diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.c | 434 |
1 files changed, 434 insertions, 0 deletions
diff --git a/src/main.c b/src/main.c new file mode 100644 index 0000000..6f18e9e --- /dev/null +++ b/src/main.c @@ -0,0 +1,434 @@ +#include <math.h> +#include <ctype.h> +#include <stdio.h> +#include <string.h> +#include <stdlib.h> +#include <unistd.h> +#include <threads.h> + +#include <ejit/ejit.h> + +/* input parameters */ +#define VAR_X EJIT_FPR(0) +#define VAR_Y EJIT_FPR(1) + +/* 'rest' of parameters, use this as a kind of base */ +#define VAR_0 EJIT_FPR(2) + +enum op { + OP_UNKNOWN, + OP_VAR_X, + OP_VAR_Y, + OP_CONST, + OP_ADD, + OP_SUB, + OP_MUL, + OP_MAX, + OP_MIN, + OP_NEG, + OP_SQUARE, + OP_SQRT +}; + +static float escape_sqrt(size_t argc, const struct ejit_arg args[argc]) +{ + assert(argc == 1); + assert(args[0].type == EJIT_FLOAT); + return sqrtf(args[0].f); +} + +static const char *skip_whitespace(const char *line) +{ + while (*line && isspace(*line)) + line++; + + return line; +} + +static const char *next_whitespace(const char *line) +{ + while (*line && !isspace(*line)) + line++; + + return line; +} + +static const char *parse_var(const char *line, size_t *var) +{ + assert(var); + + line = skip_whitespace(line); + assert(line[0] == '_'); + + char *end = NULL; + *var = strtoull(line + 1, &end, 16) + VAR_0.f; + assert(end > line + 1); + + return end; +} + +static const char *parse_op(const char *line, enum op *op) +{ + assert(op); + + const char *start = skip_whitespace(line); + const char *end = next_whitespace(start); + + int len = end - start; + + /* this is terrible but whatever */ + if (len == 6) { + if (strncmp(start, "square", len) == 0) + *op = OP_SQUARE; + else + fprintf(stderr, "unknown op: %.*s\n", len, start); + } + else if (len == 5) { + if (strncmp(start, "var-x", len) == 0) + *op = OP_VAR_X; + + else if (strncmp(start, "var-y", len) == 0) + *op = OP_VAR_Y; + + else if (strncmp(start, "const", len) == 0) + *op = OP_CONST; + else + fprintf(stderr, "unknown op: %.*s\n", len, start); + } + else if (len == 4) { + if (strncmp(start, "sqrt", len) == 0) + *op = OP_SQRT; + else + fprintf(stderr, "unknown op: %.*s\n", len, start); + } + else if (len == 3) { + if (strncmp(start, "add", len) == 0) + *op = OP_ADD; + + else if (strncmp(start, "sub", len) == 0) + *op = OP_SUB; + + else if (strncmp(start, "mul", len) == 0) + *op = OP_MUL; + + else if (strncmp(start, "max", len) == 0) + *op = OP_MAX; + + else if (strncmp(start, "min", len) == 0) + *op = OP_MIN; + + else if (strncmp(start, "neg", len) == 0) + *op = OP_NEG; + else + fprintf(stderr, "unknown op: %.*s\n", len, start); + } + else { + fprintf(stderr, "unknown op: %.*s\n", len, start); + } + + return end; +} + +static size_t compile_line(struct ejit_func *kernel, const char *line) +{ + /* not a particularly good parser but dusting off flex+bison felt like + * overkill */ + line = skip_whitespace(line); + + /* hopefully the last line is not a comment :) */ + if (line[0] == '#') + return 0; + + size_t out = 0; + line = parse_var(line, &out); + + enum op op = OP_UNKNOWN; + line = parse_op(line, &op); + switch (op) { + case OP_VAR_X: { + ejit_movr_f(kernel, EJIT_FPR(out), VAR_X); + break; + } + + case OP_VAR_Y: { + ejit_movr_f(kernel, EJIT_FPR(out), VAR_Y); + break; + } + + case OP_CONST: { + char *end = NULL; + float c = strtof(line, &end); + ejit_movi_f(kernel, EJIT_FPR(out), c); + line = end; + break; + } + + case OP_ADD: { + size_t arg0 = 0, arg1 = 0; + line = parse_var(line, &arg0); + line = parse_var(line, &arg1); + ejit_addr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1)); + break; + } + + case OP_SUB: { + size_t arg0 = 0, arg1 = 0; + line = parse_var(line, &arg0); + line = parse_var(line, &arg1); + ejit_subr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1)); + break; + } + + case OP_MUL: { + size_t arg0 = 0, arg1 = 0; + line = parse_var(line, &arg0); + line = parse_var(line, &arg1); + ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1)); + break; + } + + case OP_MAX: { + size_t arg0 = 0, arg1 = 0; + line = parse_var(line, &arg0); + line = parse_var(line, &arg1); + + /* no built-in fmax/fmin so use a plain old branch */ + struct ejit_reloc r = ejit_bger_f(kernel, EJIT_FPR(arg0), EJIT_FPR(arg1)); + /* arg0 < arg1 */ + ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg1)); + struct ejit_reloc j = ejit_jmp(kernel); + + /* arg0 >= arg1 */ + ejit_patch(kernel, r, ejit_label(kernel)); + ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0)); + + ejit_patch(kernel, j, ejit_label(kernel)); + break; + } + + case OP_MIN: { + size_t arg0 = 0, arg1 = 0; + line = parse_var(line, &arg0); + line = parse_var(line, &arg1); + + struct ejit_reloc r = ejit_bger_f(kernel, EJIT_FPR(arg0), EJIT_FPR(arg1)); + /* arg0 < arg1 */ + ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0)); + struct ejit_reloc j = ejit_jmp(kernel); + + /* arg0 >= arg1 */ + ejit_patch(kernel, r, ejit_label(kernel)); + ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg1)); + + ejit_patch(kernel, j, ejit_label(kernel)); + break; + } + + case OP_NEG: { + size_t arg = 0; + line = parse_var(line, &arg); + ejit_negr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg)); + break; + } + + case OP_SQUARE: { + size_t arg = 0; + line = parse_var(line, &arg); + ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg), EJIT_FPR(arg)); + break; + } + + case OP_SQRT: { + size_t arg = 0; + line = parse_var(line, &arg); + /* unfortunately no built-in operation for sqrt so we have to do + * an expensive call */ + struct ejit_operand op = EJIT_OPERAND_FPR(arg, EJIT_FLOAT); + ejit_escapei_f(kernel, escape_sqrt, 1, &op); + ejit_retval_f(kernel, EJIT_FPR(out)); + break; + } + + case OP_UNKNOWN: + abort(); + break; + } + + assert(line[0] == '\n'); + return out; +} + +static struct ejit_func *compile(const char *file, bool do_jit) +{ + FILE *f = fopen(file, "rb"); + assert(f); + + struct ejit_operand operands[2] = { + EJIT_OPERAND_FPR(VAR_X.f, EJIT_FLOAT), + EJIT_OPERAND_FPR(VAR_Y.f, EJIT_FLOAT) + }; + struct ejit_func *kernel = ejit_create_func(EJIT_FLOAT, 2, operands); + assert(kernel); + + size_t size = 0; + char *line = NULL; + size_t out = 0; + while (getline(&line, &size, f) != -1) + out = compile_line(kernel, line); + + ejit_retr_f(kernel, EJIT_FPR(out)); + + /* normally you'd just use ejit_compile_func(kernel), but since I want + * to explicitly compare JIT vs bytecode performance, I use this uglier, + * 'internal' compile function */ + ejit_select_compile_func(kernel, 0, out + 1, + /* we don't use any 64bit values */ + false, + /* compare jit vs bytecode */ + do_jit, + /* if we do jit, we want to mark pages read-only */ + true); + free(line); + fclose(f); + return kernel; +} + +struct work_split { + struct ejit_func *kernel; + + int8_t *data; + + /* which rows to process */ + int start; + int end; + + /* width of a column */ + int res; +}; + +static int do_work(void *arg) +{ + const struct work_split *split = arg; + + size_t idx = 0; + for (int x = split->start; x < split->end; ++x) + for (int y = 0 ; y < split->res; ++y) + { + /* convert [0, res] to [-1, 1] */ + float sx = -(x - split->res / 2) / ((float)split->res / 2); + float sy = (y - split->res / 2) / ((float)split->res / 2); + struct ejit_arg args[2] = { + EJIT_ARG(sy, float), + EJIT_ARG(sx, float) + }; + + /* single byte at a time */ + float r = ejit_run_func_f(split->kernel, 2, args); + split->data[idx++] = (r < 0) ? 255 : 0; + } + + return 0; +} + +static void usage() +{ + printf("usage: prospero [-j threads] [-r resolution] [-c]\n"); + printf(" -j how many threads to run\n"); + printf(" -r length of one side\n"); + printf(" -c enable JIT\n"); +} + +int main(int argc, char *argv[argc]) +{ + + int j = 1; + int res = 1024; + bool jit = false; + + int option = 0; + while ((option = getopt(argc, argv, "hj:r:c")) != -1) { + switch (option) { + case 'h': + usage(); + exit(EXIT_SUCCESS); + + case 'j': + j = atoi(optarg); + break; + + case 'r': + res = atoi(optarg); + break; + + case 'c': + jit = true; + break; + + default: + usage(); + exit(EXIT_FAILURE); + } + } + + if (optind != argc) { + usage(); + exit(EXIT_FAILURE); + } + + struct ejit_func *kernel = compile("prospero.vm", jit); + assert(kernel); + + thrd_t *workers = malloc(sizeof(thrd_t) * j); + struct work_split *splits = malloc(sizeof(struct work_split) * j); + + /* calculate how much work each thread should do */ + int work_len = res / j; + for (int i = 0; i < j; ++i) { + splits[i].kernel = kernel; + splits[i].res = res; + splits[i].start = i * work_len; + + /* give last split the leftovers */ + if (i == j - 1) + splits[i].end = res; + else + splits[i].end = (i + 1) * work_len; + + size_t len = splits[i].end - splits[i].start; + splits[i].data = malloc(res * len * sizeof(int8_t)); + } + + /* start threads */ + for (int i = 0; i < j; ++i) { + int r = thrd_create(&workers[i], do_work, &splits[i]); + assert(r == thrd_success); + } + + + /* collect threads */ + for (int i = 0; i < j; ++i) { + int res = 0; + thrd_join(workers[i], &res); + assert(res == 0); + } + + ejit_destroy_func(kernel); + + FILE *f = fopen("out.ppm", "wb+"); + assert(f); + + /* ppm header */ + fprintf(f, "P5\n%d %d\n255\n", res, res); + + /* image data */ + for (int i = 0; i < j; ++i) { + size_t len = splits[i].end - splits[i].start; + fwrite(splits[i].data, sizeof(int8_t), res * len, f); + free(splits[i].data); + } + + free(workers); + free(splits); + fclose(f); + return 0; +} |