aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main.c434
1 files changed, 434 insertions, 0 deletions
diff --git a/src/main.c b/src/main.c
new file mode 100644
index 0000000..6f18e9e
--- /dev/null
+++ b/src/main.c
@@ -0,0 +1,434 @@
+#include <math.h>
+#include <ctype.h>
+#include <stdio.h>
+#include <string.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <threads.h>
+
+#include <ejit/ejit.h>
+
+/* input parameters */
+#define VAR_X EJIT_FPR(0)
+#define VAR_Y EJIT_FPR(1)
+
+/* 'rest' of parameters, use this as a kind of base */
+#define VAR_0 EJIT_FPR(2)
+
+enum op {
+ OP_UNKNOWN,
+ OP_VAR_X,
+ OP_VAR_Y,
+ OP_CONST,
+ OP_ADD,
+ OP_SUB,
+ OP_MUL,
+ OP_MAX,
+ OP_MIN,
+ OP_NEG,
+ OP_SQUARE,
+ OP_SQRT
+};
+
+static float escape_sqrt(size_t argc, const struct ejit_arg args[argc])
+{
+ assert(argc == 1);
+ assert(args[0].type == EJIT_FLOAT);
+ return sqrtf(args[0].f);
+}
+
+static const char *skip_whitespace(const char *line)
+{
+ while (*line && isspace(*line))
+ line++;
+
+ return line;
+}
+
+static const char *next_whitespace(const char *line)
+{
+ while (*line && !isspace(*line))
+ line++;
+
+ return line;
+}
+
+static const char *parse_var(const char *line, size_t *var)
+{
+ assert(var);
+
+ line = skip_whitespace(line);
+ assert(line[0] == '_');
+
+ char *end = NULL;
+ *var = strtoull(line + 1, &end, 16) + VAR_0.f;
+ assert(end > line + 1);
+
+ return end;
+}
+
+static const char *parse_op(const char *line, enum op *op)
+{
+ assert(op);
+
+ const char *start = skip_whitespace(line);
+ const char *end = next_whitespace(start);
+
+ int len = end - start;
+
+ /* this is terrible but whatever */
+ if (len == 6) {
+ if (strncmp(start, "square", len) == 0)
+ *op = OP_SQUARE;
+ else
+ fprintf(stderr, "unknown op: %.*s\n", len, start);
+ }
+ else if (len == 5) {
+ if (strncmp(start, "var-x", len) == 0)
+ *op = OP_VAR_X;
+
+ else if (strncmp(start, "var-y", len) == 0)
+ *op = OP_VAR_Y;
+
+ else if (strncmp(start, "const", len) == 0)
+ *op = OP_CONST;
+ else
+ fprintf(stderr, "unknown op: %.*s\n", len, start);
+ }
+ else if (len == 4) {
+ if (strncmp(start, "sqrt", len) == 0)
+ *op = OP_SQRT;
+ else
+ fprintf(stderr, "unknown op: %.*s\n", len, start);
+ }
+ else if (len == 3) {
+ if (strncmp(start, "add", len) == 0)
+ *op = OP_ADD;
+
+ else if (strncmp(start, "sub", len) == 0)
+ *op = OP_SUB;
+
+ else if (strncmp(start, "mul", len) == 0)
+ *op = OP_MUL;
+
+ else if (strncmp(start, "max", len) == 0)
+ *op = OP_MAX;
+
+ else if (strncmp(start, "min", len) == 0)
+ *op = OP_MIN;
+
+ else if (strncmp(start, "neg", len) == 0)
+ *op = OP_NEG;
+ else
+ fprintf(stderr, "unknown op: %.*s\n", len, start);
+ }
+ else {
+ fprintf(stderr, "unknown op: %.*s\n", len, start);
+ }
+
+ return end;
+}
+
+static size_t compile_line(struct ejit_func *kernel, const char *line)
+{
+ /* not a particularly good parser but dusting off flex+bison felt like
+ * overkill */
+ line = skip_whitespace(line);
+
+ /* hopefully the last line is not a comment :) */
+ if (line[0] == '#')
+ return 0;
+
+ size_t out = 0;
+ line = parse_var(line, &out);
+
+ enum op op = OP_UNKNOWN;
+ line = parse_op(line, &op);
+ switch (op) {
+ case OP_VAR_X: {
+ ejit_movr_f(kernel, EJIT_FPR(out), VAR_X);
+ break;
+ }
+
+ case OP_VAR_Y: {
+ ejit_movr_f(kernel, EJIT_FPR(out), VAR_Y);
+ break;
+ }
+
+ case OP_CONST: {
+ char *end = NULL;
+ float c = strtof(line, &end);
+ ejit_movi_f(kernel, EJIT_FPR(out), c);
+ line = end;
+ break;
+ }
+
+ case OP_ADD: {
+ size_t arg0 = 0, arg1 = 0;
+ line = parse_var(line, &arg0);
+ line = parse_var(line, &arg1);
+ ejit_addr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
+ break;
+ }
+
+ case OP_SUB: {
+ size_t arg0 = 0, arg1 = 0;
+ line = parse_var(line, &arg0);
+ line = parse_var(line, &arg1);
+ ejit_subr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
+ break;
+ }
+
+ case OP_MUL: {
+ size_t arg0 = 0, arg1 = 0;
+ line = parse_var(line, &arg0);
+ line = parse_var(line, &arg1);
+ ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
+ break;
+ }
+
+ case OP_MAX: {
+ size_t arg0 = 0, arg1 = 0;
+ line = parse_var(line, &arg0);
+ line = parse_var(line, &arg1);
+
+ /* no built-in fmax/fmin so use a plain old branch */
+ struct ejit_reloc r = ejit_bger_f(kernel, EJIT_FPR(arg0), EJIT_FPR(arg1));
+ /* arg0 < arg1 */
+ ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg1));
+ struct ejit_reloc j = ejit_jmp(kernel);
+
+ /* arg0 >= arg1 */
+ ejit_patch(kernel, r, ejit_label(kernel));
+ ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0));
+
+ ejit_patch(kernel, j, ejit_label(kernel));
+ break;
+ }
+
+ case OP_MIN: {
+ size_t arg0 = 0, arg1 = 0;
+ line = parse_var(line, &arg0);
+ line = parse_var(line, &arg1);
+
+ struct ejit_reloc r = ejit_bger_f(kernel, EJIT_FPR(arg0), EJIT_FPR(arg1));
+ /* arg0 < arg1 */
+ ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0));
+ struct ejit_reloc j = ejit_jmp(kernel);
+
+ /* arg0 >= arg1 */
+ ejit_patch(kernel, r, ejit_label(kernel));
+ ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg1));
+
+ ejit_patch(kernel, j, ejit_label(kernel));
+ break;
+ }
+
+ case OP_NEG: {
+ size_t arg = 0;
+ line = parse_var(line, &arg);
+ ejit_negr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg));
+ break;
+ }
+
+ case OP_SQUARE: {
+ size_t arg = 0;
+ line = parse_var(line, &arg);
+ ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg), EJIT_FPR(arg));
+ break;
+ }
+
+ case OP_SQRT: {
+ size_t arg = 0;
+ line = parse_var(line, &arg);
+ /* unfortunately no built-in operation for sqrt so we have to do
+ * an expensive call */
+ struct ejit_operand op = EJIT_OPERAND_FPR(arg, EJIT_FLOAT);
+ ejit_escapei_f(kernel, escape_sqrt, 1, &op);
+ ejit_retval_f(kernel, EJIT_FPR(out));
+ break;
+ }
+
+ case OP_UNKNOWN:
+ abort();
+ break;
+ }
+
+ assert(line[0] == '\n');
+ return out;
+}
+
+static struct ejit_func *compile(const char *file, bool do_jit)
+{
+ FILE *f = fopen(file, "rb");
+ assert(f);
+
+ struct ejit_operand operands[2] = {
+ EJIT_OPERAND_FPR(VAR_X.f, EJIT_FLOAT),
+ EJIT_OPERAND_FPR(VAR_Y.f, EJIT_FLOAT)
+ };
+ struct ejit_func *kernel = ejit_create_func(EJIT_FLOAT, 2, operands);
+ assert(kernel);
+
+ size_t size = 0;
+ char *line = NULL;
+ size_t out = 0;
+ while (getline(&line, &size, f) != -1)
+ out = compile_line(kernel, line);
+
+ ejit_retr_f(kernel, EJIT_FPR(out));
+
+ /* normally you'd just use ejit_compile_func(kernel), but since I want
+ * to explicitly compare JIT vs bytecode performance, I use this uglier,
+ * 'internal' compile function */
+ ejit_select_compile_func(kernel, 0, out + 1,
+ /* we don't use any 64bit values */
+ false,
+ /* compare jit vs bytecode */
+ do_jit,
+ /* if we do jit, we want to mark pages read-only */
+ true);
+ free(line);
+ fclose(f);
+ return kernel;
+}
+
+struct work_split {
+ struct ejit_func *kernel;
+
+ int8_t *data;
+
+ /* which rows to process */
+ int start;
+ int end;
+
+ /* width of a column */
+ int res;
+};
+
+static int do_work(void *arg)
+{
+ const struct work_split *split = arg;
+
+ size_t idx = 0;
+ for (int x = split->start; x < split->end; ++x)
+ for (int y = 0 ; y < split->res; ++y)
+ {
+ /* convert [0, res] to [-1, 1] */
+ float sx = -(x - split->res / 2) / ((float)split->res / 2);
+ float sy = (y - split->res / 2) / ((float)split->res / 2);
+ struct ejit_arg args[2] = {
+ EJIT_ARG(sy, float),
+ EJIT_ARG(sx, float)
+ };
+
+ /* single byte at a time */
+ float r = ejit_run_func_f(split->kernel, 2, args);
+ split->data[idx++] = (r < 0) ? 255 : 0;
+ }
+
+ return 0;
+}
+
+static void usage()
+{
+ printf("usage: prospero [-j threads] [-r resolution] [-c]\n");
+ printf(" -j how many threads to run\n");
+ printf(" -r length of one side\n");
+ printf(" -c enable JIT\n");
+}
+
+int main(int argc, char *argv[argc])
+{
+
+ int j = 1;
+ int res = 1024;
+ bool jit = false;
+
+ int option = 0;
+ while ((option = getopt(argc, argv, "hj:r:c")) != -1) {
+ switch (option) {
+ case 'h':
+ usage();
+ exit(EXIT_SUCCESS);
+
+ case 'j':
+ j = atoi(optarg);
+ break;
+
+ case 'r':
+ res = atoi(optarg);
+ break;
+
+ case 'c':
+ jit = true;
+ break;
+
+ default:
+ usage();
+ exit(EXIT_FAILURE);
+ }
+ }
+
+ if (optind != argc) {
+ usage();
+ exit(EXIT_FAILURE);
+ }
+
+ struct ejit_func *kernel = compile("prospero.vm", jit);
+ assert(kernel);
+
+ thrd_t *workers = malloc(sizeof(thrd_t) * j);
+ struct work_split *splits = malloc(sizeof(struct work_split) * j);
+
+ /* calculate how much work each thread should do */
+ int work_len = res / j;
+ for (int i = 0; i < j; ++i) {
+ splits[i].kernel = kernel;
+ splits[i].res = res;
+ splits[i].start = i * work_len;
+
+ /* give last split the leftovers */
+ if (i == j - 1)
+ splits[i].end = res;
+ else
+ splits[i].end = (i + 1) * work_len;
+
+ size_t len = splits[i].end - splits[i].start;
+ splits[i].data = malloc(res * len * sizeof(int8_t));
+ }
+
+ /* start threads */
+ for (int i = 0; i < j; ++i) {
+ int r = thrd_create(&workers[i], do_work, &splits[i]);
+ assert(r == thrd_success);
+ }
+
+
+ /* collect threads */
+ for (int i = 0; i < j; ++i) {
+ int res = 0;
+ thrd_join(workers[i], &res);
+ assert(res == 0);
+ }
+
+ ejit_destroy_func(kernel);
+
+ FILE *f = fopen("out.ppm", "wb+");
+ assert(f);
+
+ /* ppm header */
+ fprintf(f, "P5\n%d %d\n255\n", res, res);
+
+ /* image data */
+ for (int i = 0; i < j; ++i) {
+ size_t len = splits[i].end - splits[i].start;
+ fwrite(splits[i].data, sizeof(int8_t), res * len, f);
+ free(splits[i].data);
+ }
+
+ free(workers);
+ free(splits);
+ fclose(f);
+ return 0;
+}