From 6d88dce5ebbc6fdb91054e5ffc3f8843d20b2b97 Mon Sep 17 00:00:00 2001
From: Kimplul <kimi.h.kuparinen@gmail.com>
Date: Wed, 26 Mar 2025 18:10:41 +0200
Subject: initial commit

---
 src/main.c | 434 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 434 insertions(+)
 create mode 100644 src/main.c

(limited to 'src/main.c')

diff --git a/src/main.c b/src/main.c
new file mode 100644
index 0000000..6f18e9e
--- /dev/null
+++ b/src/main.c
@@ -0,0 +1,434 @@
+#include <math.h>
+#include <ctype.h>
+#include <stdio.h>
+#include <string.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <threads.h>
+
+#include <ejit/ejit.h>
+
+/* input parameters */
+#define VAR_X EJIT_FPR(0)
+#define VAR_Y EJIT_FPR(1)
+
+/* 'rest' of parameters, use this as a kind of base */
+#define VAR_0 EJIT_FPR(2)
+
+enum op {
+	OP_UNKNOWN,
+	OP_VAR_X,
+	OP_VAR_Y,
+	OP_CONST,
+	OP_ADD,
+	OP_SUB,
+	OP_MUL,
+	OP_MAX,
+	OP_MIN,
+	OP_NEG,
+	OP_SQUARE,
+	OP_SQRT
+};
+
+static float escape_sqrt(size_t argc, const struct ejit_arg args[argc])
+{
+	assert(argc == 1);
+	assert(args[0].type == EJIT_FLOAT);
+	return sqrtf(args[0].f);
+}
+
+static const char *skip_whitespace(const char *line)
+{
+	while (*line && isspace(*line))
+		line++;
+
+	return line;
+}
+
+static const char *next_whitespace(const char *line)
+{
+	while (*line && !isspace(*line))
+		line++;
+
+	return line;
+}
+
+static const char *parse_var(const char *line, size_t *var)
+{
+	assert(var);
+
+	line = skip_whitespace(line);
+	assert(line[0] == '_');
+
+	char *end = NULL;
+	*var = strtoull(line + 1, &end, 16) + VAR_0.f;
+	assert(end > line + 1);
+
+	return end;
+}
+
+static const char *parse_op(const char *line, enum op *op)
+{
+	assert(op);
+
+	const char *start = skip_whitespace(line);
+	const char *end = next_whitespace(start);
+
+	int len = end - start;
+
+	/* this is terrible but whatever */
+	if (len == 6) {
+		if (strncmp(start, "square", len) == 0)
+			*op = OP_SQUARE;
+		else
+			fprintf(stderr, "unknown op: %.*s\n", len, start);
+	}
+	else if (len == 5) {
+		if (strncmp(start, "var-x", len) == 0)
+			*op = OP_VAR_X;
+
+		else if (strncmp(start, "var-y", len) == 0)
+			*op = OP_VAR_Y;
+
+		else if (strncmp(start, "const", len) == 0)
+			*op = OP_CONST;
+		else
+			fprintf(stderr, "unknown op: %.*s\n", len, start);
+	}
+	else if (len == 4) {
+		if (strncmp(start, "sqrt", len) == 0)
+			*op = OP_SQRT;
+		else
+			fprintf(stderr, "unknown op: %.*s\n", len, start);
+	}
+	else if (len == 3) {
+		if (strncmp(start, "add", len) == 0)
+			*op = OP_ADD;
+
+		else if (strncmp(start, "sub", len) == 0)
+			*op = OP_SUB;
+
+		else if (strncmp(start, "mul", len) == 0)
+			*op = OP_MUL;
+
+		else if (strncmp(start, "max", len) == 0)
+			*op = OP_MAX;
+
+		else if (strncmp(start, "min", len) == 0)
+			*op = OP_MIN;
+
+		else if (strncmp(start, "neg", len) == 0)
+			*op = OP_NEG;
+		else
+			fprintf(stderr, "unknown op: %.*s\n", len, start);
+	}
+	else {
+		fprintf(stderr, "unknown op: %.*s\n", len, start);
+	}
+
+	return end;
+}
+
+static size_t compile_line(struct ejit_func *kernel, const char *line)
+{
+	/* not a particularly good parser but dusting off flex+bison felt like
+	 * overkill */
+	line = skip_whitespace(line);
+
+	/* hopefully the last line is not a comment :) */
+	if (line[0] == '#')
+		return 0;
+
+	size_t out = 0;
+	line = parse_var(line, &out);
+
+	enum op op = OP_UNKNOWN;
+	line = parse_op(line, &op);
+	switch (op) {
+	case OP_VAR_X: {
+		ejit_movr_f(kernel, EJIT_FPR(out), VAR_X);
+		break;
+	}
+
+	case OP_VAR_Y: {
+		ejit_movr_f(kernel, EJIT_FPR(out), VAR_Y);
+		break;
+	}
+
+	case OP_CONST: {
+		char *end = NULL;
+		float c = strtof(line, &end);
+		ejit_movi_f(kernel, EJIT_FPR(out), c);
+		line = end;
+		break;
+	}
+
+	case OP_ADD: {
+		size_t arg0 = 0, arg1 = 0;
+		line = parse_var(line, &arg0);
+		line = parse_var(line, &arg1);
+		ejit_addr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
+		break;
+	}
+
+	case OP_SUB: {
+		size_t arg0 = 0, arg1 = 0;
+		line = parse_var(line, &arg0);
+		line = parse_var(line, &arg1);
+		ejit_subr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
+		break;
+	}
+
+	case OP_MUL: {
+		size_t arg0 = 0, arg1 = 0;
+		line = parse_var(line, &arg0);
+		line = parse_var(line, &arg1);
+		ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
+		break;
+	}
+
+	case OP_MAX: {
+		size_t arg0 = 0, arg1 = 0;
+		line = parse_var(line, &arg0);
+		line = parse_var(line, &arg1);
+
+		/* no built-in fmax/fmin so use a plain old branch */
+		struct ejit_reloc r = ejit_bger_f(kernel, EJIT_FPR(arg0), EJIT_FPR(arg1));
+		/* arg0 < arg1 */
+		ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg1));
+		struct ejit_reloc j = ejit_jmp(kernel);
+
+		/* arg0 >= arg1 */
+		ejit_patch(kernel, r, ejit_label(kernel));
+		ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0));
+
+		ejit_patch(kernel, j, ejit_label(kernel));
+		break;
+	}
+
+	case OP_MIN: {
+		size_t arg0 = 0, arg1 = 0;
+		line = parse_var(line, &arg0);
+		line = parse_var(line, &arg1);
+
+		struct ejit_reloc r = ejit_bger_f(kernel, EJIT_FPR(arg0), EJIT_FPR(arg1));
+		/* arg0 < arg1 */
+		ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0));
+		struct ejit_reloc j = ejit_jmp(kernel);
+
+		/* arg0 >= arg1 */
+		ejit_patch(kernel, r, ejit_label(kernel));
+		ejit_movr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg1));
+
+		ejit_patch(kernel, j, ejit_label(kernel));
+		break;
+	}
+
+	case OP_NEG: {
+		size_t arg = 0;
+		line = parse_var(line, &arg);
+		ejit_negr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg));
+		break;
+	}
+
+	case OP_SQUARE: {
+		size_t arg = 0;
+		line = parse_var(line, &arg);
+		ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg), EJIT_FPR(arg));
+		break;
+	}
+
+	case OP_SQRT: {
+		size_t arg = 0;
+		line = parse_var(line, &arg);
+		/* unfortunately no built-in operation for sqrt so we have to do
+		 * an expensive call */
+		struct ejit_operand op = EJIT_OPERAND_FPR(arg, EJIT_FLOAT);
+		ejit_escapei_f(kernel, escape_sqrt, 1, &op);
+		ejit_retval_f(kernel, EJIT_FPR(out));
+		break;
+	}
+
+	case OP_UNKNOWN:
+		abort();
+		break;
+	}
+
+	assert(line[0] == '\n');
+	return out;
+}
+
+static struct ejit_func *compile(const char *file, bool do_jit)
+{
+	FILE *f = fopen(file, "rb");
+	assert(f);
+
+	struct ejit_operand operands[2] = {
+		EJIT_OPERAND_FPR(VAR_X.f, EJIT_FLOAT),
+		EJIT_OPERAND_FPR(VAR_Y.f, EJIT_FLOAT)
+	};
+	struct ejit_func *kernel = ejit_create_func(EJIT_FLOAT, 2, operands);
+	assert(kernel);
+
+	size_t size = 0;
+	char *line = NULL;
+	size_t out = 0;
+	while (getline(&line, &size, f) != -1)
+		out = compile_line(kernel, line);
+
+	ejit_retr_f(kernel, EJIT_FPR(out));
+
+	/* normally you'd just use ejit_compile_func(kernel), but since I want
+	 * to explicitly compare JIT vs bytecode performance, I use this uglier,
+	 * 'internal' compile function */
+	ejit_select_compile_func(kernel, 0, out + 1,
+			/* we don't use any 64bit values */
+			false,
+			/* compare jit vs bytecode */
+			do_jit,
+			/* if we do jit, we want to mark pages read-only */
+			true);
+	free(line);
+	fclose(f);
+	return kernel;
+}
+
+struct work_split {
+	struct ejit_func *kernel;
+
+	int8_t *data;
+
+	/* which rows to process */
+	int start;
+	int end;
+
+	/* width of a column */
+	int res;
+};
+
+static int do_work(void *arg)
+{
+	const struct work_split *split = arg;
+
+	size_t idx = 0;
+	for (int x = split->start; x < split->end; ++x)
+	for (int y = 0           ; y < split->res; ++y)
+	{
+		/* convert [0, res] to [-1, 1] */
+		float sx = -(x - split->res / 2) / ((float)split->res / 2);
+		float sy =  (y - split->res / 2) / ((float)split->res / 2);
+		struct ejit_arg args[2] = {
+			EJIT_ARG(sy, float),
+			EJIT_ARG(sx, float)
+		};
+
+		/* single byte at a time */
+		float r = ejit_run_func_f(split->kernel, 2, args);
+		split->data[idx++] = (r < 0) ? 255 : 0;
+	}
+
+	return 0;
+}
+
+static void usage()
+{
+	printf("usage: prospero [-j threads] [-r resolution] [-c]\n");
+	printf(" -j how many threads to run\n");
+	printf(" -r length of one side\n");
+	printf(" -c enable JIT\n");
+}
+
+int main(int argc, char *argv[argc])
+{
+
+	int j = 1;
+	int res = 1024;
+	bool jit = false;
+
+	int option = 0;
+	while ((option = getopt(argc, argv, "hj:r:c")) != -1) {
+		switch (option) {
+		case 'h':
+			usage();
+			exit(EXIT_SUCCESS);
+
+		case 'j':
+			j = atoi(optarg);
+			break;
+
+		case 'r':
+			res = atoi(optarg);
+			break;
+
+		case 'c':
+			jit = true;
+			break;
+
+		default:
+			usage();
+			exit(EXIT_FAILURE);
+		}
+	}
+
+	if (optind != argc) {
+		usage();
+		exit(EXIT_FAILURE);
+	}
+
+	struct ejit_func *kernel = compile("prospero.vm", jit);
+	assert(kernel);
+
+	thrd_t *workers = malloc(sizeof(thrd_t) * j);
+	struct work_split *splits = malloc(sizeof(struct work_split) * j);
+
+	/* calculate how much work each thread should do */
+	int work_len = res / j;
+	for (int i = 0; i < j; ++i) {
+		splits[i].kernel = kernel;
+		splits[i].res = res;
+		splits[i].start = i * work_len;
+
+		/* give last split the leftovers */
+		if (i == j - 1)
+			splits[i].end = res;
+		else
+			splits[i].end = (i + 1) * work_len;
+
+		size_t len = splits[i].end - splits[i].start;
+		splits[i].data = malloc(res * len * sizeof(int8_t));
+	}
+
+	/* start threads */
+	for (int i = 0; i < j; ++i) {
+		int r = thrd_create(&workers[i], do_work, &splits[i]);
+		assert(r == thrd_success);
+	}
+
+
+	/* collect threads */
+	for (int i = 0; i < j; ++i) {
+		int res = 0;
+		thrd_join(workers[i], &res);
+		assert(res == 0);
+	}
+
+	ejit_destroy_func(kernel);
+
+	FILE *f = fopen("out.ppm", "wb+");
+	assert(f);
+
+	/* ppm header */
+	fprintf(f, "P5\n%d %d\n255\n", res, res);
+
+	/* image data */
+	for (int i = 0; i < j; ++i) {
+		size_t len = splits[i].end - splits[i].start;
+		fwrite(splits[i].data, sizeof(int8_t), res * len, f);
+		free(splits[i].data);
+	}
+
+	free(workers);
+	free(splits);
+	fclose(f);
+	return 0;
+}
-- 
cgit v1.2.3