#include <stdlib.h>
#include <sys/mman.h>

#include <posthaste/lower.h>
#include <posthaste/scope.h>

#define UNUSED(x) (void)x

static struct vec fns = {0};
/* zero is unintialized, global 1 reserved as null, so skip two first globals */
static size_t globals = 2;

static void lower(struct fn *f, struct ast *n);
static void lower_list(struct fn *f, struct ast *l)
{
	foreach_node(n, l) {
		lower(f, n);
	}
}

static size_t loc_as_int(struct loc l)
{
	size_t a = l.o;
	a |= (size_t)(l.l) << (sizeof(size_t) * 8 - 1);
	return a;
}

static struct loc build_local_loc(size_t idx)
{
	return (struct loc){.l = 1, .o = idx};
}

static struct loc build_global_loc(size_t idx)
{
	return (struct loc){.l = 0, .o = idx};
}

static struct loc null_loc()
{
	/* don't exactly love this but I guess it works */
	return build_global_loc(1);
}

static void output_insn(struct fn *f, enum insn_kind k, struct loc o,
                        struct loc i0, struct loc i1, int64_t v)
{
	struct insn i = {.k = k, .o = o, .i0 = i0, .i1 = i1, .v = v};
	vect_append(struct insn, f->insns, &i);
}

static void output_label(struct fn *f)
{
	output_insn(f, LABEL, null_loc(), null_loc(), null_loc(), 0);
}

static void lower_var(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp++);
	struct ast *expr = var_expr(n);
	lower(f, expr);
	output_insn(f, MOVE, n->l, expr->l, null_loc(), 0);
}

static void lower_formal(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp++);
}

static void lower_date(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	output_insn(f, CONST, n->l, null_loc(), null_loc(), n->v);
}

static void lower_string(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	output_insn(f, CONST, n->l, null_loc(), null_loc(), (int64_t)n->id);
}

static void lower_int(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	output_insn(f, CONST, n->l, null_loc(), null_loc(), n->v);
}

static void lower_add(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = add_l(n);
	struct ast *r = add_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	output_insn(f, ADD, n->l, l->l, r->l, 0);
}

static void lower_sub(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = sub_l(n);
	struct ast *r = sub_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	output_insn(f, SUB, n->l, l->l, r->l, 0);
}

static void lower_mul(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = mul_l(n);
	struct ast *r = mul_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	output_insn(f, MUL, n->l, l->l, r->l, 0);
}

static void lower_div(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = div_l(n);
	struct ast *r = div_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	output_insn(f, DIV, n->l, l->l, r->l, 0);
}

static void lower_dot_assign(struct fn *f, struct ast *n)
{
	struct ast *r = assign_r(n);
	struct ast *l = assign_l(n);

	lower(f, r);

	struct ast *base = dot_base(l);
	lower(f, base);

	enum insn_kind store = STORE_DAY;
	if (same_id(l->id, "day"))
		store = STORE_DAY;

	else if (same_id(l->id, "month"))
		store = STORE_MONTH;

	else if (same_id(l->id, "year"))
		store = STORE_YEAR;
	else
		abort();

	output_insn(f, store, base->l, base->l, r->l, 0);
	n->l = base->l;
}

static void lower_assign(struct fn *f, struct ast *n)
{
	struct ast *l = assign_l(n);
	if (l->k == AST_DOT)
		return lower_dot_assign(f, n);

	struct ast *r = assign_r(n);

	lower(f, r);
	lower(f, l);
	output_insn(f, MOVE, l->l, r->l, null_loc(), 0);
	n->l = null_loc();
}

static void lower_id(struct fn *f, struct ast *n)
{
	UNUSED(f);
	struct ast *exists = file_scope_find(n->scope, n->id);
	assert(exists);

	n->l = exists->l;
}

static void lower_return(struct fn *f, struct ast *n)
{
	struct ast *expr = return_expr(n);
	lower(f, expr);
	output_insn(f, RET, null_loc(), expr->l, null_loc(), 0);
	n->l = null_loc();
}

static void lower_attr(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *base = attr_base(n);
	lower(f, base);

	enum insn_kind load = LOAD_DAY;
	if (same_id(n->id, "day"))
		load = LOAD_DAY;

	else if (same_id(n->id, "month"))
		load = LOAD_MONTH;

	else if (same_id(n->id, "year"))
		load = LOAD_YEAR;

	else if (same_id(n->id, "weekday"))
		load = LOAD_WEEKDAY;

	else if (same_id(n->id, "weeknum"))
		load = LOAD_WEEKNUM;

	else
		abort();

	output_insn(f, load, n->l, base->l, null_loc(), 0);
}

static void lower_print(struct fn *f, struct ast *p)
{
	foreach_node(n, print_items(p)) {
		lower(f, n);

		/* don't print space on last element */
		if (n->n)
			output_insn(f, PRINT_SPACE, null_loc(), null_loc(),
			            null_loc(), 0);
	}

	output_insn(f, PRINT_NEWLINE, null_loc(), null_loc(), null_loc(), 0);
	p->l = null_loc();
}

static void lower_proc_call(struct fn *f, struct ast *c)
{
	size_t count = 0;
	foreach_node(n, proc_call_args(c)) {
		f->sp += 1; lower(f, n); count += 1;
	}

	size_t i = 0;
	foreach_node(n, proc_call_args(c)) {
		output_insn(f, ARG, null_loc(), n->l, null_loc(), i);
		i++;
	}

	struct ast *def = file_scope_find(c->scope, proc_call_id(c));
	output_insn(f, CALL, null_loc(), null_loc(), null_loc(),
	            loc_as_int(def->l));

	c->l = build_local_loc(f->sp);
	if (c->t != TYPE_VOID)
		output_insn(f, RETVAL, c->l, null_loc(), null_loc(), 0);

	f->sp -= count;
}

static void lower_func_call(struct fn *f, struct ast *c)
{
	size_t count = 0;
	foreach_node(n, func_call_args(c)) {
		f->sp += 1; lower(f, n); count += 1;
	}

	size_t i = 0;
	foreach_node(n, func_call_args(c)) {
		output_insn(f, ARG, null_loc(), n->l, null_loc(), i);
		i++;
	}

	struct ast *def = file_scope_find(c->scope, func_call_id(c));
	output_insn(f, CALL, null_loc(), null_loc(), null_loc(),
	            loc_as_int(def->l));
	c->l = build_local_loc(f->sp);
	if (c->t != TYPE_VOID)
		output_insn(f, RETVAL, c->l, null_loc(), null_loc(), 0);

	f->sp -= count;
}

static void lower_eq(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = eq_l(n);
	struct ast *r = eq_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	output_insn(f, EQ, n->l, l->l, r->l, 0);
}

static void lower_lt(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *l = lt_l(n);
	struct ast *r = lt_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	output_insn(f, LT, n->l, l->l, r->l, 0);
}

static void lower_pos(struct fn *f, struct ast *n)
{
	struct ast *expr = pos_expr(n);
	lower(f, expr);
	n->l = expr->l;
}

static void lower_neg(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);
	struct ast *expr = neg_expr(n);
	f->sp += 1; lower(f, expr);
	f->sp -= 1;

	output_insn(f, NEG, n->l, expr->l, null_loc(), 0);
}

static void lower_print_date(struct fn *f, struct ast *n)
{
	struct ast *expr = print_date_expr(n);
	lower(f, expr);
	output_insn(f, PRINT_DATE, null_loc(), expr->l, null_loc(), 0);
	n->l = null_loc();
}

static void lower_print_int(struct fn *f, struct ast *n)
{
	struct ast *expr = print_int_expr(n);
	lower(f, expr);
	output_insn(f, PRINT_INT, null_loc(), expr->l, null_loc(), 0);
	n->l = null_loc();
}

static void lower_print_string(struct fn *f, struct ast *n)
{
	struct ast *expr = print_string_expr(n);
	lower(f, expr);
	output_insn(f, PRINT_STRING, null_loc(), expr->l, null_loc(), 0);
	n->l = null_loc();
}

static void lower_print_bool(struct fn *f, struct ast *n)
{
	struct ast *expr = print_bool_expr(n);
	lower(f, expr);
	output_insn(f, PRINT_BOOL, null_loc(), expr->l, null_loc(), 0);
	n->l = null_loc();
}

static void lower_date_add(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *l = date_add_l(n);
	struct ast *r = date_add_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	output_insn(f, DATE_ADD, n->l, l->l, r->l, 0);
}

static void lower_date_sub(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *l = date_sub_l(n);
	struct ast *r = date_sub_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	output_insn(f, DATE_SUB, n->l, l->l, r->l, 0);
}

static void lower_date_diff(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *l = date_diff_l(n);
	struct ast *r = date_diff_r(n);

	f->sp += 1; lower(f, l);
	f->sp += 1; lower(f, r);
	f->sp -= 2;

	output_insn(f, DATE_DIFF, n->l, l->l, r->l, 0);
}

static void lower_until(struct fn *f, struct ast *n)
{
	size_t off = vec_len(&f->insns);
	output_label(f);

	lower_list(f, until_body(n));

	struct ast *cond = until_cond(n);
	lower(f, cond);

	output_insn(f, BZ, null_loc(), cond->l, null_loc(), off);
	n->l = null_loc();
}

static void patch_branch(struct fn *f, size_t branch, size_t off)
{
	assert((vect_at(struct insn, f->insns, off)).k == LABEL);

	struct insn i = vect_at(struct insn, f->insns, branch);
	i.v = off;
	vect_at(struct insn, f->insns, branch) = i;
}

static void lower_unless(struct fn *f, struct ast *n)
{
	struct ast *cond = unless_cond(n);
	lower(f, cond);

	size_t branch = vec_len(&f->insns);
	/* placeholder */
	output_insn(f, B, null_loc(), cond->l, null_loc(), 0);

	struct ast *body = unless_body(n);
	lower_list(f, body);

	size_t jump = vec_len(&f->insns);
	/* placeholder */
	output_insn(f, J, null_loc(), null_loc(), null_loc(), 0);

	size_t off = vec_len(&f->insns);
	output_label(f);
	patch_branch(f, branch, off);

	struct ast *otherwise = unless_otherwise(n);
	lower_list(f, otherwise);

	off = vec_len(&f->insns);
	output_label(f);
	patch_branch(f, jump, off);

	n->l = null_loc();
}

static void lower_unless_expr(struct fn *f, struct ast *n)
{
	n->l = build_local_loc(f->sp);

	struct ast *cond = unless_expr_cond(n);
	lower(f, cond);

	size_t branch = vec_len(&f->insns);
	/* placeholder */
	output_insn(f, B, null_loc(), cond->l, null_loc(), 0);

	struct ast *body = unless_expr_body(n);
	lower(f, body);

	output_insn(f, MOVE, n->l, body->l, null_loc(), 0);

	size_t jump = vec_len(&f->insns);
	/* placeholder */
	output_insn(f, J, null_loc(), null_loc(), null_loc(), 0);

	size_t off = vec_len(&f->insns);
	output_label(f);
	patch_branch(f, branch, off);

	struct ast *otherwise = unless_expr_otherwise(n);
	lower(f, otherwise);
	output_insn(f, MOVE, n->l, otherwise->l, null_loc(), 0);

	off = vec_len(&f->insns);
	output_label(f);
	patch_branch(f, jump, off);
}

static void lower_builtin_call(struct fn *f, struct ast *n)
{
	/* for now we only support Today(), which doesn't have any args */
	assert(same_id(builtin_call_id(n), "Today"));

	n->l = build_local_loc(f->sp);
	output_insn(f, TODAY, n->l, null_loc(), null_loc(), 0);
}

static void lower(struct fn *f, struct ast *n)
{
	if (f->max_sp < f->sp)
		f->max_sp = f->sp;

	switch (n->k) {
	case AST_PROC_DEF: break;
	case AST_FUNC_DEF: break;
	case AST_VAR_DEF: lower_var(f, n); break;
	case AST_FORMAL_DEF: lower_formal(f, n); break;
	case AST_CONST_STRING: lower_string(f, n); break;
	case AST_CONST_DATE: lower_date(f, n); break;
	case AST_CONST_INT: lower_int(f, n); break;
	case AST_ASSIGN: lower_assign(f, n); break;
	case AST_ADD: lower_add(f, n); break;
	case AST_SUB: lower_sub(f, n); break;
	case AST_MUL: lower_mul(f, n); break;
	case AST_DIV: lower_div(f, n); break;
	case AST_ID: lower_id(f, n); break;
	case AST_RETURN: lower_return(f, n); break;
	case AST_ATTR: lower_attr(f, n); break;
	case AST_PRINT: lower_print(f, n); break;
	case AST_PROC_CALL: lower_proc_call(f, n); break;
	case AST_FUNC_CALL: lower_func_call(f, n); break;
	case AST_BUILTIN_CALL: lower_builtin_call(f, n); break;
	case AST_EQ: lower_eq(f, n); break;
	case AST_LT: lower_lt(f, n); break;
	case AST_POS: lower_pos(f, n); break;
	case AST_NEG: lower_neg(f, n); break;
	case AST_PRINT_DATE: lower_print_date(f, n); break;
	case AST_PRINT_STRING: lower_print_string(f, n); break;
	case AST_PRINT_BOOL: lower_print_bool(f, n); break;
	case AST_PRINT_INT: lower_print_int(f, n); break;
	case AST_DATE_ADD: lower_date_add(f, n); break;
	case AST_DATE_SUB: lower_date_sub(f, n); break;
	case AST_DATE_DIFF: lower_date_diff(f, n); break;
	case AST_UNTIL: lower_until(f, n); break;
	case AST_UNLESS: lower_unless(f, n); break;
	case AST_UNLESS_EXPR: lower_unless_expr(f, n); break;

	/* handled by assign */
	case AST_DOT: break;
	}

	assert(loc_as_int(n->l) > 0);
}

static void lower_global_var(struct fn *f, struct ast *n) {
	n->l = build_global_loc(globals++);
	struct ast *expr = var_expr(n);
	lower(f, expr);
	output_insn(f, MOVE, n->l, expr->l, null_loc(), 0);
}

static void add_proc(struct ast *n) {
	size_t idx = vec_len(&fns);
	n->l = build_global_loc(idx);
	struct fn f = {.name = proc_id(n),
		       .idx = idx,
		       .sp = 0,
		       .insns = vec_create(sizeof(struct insn))};

	vect_append(struct fn, fns, &f);
}

static void add_func(struct ast *n) {
	size_t idx = vec_len(&fns);
	n->l = build_global_loc(idx);
	struct fn f = {.name = func_id(n),
		       .idx = idx,
		       .sp = 0,
		       .insns = vec_create(sizeof(struct insn))};

	vect_append(struct fn, fns, &f);
}

static void lower_proc_def(struct ast *d)
{
	struct fn *f = vec_at(&fns, d->l.o);
	assert(f);

	f->params = ast_list_len(proc_formals(d));

	lower_list(f, proc_formals(d));
	lower_list(f, proc_vars(d));
	lower_list(f, proc_body(d));

	if (d->t == TYPE_VOID)
		output_insn(f, STOP, null_loc(), null_loc(), null_loc(), 0);
}

static void lower_func_def(struct ast *d)
{
	struct fn *f = vec_at(&fns, d->l.o);
	assert(f);

	f->params = ast_list_len(func_formals(d));

	lower_list(f, func_formals(d));
	lower_list(f, func_vars(d));
	lower_list(f, func_body(d));
}

#ifdef DEBUG
static void dump_loc(struct loc l)
{
	if (is_null_loc(l)) {
		printf("null, ");
		return;
	}

	bool local = l.l;
	if (local) {
		printf("l");
	} else {
		printf("g");
	}

	size_t val = l.o;
	printf("%zd, ", val);
}

static void dump_val(int64_t v)
{
	printf("%lli", (long long)v);
}

static void dump_insn(struct insn i, size_t addr)
{
	printf("//%8zd: ", addr);
#define DUMP(x) case x: printf(#x); break;
	switch (i.k) {
	DUMP(TODAY);
	DUMP(CALL);
	DUMP(MOVE);
	DUMP(ADD);
	DUMP(SUB);
	DUMP(MUL);
	DUMP(DIV);
	DUMP(ARG);
	DUMP(STOP);
	DUMP(RETVAL);
	DUMP(PRINT_DATE);
	DUMP(PRINT_INT);
	DUMP(PRINT_STRING);
	DUMP(PRINT_NEWLINE);
	DUMP(PRINT_SPACE);
	DUMP(PRINT_BOOL);
	DUMP(DATE_ADD);
	DUMP(DATE_SUB);
	DUMP(DATE_DIFF);
	DUMP(STORE_DAY);
	DUMP(STORE_MONTH);
	DUMP(STORE_YEAR);
	DUMP(LOAD_DAY);
	DUMP(LOAD_MONTH);
	DUMP(LOAD_YEAR);
	DUMP(LOAD_WEEKDAY);
	DUMP(LOAD_WEEKNUM);
	DUMP(RET);
	DUMP(CONST);
	DUMP(EQ);
	DUMP(LT);
	DUMP(NEG);
	DUMP(B);
	DUMP(BZ);
	DUMP(J);
	DUMP(LABEL);
	}
#undef DUMP

	printf(" ");

	dump_loc(i.o);
	dump_loc(i.i0);
	dump_loc(i.i1);
	dump_val(i.v);

	printf("\n");
}

static void dump_insns(struct fn *f)
{
	size_t addr = 0;
	foreach_vec(ii, f->insns) {
		struct insn i = vect_at(struct insn, f->insns, ii);
		dump_insn(i, addr);
		addr++;
	}
}
#endif /* DEBUG */

struct fn *find_fn(size_t idx)
{
	return vec_at(&fns, idx);
}

int lower_ast(struct ast *tree)
{
	fns = vec_create(sizeof(struct fn));

	struct fn main = {.name = "main",
		          .idx = 0,
		          .sp = 0,
		          .insns = vec_create(sizeof(struct insn))};

	vect_append(struct fn, fns, &main);

	foreach_node(n, tree) {
		switch (n->k) {
		case AST_PROC_DEF: add_proc(n); break;
		case AST_FUNC_DEF: add_func(n); break;
		default:
		}
	}

	struct fn *f = &vect_at(struct fn, fns, 0);
	foreach_node(n, tree) {
		switch (n->k) {
		case AST_VAR_DEF: lower_global_var(f, n); break;
		case AST_PROC_DEF: lower_proc_def(n); break;
		case AST_FUNC_DEF: lower_func_def(n); break;
		default: lower(f, n);
		}
	}

	output_insn(f, STOP, null_loc(), null_loc(), null_loc(), 0);

#ifdef DEBUG
	foreach_vec(fi, fns) {
		struct fn *f = vec_at(&fns, fi);
		printf("// %s (%zd):\n", f->name, f->idx);
		dump_insns(f);
	}
#endif
	return 0;
}

static void destroy_fn(struct fn *f)
{
	vec_destroy(&f->insns);
	if (f->arena)
		munmap(f->arena, f->size);
}

void destroy_lowering()
{
	foreach_vec(fi, fns) {
		struct fn *f = vec_at(&fns, fi);
		destroy_fn(f);
	}

	vec_destroy(&fns);
}