#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>

#include <posthaste/lower.h>
#include <posthaste/scope.h>
#include <posthaste/utils.h>
#include <posthaste/date.h>

/* The bytecode is similar in construction to Lua's, as in there's two arrays, one
 * for global values and one for local values. The local value array is referred
 * to as the stack. In theory each function/procedure could get its own stack,
 * but in this case there's just a single global stack, and each call moves a
 * stack pointer up to signify entering a new function. Procedure/function
 * formal parameters are the first things on the stack.
 *
 * All instructions are of the three-value kinds, i.e. an instruction has an
 * opcode, an output location and two input locations, not all of which have to
 * be used. There's a special null_loc() to signify an unused location.
 * Locations are effectively just indexes into either the global or local
 * array. I've also added a separate CONST instruction that pushes a constant
 * value to a specified output location.
 *
 * Individual instructions are very wide, which decreases cache performance and
 * makes this particular implementation fairly slow to interpret, but the extra
 * width means it's a little bit easier to work due to not having to
 * compress/decompress anything. In particular, I tried to make JIT compilation
 * as easy as possible, at expense of interpretation speed.
 */

/* globals are kind of ugly and could/should be made into parameters (or
 * potentially even members of some all-encompassing state) but this works for
 * this simple implementation, implementing the change would mosty just require
 * changing function signatures which is boring and I don't want to do it right
 * now */
static struct vec fns = {0};

static size_t globals = 1;

/* get register/stack slot for AST node */
#define regno(x) ((x)->l.s)
#define reg(x) EJIT_GPR(regno(x))

size_t num_globals()
{
	return globals;
}

static void lower(struct fn *f, struct ast *n);
static void lower_list(struct fn *f, struct ast *l)
{
	foreach_node(n, l) {
		lower(f, n);
	}
}

static struct loc build_local_loc(size_t idx)
{
	return (struct loc){.g = -1, .s = idx};
}

static struct loc build_global_loc(size_t idx)
{
	return (struct loc){.g = idx, .s = -1};
}

static struct loc null_loc()
{
	return (struct loc){.g = -1, .s = -1};
}

static enum ejit_type ejit_type_from(enum type_kind t)
{
	switch (t) {
	case TYPE_VOID: return EJIT_VOID;
	case TYPE_STRING: return EJIT_POINTER;
	case TYPE_INT:
	case TYPE_BOOL:
	case TYPE_DATE: return EJIT_INT64;
	default: abort();
	}
}

static void put(struct fn *f, struct loc l)
{
	assert(l.g != 0);
	/* local values are already where they should be */
	if (l.g == -1)
		return;

	/* something like this I think, should still check the type width */
	ejit_stxi_64(f->f, EJIT_GPR(l.s), EJIT_GPR(0), l.g * sizeof(int64_t));
}

static void get(struct fn *f, struct loc l)
{
	assert(l.g != 0);
	if (l.g == -1)
		return;

	ejit_ldxi_u64(f->f, EJIT_GPR(l.s), EJIT_GPR(0), l.g * sizeof(int64_t));
}

static void lower_var(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp++);
	struct ast *expr = var_expr(n);
	lower(f, expr);
	ejit_movr(f->f, reg(n), reg(expr));
}

static void lower_formal(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp++);
}

static struct vec lower_formals(struct fn *f, struct ast *n)
{
	struct vec formals = vec_create(sizeof(struct ejit_operand));
	foreach_node(fn, n) {
		lower_formal(f, fn);

		struct ejit_operand arg = EJIT_OPERAND_GPR(
			regno(fn),
			ejit_type_from(fn->t)
			);
		vec_append(&formals, &arg);
	}

	return formals;
}

static void lower_date(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	ejit_movi(f->f, reg(n), n->v);
}

static void lower_string(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	ejit_movi(f->f, reg(n), (int64_t)n->id);
}

static void lower_int(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	ejit_movi(f->f, reg(n), n->v);
}

static void lower_add(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = add_l(n);
	struct ast *r = add_r(n);

	/* we could implement some kind of constant folding here, but it might
	 * be better to do it during type checking or something so we don't mess
	 * with stack locations too much */
	if (r->k == AST_CONST_INT) {
		f->sp += 1; lower(f, l);
		f->sp -= 1;
		return ejit_addi(f->f, reg(n), reg(l), r->v);
	}

	if (l->k == AST_CONST_INT) {
		f->sp += 1; lower(f, r);
		f->sp -= 1;
		return ejit_addi(f->f, reg(n), reg(r), l->v);
	}

	/* since variable definitions can't appear in expressions, there's no
	 * danger of some variable claiming a stack location above this point and
	 * it being overwritten by accident. If variable definitions were
	 * expressions, we would probably have to do an extra pass before
	 * lowering that pre-assigns locations to all variables */
	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	/* technically we could reuse the stack location for l, but I found that
	 * reading the generated instruction sequences was a bit easier with
	 * separate output and input stack locations */
	f->sp -= 2;

	ejit_addr(f->f, reg(n), reg(l), reg(r));
}

static void lower_sub(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = sub_l(n);
	struct ast *r = sub_r(n);

	if (r->k == AST_CONST_INT) {
		f->sp += 1; lower(f, l);
		f->sp -= 1;
		ejit_subi(f->f, reg(n), reg(l), r->v);
		return;
	}

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	ejit_subr(f->f, reg(n), reg(l), reg(r));
}

static void lower_mul(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = mul_l(n);
	struct ast *r = mul_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	ejit_mulr(f->f, reg(n), reg(l), reg(r));
}

static void lower_div(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = div_l(n);
	struct ast *r = div_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	ejit_divr(f->f, reg(n), reg(l), reg(r));
}

static long escape_store_year(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 2);
	assert(args[0].type == ejit_type_from(TYPE_DATE));
	assert(args[1].type == ejit_type_from(TYPE_INT));

	unsigned month = 0;
	unsigned day = 0;
	date_split(args[0].u64, NULL, &month, &day);
	return date_from_numbers(args[1].u64, month, day);
}

static void lower_store_year(struct fn *f, struct ast *base, struct ast* r)
{
	assert(base->t == TYPE_DATE);
	assert(r->t == TYPE_INT);
	struct ejit_operand args[2] = {
		EJIT_OPERAND_GPR(regno(base), ejit_type_from(TYPE_DATE)),
		EJIT_OPERAND_GPR(regno(r), ejit_type_from(TYPE_INT))
	};

	ejit_escapei(f->f, escape_store_year, 2, args);
	ejit_retval(f->f, reg(base));
}

static long escape_store_month(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 2);
	assert(args[0].type == ejit_type_from(TYPE_DATE));
	assert(args[1].type == ejit_type_from(TYPE_INT));

	unsigned year = 0;
	unsigned day = 0;
	date_split(args[0].u64, &year, NULL, &day);
	return date_from_numbers(year, args[1].u64, day);
}

static void lower_store_month(struct fn *f, struct ast *base, struct ast* r)
{
	assert(base->t == TYPE_DATE);
	assert(r->t == TYPE_INT);
	struct ejit_operand args[2] = {
		EJIT_OPERAND_GPR(regno(base), ejit_type_from(TYPE_DATE)),
		EJIT_OPERAND_GPR(regno(r), ejit_type_from(TYPE_INT))
	};

	ejit_escapei(f->f, escape_store_month, 2, args);
	ejit_retval(f->f, reg(base));
}

static long escape_store_day(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 2);
	assert(args[0].type == ejit_type_from(TYPE_DATE));
	assert(args[1].type == ejit_type_from(TYPE_INT));

	unsigned year = 0;
	unsigned month = 0;
	date_split(args[0].u64, &year, &month, NULL);
	return date_from_numbers(year, month, args[1].u64);
}

static void lower_store_day(struct fn *f, struct ast *base, struct ast* r)
{
	assert(base->t == TYPE_DATE);
	assert(r->t == TYPE_INT);
	struct ejit_operand args[2] = {
		EJIT_OPERAND_GPR(regno(base), ejit_type_from(TYPE_DATE)),
		EJIT_OPERAND_GPR(regno(r), ejit_type_from(TYPE_INT))
	};

	ejit_escapei(f->f, escape_store_day, 2, args);
	ejit_retval(f->f, reg(base));
}

static void lower_dot_assign(struct fn *f, struct ast *n)
{
	struct ast *r = assign_r(n);
	struct ast *l = assign_l(n);

	lower(f, r);

	struct ast *base = dot_base(l);
	lower(f, base);

	if (same_id(l->id, "day"))
		lower_store_day(f, l, r);

	else if (same_id(l->id, "month"))
		lower_store_month(f, l, r);

	else if (same_id(l->id, "year"))
		lower_store_year(f, l, r);
	else
		abort();

	n->l = base->l;
	put(f, n->l);
}

static void get_id_loc(struct fn *f, struct ast *n)
{
	assert(n->k == AST_ID);
	struct ast *exists = file_scope_find(n->scope, n->id);
	assert(exists);

	/* using ast nodes/scope lookup as convenient way to store variable ->
	 * location mappings */
	n->l = exists->l;
	assert(n->l.g != 0);
	if (n->l.g != -1) {
		/* global variables should get loaded to the stack for handling */
		n->l.s = f->sp;
	}
}

static void lower_id(struct fn *f, struct ast *n)
{
	/* first calculate location */
	get_id_loc(f, n);
	/* then value at location */
	get(f, n->l);
}

static void lower_assign(struct fn *f, struct ast *n)
{
	struct ast *l = assign_l(n);
	if (l->k == AST_DOT)
		return lower_dot_assign(f, n);

	struct ast *r = assign_r(n);

	lower(f, r);
	/* get location for variable, at this point we should be certain that
	 * we're dealing with a simple variable */
	get_id_loc(f, l);
	ejit_movr(f->f, reg(l), reg(r));
	/* maybe store possible global value to global array */
	put(f, l->l);

	n->l = null_loc();
}


static void lower_return(struct fn *f, struct ast *n)
{
	struct ast *expr = return_expr(n);
	if (expr->k == AST_CONST_INT) {
		ejit_reti(f->f, expr->v);
		n->l = null_loc();
		return;
	}

	lower(f, expr);
	ejit_retr(f->f, reg(expr));
	n->l = null_loc();
}

static long escape_load_weeknum(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 1);
	assert(args[0].type == ejit_type_from(TYPE_DATE));

	struct tm time = tm_from_date(args[0].u64);
	return time.tm_yday / 7;
}

static void lower_load_weeknum(struct fn *f, struct ast *n, struct ast *base)
{
	assert(base->t == TYPE_DATE);
	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(regno(base), ejit_type_from(TYPE_DATE))
	};

	ejit_escapei(f->f, escape_load_weeknum, 1, args);
	ejit_retval(f->f, reg(n));
}

static long escape_load_weekday(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 1);
	assert(args[0].type == ejit_type_from(TYPE_DATE));

	struct tm time = tm_from_date(args[0].u64);
	return time.tm_wday;
}

static void lower_load_weekday(struct fn *f, struct ast *n, struct ast *base)
{
	assert(base->t == TYPE_DATE);
	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(regno(base), ejit_type_from(TYPE_DATE))
	};

	ejit_escapei(f->f, escape_load_weekday, 1, args);
	ejit_retval(f->f, reg(n));
}

static long escape_load_year(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 1);
	assert(args[0].type == ejit_type_from(TYPE_DATE));

	unsigned year = 0;
	date_split(args[0].u64, &year, NULL, NULL);
	return year;
}

static void lower_load_year(struct fn *f, struct ast *n, struct ast *base)
{
	assert(base->t == TYPE_DATE);
	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(regno(base), ejit_type_from(TYPE_DATE))
	};

	ejit_escapei(f->f, escape_load_year, 1, args);
	ejit_retval(f->f, reg(n));
}

static long escape_load_month(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 1);
	assert(args[0].type == ejit_type_from(TYPE_DATE));

	unsigned month = 0;
	date_split(args[0].u64, NULL, &month, NULL);
	return month;
}

static void lower_load_month(struct fn *f, struct ast *n, struct ast *base)
{
	assert(base->t == TYPE_DATE);
	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(regno(base), ejit_type_from(TYPE_DATE))
	};

	ejit_escapei(f->f, escape_load_month, 1, args);
	ejit_retval(f->f, reg(n));
}

static long escape_load_day(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 1);
	assert(args[0].type == ejit_type_from(TYPE_DATE));

	unsigned day = 0;
	date_split(args[0].u64, NULL, NULL, &day);
	return day;
}

static void lower_load_day(struct fn *f, struct ast *n, struct ast *base)
{
	assert(base->t == TYPE_DATE);
	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(regno(base), ejit_type_from(TYPE_DATE))
	};

	ejit_escapei(f->f, escape_load_day, 1, args);
	ejit_retval(f->f, reg(n));
}

static void lower_attr(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *base = attr_base(n);
	lower(f, base);

	if (same_id(n->id, "day"))
		lower_load_day(f, n, base);

	else if (same_id(n->id, "month"))
		lower_load_month(f, n, base);

	else if (same_id(n->id, "year"))
		lower_load_year(f, n, base);

	else if (same_id(n->id, "weekday"))
		lower_load_weekday(f, n, base);

	else if (same_id(n->id, "weeknum"))
		lower_load_weeknum(f, n, base);

	else
		abort();
}

static long escape_print_space(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 0);
	putchar(' ');
	return 0;
}

static long escape_print_newline(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 0);
	putchar('\n');
	return 0;
}

static void lower_print(struct fn *f, struct ast *p)
{
	foreach_node(n, print_items(p)) {
		lower(f, n);

		/* don't print space on last element */
		if (n->n)
			ejit_escapei(f->f, escape_print_space, 0, NULL);
	}

	ejit_escapei(f->f, escape_print_newline, 0, NULL);
	p->l = null_loc();
}

static void lower_proc_call(struct fn *f, struct ast *c)
{
	size_t count = 0;
	foreach_node(n, proc_call_args(c)) {
		/* place each argument in its own stack location to await being
		 * used */
		f->sp += 1; lower(f, n); count += 1;
	}

	size_t i = 0;
	struct ejit_operand args[count];
	foreach_node(n, proc_call_args(c)) {
		args[i] = EJIT_OPERAND_GPR(regno(n), ejit_type_from(n->t));
		i++;
	}

	struct ast *def = file_scope_find(c->scope, proc_call_id(c));
	struct fn *target = vec_at(&fns, def->l.s);
	assert(target);

	ejit_calli(f->f, target->f, count, args);

	f->sp -= count;

	c->l = build_local_loc(f->sp);
	if (c->t != TYPE_VOID)
		ejit_retval(f->f, reg(c));
}

static void lower_func_call(struct fn *f, struct ast *c)
{
	size_t count = 0;
	foreach_node(n, func_call_args(c)) {
		f->sp += 1; lower(f, n); count += 1;
	}

	size_t i = 0;
	struct ejit_operand args[count];
	foreach_node(n, func_call_args(c)) {
		args[i] = EJIT_OPERAND_GPR(regno(n), ejit_type_from(n->t));
		i++;
	}

	struct ast *def = file_scope_find(c->scope, func_call_id(c));
	struct fn *target = vec_at(&fns, def->l.s);
	assert(target);

	ejit_calli(f->f, target->f, count, args);

	f->sp -= count;

	c->l = build_local_loc(f->sp);
	if (c->t != TYPE_VOID)
		ejit_retval(f->f, reg(c));
}

static void lower_eq(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = eq_l(n);
	struct ast *r = eq_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	ejit_eqr(f->f, reg(n), reg(l), reg(r));
}

static void lower_lt(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = lt_l(n);
	struct ast *r = lt_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	ejit_ltr(f->f, reg(n), reg(l), reg(r));
}

static void lower_pos(struct fn *f, struct ast *n)
{
	struct ast *expr = pos_expr(n);
	lower(f, expr);
	n->l = expr->l;
}

static void lower_neg(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *expr = neg_expr(n);
	f->sp += 1; lower(f, expr);
	f->sp -= 1;

	ejit_negr(f->f, reg(n), reg(expr));
}

static long escape_print_date(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 1);
	assert(args[0].type == ejit_type_from(TYPE_DATE));

	char str[11] = {0};
	date_to_string(str, args[0].u64);
	printf("%s", str);
	return 0;
}

static void lower_print_date(struct fn *f, struct ast *n)
{
	struct ast *expr = print_date_expr(n);
	lower(f, expr);

	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(regno(expr), ejit_type_from(TYPE_DATE))
	};

	ejit_escapei(f->f, escape_print_date, 1, args);
	n->l = null_loc();
}

static long escape_print_int(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 1);
	assert(args[0].type == ejit_type_from(TYPE_INT));
	printf("%lli", (long long)args[0].l);
	return 0;
}

static void lower_print_int(struct fn *f, struct ast *n)
{
	struct ast *expr = print_int_expr(n);
	lower(f, expr);

	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(regno(expr), ejit_type_from(TYPE_INT))
	};

	ejit_escapei(f->f, escape_print_int, 1, args);
	n->l = null_loc();
}

static long escape_print_string(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 1);
	assert(args[0].type == ejit_type_from(TYPE_STRING));
	puts(args[0].p);
	return 0;
}

static void lower_print_string(struct fn *f, struct ast *n)
{
	struct ast *expr = print_string_expr(n);
	lower(f, expr);

	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(regno(expr), ejit_type_from(TYPE_STRING))
	};

	ejit_escapei(f->f, escape_print_string, 1, args);
	n->l = null_loc();
}

static long escape_print_bool(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 1);
	assert(args[0].type == ejit_type_from(TYPE_BOOL));
	puts(args[0].l ? "true" : "false");
	return 0;
}

static void lower_print_bool(struct fn *f, struct ast *n)
{
	struct ast *expr = print_bool_expr(n);
	lower(f, expr);

	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(regno(expr), ejit_type_from(TYPE_BOOL))
	};

	ejit_escapei(f->f, escape_print_bool, 1, args);
	n->l = null_loc();
}

static ph_date_t date_add(ph_date_t d, long i)
{
	struct tm time = tm_from_date(d);
	time.tm_mday = i;
	mktime(&time);
	return date_from_tm(time);
}

static long escape_date_add(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 2);
	assert(args[0].type == ejit_type_from(TYPE_DATE));
	assert(args[1].type == ejit_type_from(TYPE_INT));
	return date_add(args[0].u64, args[1].l);
}

static void lower_date_add(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *l = date_add_l(n);
	struct ast *r = date_add_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	struct ejit_operand args[2] = {
		EJIT_OPERAND_GPR(regno(l), ejit_type_from(TYPE_DATE)),
		EJIT_OPERAND_GPR(regno(r), ejit_type_from(TYPE_INT)),
	};

	ejit_escapei(f->f, escape_date_add, 2, args);
	ejit_retval(f->f, reg(n));
}

static ph_date_t date_sub(ph_date_t d, long i)
{
	struct tm time = tm_from_date(d);
	time.tm_mday -= i;
	mktime(&time);
	return date_from_tm(time);
}

static long escape_date_sub(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 2);
	assert(args[0].type == ejit_type_from(TYPE_DATE));
	assert(args[1].type == ejit_type_from(TYPE_INT));
	return date_sub(args[0].u64, args[1].l);
}

static void lower_date_sub(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *l = date_sub_l(n);
	struct ast *r = date_sub_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	struct ejit_operand args[2] = {
		EJIT_OPERAND_GPR(regno(l), ejit_type_from(TYPE_DATE)),
		EJIT_OPERAND_GPR(regno(r), ejit_type_from(TYPE_INT)),
	};

	ejit_escapei(f->f, escape_date_sub, 2, args);
	ejit_retval(f->f, reg(n));
}

static long escape_date_diff(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 2);
	assert(args[0].type == ejit_type_from(TYPE_DATE));
	assert(args[1].type == ejit_type_from(TYPE_DATE));

	struct tm tm0 = tm_from_date((ph_date_t)args[0].u64);
	struct tm tm1 = tm_from_date((ph_date_t)args[1].u64);

	time_t t0 = mktime(&tm0);
	time_t t1 = mktime(&tm1);

	double seconds = difftime(t0, t1);
	return round(seconds / 86400);
}

static void lower_date_diff(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *l = date_diff_l(n);
	struct ast *r = date_diff_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	struct ejit_operand args[2] = {
		EJIT_OPERAND_GPR(regno(l), ejit_type_from(TYPE_DATE)),
		EJIT_OPERAND_GPR(regno(r), ejit_type_from(TYPE_DATE)),
	};

	ejit_escapei(f->f, escape_date_diff, 2, args);
	ejit_retval(f->f, reg(n));
}

static struct ejit_reloc branch_if(struct fn *f, struct ast *cond, bool branch)
{
	if (cond->k == AST_EQ && (eq_r(cond))->k == AST_CONST_INT) {
		struct ast *r = eq_r(cond);
		struct ast *l = eq_l(cond);
		lower(f, l);
		return (branch ? ejit_beqi : ejit_bnei)(f->f, reg(l), r->v);
	}

	if (cond->k == AST_EQ && (eq_l(cond))->k == AST_CONST_INT) {
		struct ast *r = eq_r(cond);
		struct ast *l = eq_l(cond);
		lower(f, r);
		return (branch ? ejit_beqi : ejit_bnei)(f->f, reg(r), l->v);
	}

	if (cond->k == AST_EQ) {
		struct ast *r = eq_r(cond);
		struct ast *l = eq_l(cond);
		lower(f, r);
		lower(f, l);
		return (branch ? ejit_beqr : ejit_bner)(f->f, reg(r), reg(l));
	}

	if (cond->k == AST_LT && (lt_r(cond))->k == AST_CONST_INT) {
		struct ast *r = lt_r(cond);
		struct ast *l = lt_l(cond);
		lower(f, l);
		return (branch ? ejit_blti : ejit_bgei)(f->f, reg(l), r->v);
	}

	if (cond->k == AST_LT && (lt_l(cond))->k == AST_CONST_INT) {
		struct ast *r = lt_r(cond);
		struct ast *l = lt_l(cond);
		lower(f, r);
		return (branch ? ejit_bgti : ejit_blei)(f->f, reg(r), l->v);
	}

	if (cond->k == AST_LT) {
		struct ast *r = lt_r(cond);
		struct ast *l = lt_l(cond);
		lower(f, r);
		lower(f, l);
		return (branch ? ejit_bltr : ejit_bger)(f->f, reg(l), reg(r));
	}

	if (cond->k == AST_CONST_INT) {
		if (branch && cond->v)
			/* always jump */
			return ejit_jmp(f->f);

		if (!branch && !cond->v)
			return ejit_jmp(f->f);

		/** @todo here it would be useful to have some kind of nop with
		 * relocation, since we must return a branch of some kind, but
		 * we know that we never want to jump. */
	}

	/* fallback */
	lower(f, cond);
	return (branch ? ejit_bnei : ejit_beqi)(f->f, reg(cond), 0);
}

static void lower_until(struct fn *f, struct ast *n)
{
	struct ejit_label l = ejit_label(f->f);
	ejit_inc_prio(f->f, 100); /* completely arbitrary */

	lower_list(f, until_body(n));

	struct ast *cond = until_cond(n);
	struct ejit_reloc branch = branch_if(f, cond, false);
	ejit_patch(f->f, branch, l);
	ejit_dec_prio(f->f, 100);

	n->l = null_loc();
}

static void lower_unless(struct fn *f, struct ast *n)
{
	struct ast *cond = unless_cond(n);
	struct ejit_reloc branch = branch_if(f, cond, true);

	struct ast *body = unless_body(n);
	lower_list(f, body);

	struct ejit_reloc jump = ejit_jmp(f->f);

	struct ejit_label l = ejit_label(f->f);
	ejit_patch(f->f, branch, l);

	struct ast *otherwise = unless_otherwise(n);
	lower_list(f, otherwise);

	l = ejit_label(f->f);
	ejit_patch(f->f, jump, l);

	n->l = null_loc();
}

static void lower_unless_expr(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *cond = unless_expr_cond(n);
	struct ejit_reloc branch = branch_if(f, cond, true);

	struct ast *body = unless_expr_body(n);
	lower(f, body);
	ejit_movr(f->f, reg(n), reg(body));

	struct ejit_reloc jump = ejit_jmp(f->f);

	struct ejit_label l = ejit_label(f->f);
	ejit_patch(f->f, branch, l);

	struct ast *otherwise = unless_expr_otherwise(n);
	lower(f, otherwise);
	ejit_movr(f->f, reg(n), reg(otherwise));

	l = ejit_label(f->f);
	ejit_patch(f->f, jump, l);
}

static long escape_today(size_t argc, const struct ejit_arg args[argc])
{
	assert(argc == 0);
	return current_date();
}

static void lower_builtin_call(struct fn *f, struct ast *n)
{
	/* for now we only support Today(), which doesn't have any args */
	assert(same_id(builtin_call_id(n), "Today"));
	n->l = build_local_loc(f->sp);

	ejit_escapei(f->f, escape_today, 0, NULL);
	ejit_retval(f->f, reg(n));
}

static void lower(struct fn *f, struct ast *n)
{
	/* technically it would be possible for some lowering to use stack space
	 * without calling lower(), probably causing difficult to debug runtime weirdness.
	 * Not currently an issue, but something that would need to be taken
	 * into account if this instruction set was extended at some point. */
	if (f->max_sp < f->sp)
		f->max_sp = f->sp;

	switch (n->k) {
	case AST_PROC_DEF: break;
	case AST_FUNC_DEF: break;
	case AST_VAR_DEF: lower_var(f, n); break;
	case AST_FORMAL_DEF: lower_formal(f, n); break;
	case AST_CONST_STRING: lower_string(f, n); break;
	case AST_CONST_DATE: lower_date(f, n); break;
	case AST_CONST_INT: lower_int(f, n); break;
	case AST_ASSIGN: lower_assign(f, n); break;
	case AST_ADD: lower_add(f, n); break;
	case AST_SUB: lower_sub(f, n); break;
	case AST_MUL: lower_mul(f, n); break;
	case AST_DIV: lower_div(f, n); break;
	case AST_ID: lower_id(f, n); break;
	case AST_RETURN: lower_return(f, n); break;
	case AST_ATTR: lower_attr(f, n); break;
	case AST_PRINT: lower_print(f, n); break;
	case AST_PROC_CALL: lower_proc_call(f, n); break;
	case AST_FUNC_CALL: lower_func_call(f, n); break;
	case AST_BUILTIN_CALL: lower_builtin_call(f, n); break;
	case AST_EQ: lower_eq(f, n); break;
	case AST_LT: lower_lt(f, n); break;
	case AST_POS: lower_pos(f, n); break;
	case AST_NEG: lower_neg(f, n); break;
	case AST_PRINT_DATE: lower_print_date(f, n); break;
	case AST_PRINT_STRING: lower_print_string(f, n); break;
	case AST_PRINT_BOOL: lower_print_bool(f, n); break;
	case AST_PRINT_INT: lower_print_int(f, n); break;
	case AST_DATE_ADD: lower_date_add(f, n); break;
	case AST_DATE_SUB: lower_date_sub(f, n); break;
	case AST_DATE_DIFF: lower_date_diff(f, n); break;
	case AST_UNTIL: lower_until(f, n); break;
	case AST_UNLESS: lower_unless(f, n); break;
	case AST_UNLESS_EXPR: lower_unless_expr(f, n); break;

	/* handled by assign */
	case AST_DOT: break;
	}

	/* each ast node is assigned some location, regardless of if it actually
	 * needs one as a sanity check */
	assert(n->l.g || n->l.s);
}

static void lower_global_var(struct fn *f, struct ast *n) {
	n->l = build_global_loc(globals++);
	struct ast *expr = var_expr(n);
	lower(f, expr);

	/* move directly from expr result to global */
	struct loc l = expr->l;
	l.g = n->l.g;
	put(f, l);
}

static void add_proc(struct ast *n) {
	size_t idx = vec_len(&fns);
	n->l = build_local_loc(idx);
	struct fn f = {.name = proc_id(n),
		       .idx = idx,
		       .sp = 0,
		       .max_sp = 0};

	vect_append(struct fn, fns, &f);
}

static void add_func(struct ast *n) {
	size_t idx = vec_len(&fns);
	n->l = build_local_loc(idx);
	struct fn f = {.name = func_id(n),
		       .idx = idx,
		       .sp = 0,
		       .max_sp = 0};

	vect_append(struct fn, fns, &f);
}

static void lower_proc_def(struct ast *d)
{
	struct fn *f = vec_at(&fns, d->l.s);
	assert(f);

	struct vec formals = lower_formals(f, proc_formals(d));
	f->f = ejit_create_func(ejit_type_from(d->t), vec_len(&formals),
	                        formals.buf);
	f->sp = vec_len(&formals); f->max_sp = vec_len(&formals);
	vec_destroy(&formals);

	lower_list(f, proc_vars(d));
	lower_list(f, proc_body(d));

	if (d->t == TYPE_VOID)
		ejit_reti(f->f, 0);

	/* ph_date_t is inherently 64bit so we can't really use 32bit JIT */
	ejit_select_compile_func(f->f, f->max_sp + 1, 0, true, JIT);
}

static void lower_func_def(struct ast *d)
{
	struct fn *f = vec_at(&fns, d->l.s);
	assert(f);

	struct vec formals = lower_formals(f, func_formals(d));
	f->f = ejit_create_func(EJIT_VOID, vec_len(&formals), formals.buf);
	f->sp = vec_len(&formals); f->max_sp = vec_len(&formals);
	vec_destroy(&formals);

	lower_list(f, func_vars(d));
	lower_list(f, func_body(d));
	ejit_select_compile_func(f->f, f->max_sp + 1, 0, true, JIT);
}

struct fn *find_fn(size_t idx)
{
	return vec_at(&fns, idx);
}

int lower_ast(struct ast *tree)
{
	fns = vec_create(sizeof(struct fn));

	/* make body of file out to be a kind of main function, it will always
	 * be at index 0 */
	struct fn main = {.name = "main",
		          .idx = 0,
		          .sp = 0,
		          .max_sp = 0};

	vect_append(struct fn, fns, &main);

	/* first create function nodes in fns to assign each procedure an index */
	foreach_node(n, tree) {
		switch (n->k) {
		case AST_PROC_DEF: add_proc(n); break;
		case AST_FUNC_DEF: add_func(n); break;
		default:
		}
	}

	struct fn *f = &vect_at(struct fn, fns, 0);

	/* reserve arg 0 for global pointer */
	struct ejit_operand args[1] = {
		EJIT_OPERAND_GPR(0, EJIT_POINTER)
	};
	f->sp = 1; f->max_sp = 1;
	f->f = ejit_create_func(EJIT_VOID, 1, args);

	/* we can't treat file scope as a regular function, as all variable
	 * definitions must be made global, so we have a little bit of
	 * duplicated code here but that's fine */
	foreach_node(n, tree) {
		switch (n->k) {
		case AST_VAR_DEF: lower_global_var(f, n); break;
		case AST_PROC_DEF: lower_proc_def(n); break;
		case AST_FUNC_DEF: lower_func_def(n); break;
		default: lower(f, n);
		}
	}

	ejit_select_compile_func(f->f, f->max_sp + 1, 0, true, JIT);
	return 0;
}

static void destroy_fn(struct fn *f)
{
	ejit_destroy_func(f->f);
}

void destroy_lowering()
{
	foreach_vec(fi, fns) {
		struct fn *f = vec_at(&fns, fi);
		destroy_fn(f);
	}

	vec_destroy(&fns);
}