#include <math.h>
#include <ctype.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <threads.h>

#include <ejit/ejit.h>

/* input parameters */
#define VAR_X EJIT_FPR(0)
#define VAR_Y EJIT_FPR(1)

/* 'rest' of parameters, use this as a kind of base */
#define VAR_0 EJIT_FPR(2)

enum op {
	OP_UNKNOWN,
	OP_VAR_X,
	OP_VAR_Y,
	OP_CONST,
	OP_ADD,
	OP_SUB,
	OP_MUL,
	OP_MAX,
	OP_MIN,
	OP_NEG,
	OP_SQUARE,
	OP_SQRT
};

static const char *skip_whitespace(const char *line)
{
	while (*line && isspace(*line))
		line++;

	return line;
}

static const char *next_whitespace(const char *line)
{
	while (*line && !isspace(*line))
		line++;

	return line;
}

static const char *parse_var(const char *line, size_t *var)
{
	assert(var);

	line = skip_whitespace(line);
	assert(line[0] == '_');

	char *end = NULL;
	*var = strtoull(line + 1, &end, 16) + VAR_0.f;
	assert(end > line + 1);

	return end;
}

static const char *parse_op(const char *line, enum op *op)
{
	assert(op);

	const char *start = skip_whitespace(line);
	const char *end = next_whitespace(start);

	int len = end - start;

	/* this is terrible but whatever */
	if (len == 6) {
		if (strncmp(start, "square", len) == 0)
			*op = OP_SQUARE;
		else
			fprintf(stderr, "unknown op: %.*s\n", len, start);
	}
	else if (len == 5) {
		if (strncmp(start, "var-x", len) == 0)
			*op = OP_VAR_X;

		else if (strncmp(start, "var-y", len) == 0)
			*op = OP_VAR_Y;

		else if (strncmp(start, "const", len) == 0)
			*op = OP_CONST;
		else
			fprintf(stderr, "unknown op: %.*s\n", len, start);
	}
	else if (len == 4) {
		if (strncmp(start, "sqrt", len) == 0)
			*op = OP_SQRT;
		else
			fprintf(stderr, "unknown op: %.*s\n", len, start);
	}
	else if (len == 3) {
		if (strncmp(start, "add", len) == 0)
			*op = OP_ADD;

		else if (strncmp(start, "sub", len) == 0)
			*op = OP_SUB;

		else if (strncmp(start, "mul", len) == 0)
			*op = OP_MUL;

		else if (strncmp(start, "max", len) == 0)
			*op = OP_MAX;

		else if (strncmp(start, "min", len) == 0)
			*op = OP_MIN;

		else if (strncmp(start, "neg", len) == 0)
			*op = OP_NEG;
		else
			fprintf(stderr, "unknown op: %.*s\n", len, start);
	}
	else {
		fprintf(stderr, "unknown op: %.*s\n", len, start);
	}

	return end;
}

static size_t compile_line(struct ejit_func *kernel, const char *line)
{
	/* not a particularly good parser but dusting off flex+bison felt like
	 * overkill */
	line = skip_whitespace(line);

	/* hopefully the last line is not a comment :) */
	if (line[0] == '#')
		return 0;

	size_t out = 0;
	line = parse_var(line, &out);

	enum op op = OP_UNKNOWN;
	line = parse_op(line, &op);
	switch (op) {
	case OP_VAR_X: {
		ejit_movr_f(kernel, EJIT_FPR(out), VAR_X);
		break;
	}

	case OP_VAR_Y: {
		ejit_movr_f(kernel, EJIT_FPR(out), VAR_Y);
		break;
	}

	case OP_CONST: {
		char *end = NULL;
		float c = strtof(line, &end);
		ejit_movi_f(kernel, EJIT_FPR(out), c);
		line = end;
		break;
	}

	case OP_ADD: {
		size_t arg0 = 0, arg1 = 0;
		line = parse_var(line, &arg0);
		line = parse_var(line, &arg1);
		ejit_addr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
		break;
	}

	case OP_SUB: {
		size_t arg0 = 0, arg1 = 0;
		line = parse_var(line, &arg0);
		line = parse_var(line, &arg1);
		ejit_subr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
		break;
	}

	case OP_MUL: {
		size_t arg0 = 0, arg1 = 0;
		line = parse_var(line, &arg0);
		line = parse_var(line, &arg1);
		ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
		break;
	}

	case OP_MAX: {
		size_t arg0 = 0, arg1 = 0;
		line = parse_var(line, &arg0);
		line = parse_var(line, &arg1);
		ejit_maxr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
		break;
	}

	case OP_MIN: {
		size_t arg0 = 0, arg1 = 0;
		line = parse_var(line, &arg0);
		line = parse_var(line, &arg1);
		ejit_minr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg0), EJIT_FPR(arg1));
		break;
	}

	case OP_NEG: {
		size_t arg = 0;
		line = parse_var(line, &arg);
		ejit_negr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg));
		break;
	}

	case OP_SQUARE: {
		size_t arg = 0;
		line = parse_var(line, &arg);
		ejit_mulr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg), EJIT_FPR(arg));
		break;
	}

	case OP_SQRT: {
		size_t arg = 0;
		line = parse_var(line, &arg);
		ejit_sqrtr_f(kernel, EJIT_FPR(out), EJIT_FPR(arg));
		break;
	}

	case OP_UNKNOWN:
		abort();
		break;
	}

	assert(line[0] == '\n');
	return out;
}

static struct ejit_func *compile(const char *file, bool do_jit)
{
	FILE *f = fopen(file, "rb");
	assert(f);

	struct ejit_operand operands[2] = {
		EJIT_OPERAND_FPR(VAR_X.f, EJIT_FLOAT),
		EJIT_OPERAND_FPR(VAR_Y.f, EJIT_FLOAT)
	};
	struct ejit_func *kernel = ejit_create_func(EJIT_FLOAT, 2, operands);
	assert(kernel);

	size_t size = 0;
	char *line = NULL;
	size_t out = 0;
	while (getline(&line, &size, f) != -1)
		out = compile_line(kernel, line);

	ejit_retr_f(kernel, EJIT_FPR(out));

	/* normally you'd just use ejit_compile_func(kernel), but since I want
	 * to explicitly compare JIT vs bytecode performance, I use this uglier,
	 * 'internal' compile function */
	ejit_select_compile_func(kernel, 0, out + 1,
			/* we don't use any 64bit values */
			false,
			/* compare jit vs bytecode */
			do_jit,
			/* if we do jit, we want to mark pages read-only */
			true);
	free(line);
	fclose(f);
	return kernel;
}

struct work_split {
	struct ejit_func *kernel;

	int8_t *data;

	/* which rows to process */
	int start;
	int end;

	/* width of a column */
	int res;
};

static int do_work(void *arg)
{
	const struct work_split *split = arg;

	size_t idx = 0;
	for (int x = split->start; x < split->end; ++x)
	for (int y = 0           ; y < split->res; ++y)
	{
		/* convert [0, res] to [-1, 1] */
		float sx = -(x - split->res / 2) / ((float)split->res / 2);
		float sy =  (y - split->res / 2) / ((float)split->res / 2);
		struct ejit_arg args[2] = {
			EJIT_ARG(sy, float),
			EJIT_ARG(sx, float)
		};

		/* single byte at a time */
		float r = ejit_run_func_f(split->kernel, 2, args);
		split->data[idx++] = (r < 0) ? 255 : 0;
	}

	return 0;
}

static void usage()
{
	printf("usage: prospero [-j threads] [-r resolution] [-c]\n");
	printf(" -j how many threads to run\n");
	printf(" -r length of one side\n");
	printf(" -c enable JIT\n");
}

int main(int argc, char *argv[argc])
{

	int j = 1;
	int res = 1024;
	bool jit = false;

	int option = 0;
	while ((option = getopt(argc, argv, "hj:r:c")) != -1) {
		switch (option) {
		case 'h':
			usage();
			exit(EXIT_SUCCESS);

		case 'j':
			j = atoi(optarg);
			break;

		case 'r':
			res = atoi(optarg);
			break;

		case 'c':
			jit = true;
			break;

		default:
			usage();
			exit(EXIT_FAILURE);
		}
	}

	if (optind != argc) {
		usage();
		exit(EXIT_FAILURE);
	}

	struct ejit_func *kernel = compile("prospero.vm", jit);
	assert(kernel);

	thrd_t *workers = malloc(sizeof(thrd_t) * j);
	struct work_split *splits = malloc(sizeof(struct work_split) * j);

	/* calculate how much work each thread should do */
	int work_len = res / j;
	for (int i = 0; i < j; ++i) {
		splits[i].kernel = kernel;
		splits[i].res = res;
		splits[i].start = i * work_len;

		/* give last split the leftovers */
		if (i == j - 1)
			splits[i].end = res;
		else
			splits[i].end = (i + 1) * work_len;

		size_t len = splits[i].end - splits[i].start;
		splits[i].data = malloc(res * len * sizeof(int8_t));
	}

	/* start threads */
	for (int i = 0; i < j; ++i) {
		int r = thrd_create(&workers[i], do_work, &splits[i]);
		assert(r == thrd_success);
	}


	/* collect threads */
	for (int i = 0; i < j; ++i) {
		int res = 0;
		thrd_join(workers[i], &res);
		assert(res == 0);
	}

	ejit_destroy_func(kernel);

	FILE *f = fopen("out.ppm", "wb+");
	assert(f);

	/* ppm header */
	fprintf(f, "P5\n%d %d\n255\n", res, res);

	/* image data */
	for (int i = 0; i < j; ++i) {
		size_t len = splits[i].end - splits[i].start;
		fwrite(splits[i].data, sizeof(int8_t), res * len, f);
		free(splits[i].data);
	}

	free(workers);
	free(splits);
	fclose(f);
	return 0;
}