#include #include #include #include #include #include #include #include /* input parameters */ #define VAR_X EJIT_FPR(0) #define VAR_Y EJIT_FPR(1) /* 'rest' of parameters, use this as a kind of base */ #define VAR_0 EJIT_FPR(2) enum op { OP_UNKNOWN, OP_VAR_X, OP_VAR_Y, OP_CONST, OP_ADD, OP_SUB, OP_MUL, OP_MAX, OP_MIN, OP_NEG, OP_SQUARE, OP_SQRT }; static float escape_sqrt(size_t argc, const struct ejit_arg args[argc]) { assert(argc == 1); assert(args[0].type == EJIT_FLOAT); return sqrtf(args[0].f); } static const char *skip_whitespace(const char *line) { while (*line && isspace(*line)) line++; return line; } static const char *next_whitespace(const char *line) { while (*line && !isspace(*line)) line++; return line; } static const char *parse_var(const char *line, size_t *var) { assert(var); line = skip_whitespace(line); assert(line[0] == '_'); char *end = NULL; *var = strtoull(line + 1, &end, 16) + VAR_0.f; assert(end > line + 1); return end; } static const char *parse_op(const char *line, enum op *op) { assert(op); const char *start = skip_whitespace(line); const char *end = next_whitespace(start); int len = end - start; /* this is terrible but whatever */ if (len == 6) { if (strncmp(start, "square", len) == 0) *op = OP_SQUARE; else fprintf(stderr, "unknown op: %.*s\n", len, start); } else if (len == 5) { if (strncmp(start, "var-x", len) == 0) *op = OP_VAR_X; else if (strncmp(start, "var-y", len) == 0) *op = OP_VAR_Y; else if (strncmp(start, "const", len) == 0) *op = OP_CONST; else fprintf(stderr, "unknown op: %.*s\n", len, start); } else if (len == 4) { if (strncmp(start, "sqrt", len) == 0) *op = OP_SQRT; else fprintf(stderr, "unknown op: %.*s\n", len, start); } else if (len == 3) { if (strncmp(start, "add", len) == 0) *op = OP_ADD; else if (strncmp(start, "sub", len) == 0) *op = OP_SUB; else if (strncmp(start, "mul", len) == 0) *op = OP_MUL; else if (strncmp(start, "max", len) == 0) *op = OP_MAX; else if (strncmp(start, "min", len) == 0) *op = OP_MIN; else if (strncmp(start, "neg", len) == 0) *op = OP_NEG; else fprintf(stderr, "unknown op: %.*s\n", len, start); } else { fprintf(stderr, "unknown op: %.*s\n", len, start); } return end; } static size_t compile_line(struct ejit_func *kernel, const char *line) { /* not a particularly good parser but dusting off flex+bison felt like * overkill */ line = skip_whitespace(line); /* hopefully the last line is not a comment :) */ if (line[0] == '#') return 0; size_t out = 0; line = parse_var(line, &out); enum op op = OP_UNKNOWN; line = parse_op(line, &op); switch (op) { case OP_VAR_X: { ejit_movr_f(kernel, EJIT_FPR(out), VAR_X); break; } case OP_VAR_Y: { ejit_movr_f(kernel, EJIT_FPR(out), VAR_Y); break; } case OP_CONST: { char *end = NULL; float c = strtof(line, &end); ejit_movi_f(kernel, EJIT_FPR(out), c); line = end; break; } case OP_ADD: { size_t arg0 = 0, arg1 = 0; line = parse_var(line, &arg0); line = parse_var(line, &arg1); ejit_addr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1)); break; } case OP_SUB: { size_t arg0 = 0, arg1 = 0; line = parse_var(line, &arg0); line = parse_var(line, &arg1); ejit_subr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1)); break; } case OP_MUL: { size_t arg0 = 0, arg1 = 0; line = parse_var(line, &arg0); line = parse_var(line, &arg1); ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1)); break; } case OP_MAX: { size_t arg0 = 0, arg1 = 0; line = parse_var(line, &arg0); line = parse_var(line, &arg1); /* no built-in fmax/fmin so use a plain old branch */ struct ejit_reloc r = ejit_bger_f(kernel, EJIT_FPR(arg0), EJIT_FPR(arg1)); /* arg0 < arg1 */ ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg1)); struct ejit_reloc j = ejit_jmp(kernel); /* arg0 >= arg1 */ ejit_patch(kernel, r, ejit_label(kernel)); ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0)); ejit_patch(kernel, j, ejit_label(kernel)); break; } case OP_MIN: { size_t arg0 = 0, arg1 = 0; line = parse_var(line, &arg0); line = parse_var(line, &arg1); struct ejit_reloc r = ejit_bger_f(kernel, EJIT_FPR(arg0), EJIT_FPR(arg1)); /* arg0 < arg1 */ ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0)); struct ejit_reloc j = ejit_jmp(kernel); /* arg0 >= arg1 */ ejit_patch(kernel, r, ejit_label(kernel)); ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg1)); ejit_patch(kernel, j, ejit_label(kernel)); break; } case OP_NEG: { size_t arg = 0; line = parse_var(line, &arg); ejit_negr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg)); break; } case OP_SQUARE: { size_t arg = 0; line = parse_var(line, &arg); ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg), EJIT_FPR(arg)); break; } case OP_SQRT: { size_t arg = 0; line = parse_var(line, &arg); /* unfortunately no built-in operation for sqrt so we have to do * an expensive call */ struct ejit_operand op = EJIT_OPERAND_FPR(arg, EJIT_FLOAT); ejit_escapei_f(kernel, escape_sqrt, 1, &op); ejit_retval_f(kernel, EJIT_FPR(out)); break; } case OP_UNKNOWN: abort(); break; } assert(line[0] == '\n'); return out; } static struct ejit_func *compile(const char *file, bool do_jit) { FILE *f = fopen(file, "rb"); assert(f); struct ejit_operand operands[2] = { EJIT_OPERAND_FPR(VAR_X.f, EJIT_FLOAT), EJIT_OPERAND_FPR(VAR_Y.f, EJIT_FLOAT) }; struct ejit_func *kernel = ejit_create_func(EJIT_FLOAT, 2, operands); assert(kernel); size_t size = 0; char *line = NULL; size_t out = 0; while (getline(&line, &size, f) != -1) out = compile_line(kernel, line); ejit_retr_f(kernel, EJIT_FPR(out)); /* normally you'd just use ejit_compile_func(kernel), but since I want * to explicitly compare JIT vs bytecode performance, I use this uglier, * 'internal' compile function */ ejit_select_compile_func(kernel, 0, out + 1, /* we don't use any 64bit values */ false, /* compare jit vs bytecode */ do_jit, /* if we do jit, we want to mark pages read-only */ true); free(line); fclose(f); return kernel; } struct work_split { struct ejit_func *kernel; int8_t *data; /* which rows to process */ int start; int end; /* width of a column */ int res; }; static int do_work(void *arg) { const struct work_split *split = arg; size_t idx = 0; for (int x = split->start; x < split->end; ++x) for (int y = 0 ; y < split->res; ++y) { /* convert [0, res] to [-1, 1] */ float sx = -(x - split->res / 2) / ((float)split->res / 2); float sy = (y - split->res / 2) / ((float)split->res / 2); struct ejit_arg args[2] = { EJIT_ARG(sy, float), EJIT_ARG(sx, float) }; /* single byte at a time */ float r = ejit_run_func_f(split->kernel, 2, args); split->data[idx++] = (r < 0) ? 255 : 0; } return 0; } static void usage() { printf("usage: prospero [-j threads] [-r resolution] [-c]\n"); printf(" -j how many threads to run\n"); printf(" -r length of one side\n"); printf(" -c enable JIT\n"); } int main(int argc, char *argv[argc]) { int j = 1; int res = 1024; bool jit = false; int option = 0; while ((option = getopt(argc, argv, "hj:r:c")) != -1) { switch (option) { case 'h': usage(); exit(EXIT_SUCCESS); case 'j': j = atoi(optarg); break; case 'r': res = atoi(optarg); break; case 'c': jit = true; break; default: usage(); exit(EXIT_FAILURE); } } if (optind != argc) { usage(); exit(EXIT_FAILURE); } struct ejit_func *kernel = compile("prospero.vm", jit); assert(kernel); thrd_t *workers = malloc(sizeof(thrd_t) * j); struct work_split *splits = malloc(sizeof(struct work_split) * j); /* calculate how much work each thread should do */ int work_len = res / j; for (int i = 0; i < j; ++i) { splits[i].kernel = kernel; splits[i].res = res; splits[i].start = i * work_len; /* give last split the leftovers */ if (i == j - 1) splits[i].end = res; else splits[i].end = (i + 1) * work_len; size_t len = splits[i].end - splits[i].start; splits[i].data = malloc(res * len * sizeof(int8_t)); } /* start threads */ for (int i = 0; i < j; ++i) { int r = thrd_create(&workers[i], do_work, &splits[i]); assert(r == thrd_success); } /* collect threads */ for (int i = 0; i < j; ++i) { int res = 0; thrd_join(workers[i], &res); assert(res == 0); } ejit_destroy_func(kernel); FILE *f = fopen("out.ppm", "wb+"); assert(f); /* ppm header */ fprintf(f, "P5\n%d %d\n255\n", res, res); /* image data */ for (int i = 0; i < j; ++i) { size_t len = splits[i].end - splits[i].start; fwrite(splits[i].data, sizeof(int8_t), res * len, f); free(splits[i].data); } free(workers); free(splits); fclose(f); return 0; }