#include <ejit/ejit.h>
#include "../../deps/lightening/lightening/lightening.c"
#include "../common.h"

static void *alloc_arena(size_t size)
{
	return mmap(NULL, size,
	            PROT_EXEC | PROT_READ | PROT_WRITE,
	            MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
}

static void free_arena(void *arena, size_t size)
{
	munmap(arena, size);
}

static size_t grploc_count(struct ejit_func *f)
{
	return f->gpr < jit_v_num() ? 0 : f->gpr - jit_v_num();
}

static size_t frploc_count(struct ejit_func *f)
{
	return f->fpr < jit_vf_num() ? 0 : f->fpr - jit_vf_num();
}

static size_t stack_size(struct ejit_func *f)
{
	return grploc_count(f) * sizeof(jit_uword_t)
	       + frploc_count(f) * sizeof(jit_float64_t);
}

static jit_off_t stack_loc(size_t l)
{
	assert(l >= jit_v_num());
	return (l - jit_v_num()) * sizeof(jit_uword_t);
}

static jit_off_t stack_loc_f(struct ejit_func *f, size_t l)
{
	assert(l >= jit_vf_num());
	return grploc_count(f) * sizeof(jit_uword_t)
	       + (l - jit_vf_num()) * sizeof(jit_float64_t);
}


struct reloc_helper {
	jit_reloc_t r;
	size_t to;
};

static jit_gpr_t getloc(struct ejit_func *f, jit_state_t *j, size_t l, size_t i)
{
	(void)(f);
	if (l < jit_v_num())
		return jit_v(l);

	jit_ldxi(j, jit_r(i), JIT_SP, stack_loc(l));
	return jit_r(i);
}

static jit_gpr_t getreg(struct ejit_func *f, size_t l, size_t i)
{
	(void)(f);
	if (l < jit_v_num())
		return jit_v(l);

	return jit_r(i);
}

static void putloc(struct ejit_func *f, jit_state_t *j, size_t l, jit_gpr_t r)
{
	(void)(f);
	if (l < jit_v_num()) {
		assert(jit_v(l).regno == r.regno);
		return;
	}

	jit_stxi(j, stack_loc(l), JIT_SP, r);
}

static void compile_label(jit_state_t *j, size_t ii, struct vec *labels)
{
	vect_at(jit_addr_t, *labels, ii) = jit_address(j);
}

static void compile_movi(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i)
{
	jit_gpr_t r = getreg(f, i.r0, 0);
	jit_movi(j, r, i.o);
	putloc(f, j, i.r0, r);
}

static void compile_movr(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i)
{
	jit_gpr_t to = getreg(f, i.r0, 0);
	jit_gpr_t from = getreg(f, i.r1, 1);
	jit_movr(j, to, from);
	putloc(f, j, i.r0, to);
}

static void compile_addr(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i)
{
	jit_gpr_t dst = getreg(f, i.r0, 0);
	jit_gpr_t src0 = getloc(f, j, i.r1, 1);
	jit_gpr_t src1 = getloc(f, j, i.r2, 2);
	jit_addr(j, dst, src0, src1);
	putloc(f, j, i.r0, dst);
}

static void compile_addi(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i)
{
	jit_gpr_t dst = getreg(f, i.r0, 0);
	jit_gpr_t src0 = getloc(f, j, i.r1, 1);
	jit_addi(j, dst, src0, i.o);
	putloc(f, j, i.r0, dst);
}

static void compile_subr(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i)
{
	jit_gpr_t dst = getreg(f, i.r0, 0);
	jit_gpr_t src0 = getloc(f, j, i.r1, 1);
	jit_gpr_t src1 = getloc(f, j, i.r2, 2);
	jit_subr(j, dst, src0, src1);
	putloc(f, j, i.r0, dst);
}

static void compile_subi(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i)
{
	jit_gpr_t dst = getreg(f, i.r0, 0);
	jit_gpr_t src0 = getloc(f, j, i.r1, 1);
	jit_subi(j, dst, src0, i.o);
	putloc(f, j, i.r0, dst);
}

static void compile_stxi64(struct ejit_func *f, jit_state_t *j,
                           struct ejit_insn i)
{
	jit_gpr_t r0 = getloc(f, j, i.r0, 0);
	jit_gpr_t r1 = getloc(f, j, i.r1, 1);
	jit_stxi_l(j, i.o, r1, r0);
}

static void compile_ldxiu64(struct ejit_func *f, jit_state_t *j,
                            struct ejit_insn i)
{
	jit_gpr_t r0 = getreg(f, i.r0, 0);
	jit_gpr_t r1 = getloc(f, j, i.r1, 1);
	jit_ldxi_l(j, r0, r1, i.o);
	putloc(f, j, i.r0, r0);
}

static void compile_reg_cmp(struct ejit_func *f, jit_state_t *j,
                            struct ejit_insn i,
                            jit_reloc_t (*bcomp)(jit_state_t *, jit_gpr_t,
                                                 jit_gpr_t), long same)
{
	jit_gpr_t r0 = getreg(f, i.r0, 0);
	if (i.r1 == i.r2) {
		jit_movi(j, r0, same);
		putloc(f, j, i.r0, r0);
		return;
	}

	jit_gpr_t r1 = getloc(f, j, i.r1, 1);
	jit_gpr_t r2 = getloc(f, j, i.r2, 2);
	jit_reloc_t branch = bcomp(j, r1, r2);

	/* not equal */
	jit_movi(j, r0, 0);
	jit_reloc_t jump = jit_jmp(j);
	jit_patch_there(j, branch, jit_address(j));

	/* equal */
	jit_movi(j, r0, 1);
	jit_patch_there(j, jump, jit_address(j));

	/* write final result */
	putloc(f, j, i.r0, r0);
}

static void compile_eqr(struct ejit_func *f, jit_state_t *j, struct ejit_insn i)
{
	compile_reg_cmp(f, j, i, jit_beqr, 1);
}

static void compile_ltr(struct ejit_func *f, jit_state_t *j, struct ejit_insn i)
{
	compile_reg_cmp(f, j, i, jit_bltr, 0);
}

static void compile_bltr(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i, struct vec *relocs)
{
	jit_gpr_t c0 = getloc(f, j, i.r1, 0);
	jit_gpr_t c1 = getloc(f, j, i.r2, 1);
	jit_reloc_t r = jit_bltr(j, c0, c1);
	struct reloc_helper h = {.r = r, .to = i.r0};
	vect_append(struct reloc_helper, *relocs, &h);
}

static void compile_beqi(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i, struct vec *relocs)
{
	jit_gpr_t r1 = getloc(f, j, i.r1, 0);
	jit_reloc_t r = jit_beqi(j, r1, i.o);
	struct reloc_helper h = {.r = r, .to = i.r0};
	vect_append(struct reloc_helper, *relocs, &h);
}

static void compile_bnei(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i, struct vec *relocs)
{
	jit_gpr_t r1 = getloc(f, j, i.r1, 0);
	jit_reloc_t r = jit_bnei(j, r1, i.o);
	struct reloc_helper h = {.r = r, .to = i.r0};
	vect_append(struct reloc_helper, *relocs, &h);
}

static void compile_bgti(struct ejit_func *f, jit_state_t *j,
                         struct ejit_insn i, struct vec *relocs)
{
	jit_gpr_t r1 = getloc(f, j, i.r1, 0);
	jit_reloc_t r = jit_bgti(j, r1, i.o);
	struct reloc_helper h = {.r = r, .to = i.r0};
	vect_append(struct reloc_helper, *relocs, &h);
}

static void compile_jmp(struct ejit_func *f, jit_state_t *j, struct ejit_insn i,
                        struct vec *relocs)
{
	(void)(f);
	jit_reloc_t r = jit_jmp(j);
	struct reloc_helper h = {.r = r, .to = i.r0};
	vect_append(struct reloc_helper, *relocs, &h);
}

static void compile_retval(struct ejit_func *f, jit_state_t *j,
                           struct ejit_insn i)
{
	jit_gpr_t r0 = getreg(f, i.r0, 0);
	jit_retval(j, r0);
	putloc(f, j, i.r0, r0);
}

static enum jit_operand_abi jit_abi_from(enum ejit_type t)
{
	switch (t) {
	case EJIT_INT8: return JIT_OPERAND_ABI_INT8;
	case EJIT_INT16: return JIT_OPERAND_ABI_INT16;
	case EJIT_INT32: return JIT_OPERAND_ABI_INT32;
	case EJIT_INT64: return JIT_OPERAND_ABI_INT64;
	case EJIT_UINT8: return JIT_OPERAND_ABI_UINT8;
	case EJIT_UINT16: return JIT_OPERAND_ABI_UINT16;
	case EJIT_UINT32: return JIT_OPERAND_ABI_UINT32;
	case EJIT_UINT64: return JIT_OPERAND_ABI_UINT64;
	case EJIT_POINTER: return JIT_OPERAND_ABI_POINTER;
	case EJIT_FLOAT: return JIT_OPERAND_ABI_FLOAT;
	case EJIT_DOUBLE: return JIT_OPERAND_ABI_DOUBLE;
	default:
	}

	abort();
}

static size_t arg_offsetof(enum ejit_type t)
{
	switch (t) {
	case EJIT_INT8: return offsetof(struct ejit_arg, i8);
	case EJIT_INT16: return offsetof(struct ejit_arg, i16);
	case EJIT_INT32: return offsetof(struct ejit_arg, i32);
	case EJIT_INT64: return offsetof(struct ejit_arg, i64);
	case EJIT_UINT8: return offsetof(struct ejit_arg, u8);
	case EJIT_UINT16: return offsetof(struct ejit_arg, u16);
	case EJIT_UINT32: return offsetof(struct ejit_arg, u32);
	case EJIT_UINT64: return offsetof(struct ejit_arg, u64);
	case EJIT_POINTER: return offsetof(struct ejit_arg, p);
	case EJIT_FLOAT: return offsetof(struct ejit_arg, f);
	case EJIT_DOUBLE: return offsetof(struct ejit_arg, d);
	default:
	};

	abort();
}

static jit_off_t arg_offset(struct ejit_insn i)
{
	/* index of ejit_arg in stack and offset of whatever type we're dealing
	 * with */
	return (sizeof(struct ejit_arg) * i.r0) + arg_offsetof(i.r1);
}

static jit_off_t type_offset(struct ejit_insn i)
{
	return (sizeof(struct ejit_arg) * i.r0) + offsetof(struct ejit_arg,
	                                                   type);
}

static void fixup_operands(struct vec *operands, size_t fixup)
{
	foreach_vec(i, *operands) {
		jit_operand_t op = vect_at(jit_operand_t, *operands, i);
		if (op.kind != JIT_OPERAND_KIND_MEM)
			continue;

		op.loc.mem.offset += fixup;
		vect_at(jit_operand_t, *operands, i) = op;
	}
}

static void compile_imm_call(jit_state_t *j, struct vec *src, struct vec *dst,
                             void *addr, size_t argc, jit_operand_t args[argc])
{
	/* each move is type + arg, so twofold */
	size_t movec = vec_len(src) / 2;
	size_t fixup = jit_align_stack(j, movec * sizeof(struct ejit_arg));
	fixup_operands(src, fixup);
	/* note, do not fix up destination! */
	/* remember to move all operands */
	jit_move_operands(j, dst->buf, src->buf, movec * 2);
	jit_calli(j, addr, argc, args);

	jit_shrink_stack(j, fixup);
	vec_reset(src);
	vec_reset(dst);
}

static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
                              size_t size)
{
	jit_begin(j, arena, size);
	size_t gprs = f->gpr >= jit_v_num() ? jit_v_num() : f->gpr;
	size_t fprs = f->fpr >= jit_vf_num() ? jit_vf_num() : f->fpr;
	size_t frame = jit_enter_jit_abi(j, gprs, fprs, 0);

	/* very important, argc we don't really do anything with but JIR_R1
	 * contains the argument stack! */
	jit_load_args_2(j,
	                jit_operand_gpr(JIT_OPERAND_ABI_WORD, JIT_R0),
	                jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_R1));

	size_t stack = jit_align_stack(j, stack_size(f));

	struct vec src = vec_create(sizeof(jit_operand_t));
	struct vec dst = vec_create(sizeof(jit_operand_t));

	struct vec relocs = vec_create(sizeof(struct reloc_helper));
	struct vec labels = vec_create(sizeof(jit_addr_t));
	vec_reserve(&labels, vec_len(&f->insns));

	size_t label = 0;
	foreach_vec(ii, f->insns) {
		/* if we've hit a label, add it to our vector of label addresses */
		if (label < vec_len(&f->labels)) {
			if (vect_at(size_t, f->labels, label) == ii) {
				compile_label(j, ii, &labels);
				label++;
			}
		}

		struct ejit_insn i = vect_at(struct ejit_insn, f->insns, ii);
		switch (i.op) {
		case MOVR: compile_movr(f, j, i); break;
		case MOVI: compile_movi(f, j, i); break;
		case ADDR: compile_addr(f, j, i); break;
		case ADDI: compile_addi(f, j, i); break;
		case SUBR: compile_subr(f, j, i); break;
		case SUBI: compile_subi(f, j, i); break;

		case STXI64: compile_stxi64(f, j, i); break;
		case LDXIU64: compile_ldxiu64(f, j, i); break;

		case EQR: compile_eqr(f, j, i); break;
		case LTR: compile_ltr(f, j, i); break;

		case BLTR: compile_bltr(f, j, i, &relocs); break;
		case BEQI: compile_beqi(f, j, i, &relocs); break;
		case BNEI: compile_bnei(f, j, i, &relocs); break;
		case BGTI: compile_bgti(f, j, i, &relocs); break;
		case JMP: compile_jmp(f, j, i, &relocs); break;

		case ARG: {
			jit_operand_t type =
				jit_operand_imm(JIT_OPERAND_ABI_WORD, i.r1);
			jit_operand_t arg;
			if (i.r0 < jit_v_num()) {
				/* regular register */
				arg = jit_operand_gpr(jit_abi_from(i.r1),
				                      jit_v(i.r2));
			}
			else {
				/* stack location, note that we'll fix up the SP
				 * offset before doing the actual call */
				arg = jit_operand_mem(jit_abi_from(i.r1),
				                      JIT_SP, stack_loc(i.r0));
			}

			vec_append(&src, &type);
			vec_append(&src, &arg);

			jit_operand_t to[2] = {
				jit_operand_mem(JIT_OPERAND_ABI_WORD, JIT_SP,
				                type_offset(i)),
				jit_operand_mem(jit_abi_from(i.r1), JIT_SP,
				                arg_offset(i))
			};

			vec_append(&dst, &to[0]);
			vec_append(&dst, &to[1]);
			break;
		}

		case ESCAPEI: {
			jit_operand_t args[2] = {
				jit_operand_imm(JIT_OPERAND_ABI_WORD,
				                vec_len(&src) / 2),
				jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_SP)
			};
			compile_imm_call(j, &src, &dst, (void *)i.o, 2, args);
			break;
		}

		case CALLI: {
			jit_operand_t args[3] = {
				jit_operand_imm(JIT_OPERAND_ABI_POINTER, i.o),
				jit_operand_imm(JIT_OPERAND_ABI_WORD,
				                vec_len(&src) / 2),
				jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_SP)
			};
			compile_imm_call(j, &src, &dst, ejit_run_func, 3, args);
			break;
		}

		case RETVAL: compile_retval(f, j, i); break;
		case RETR: {
			jit_gpr_t r = getloc(f, j, i.r0, 0);
			/* R0 won't get overwritten by jit_leave_jit_abi */
			jit_movr(j, JIT_R0, r);
			jit_shrink_stack(j, stack);
			jit_leave_jit_abi(j, gprs, fprs, frame);
			jit_retr(j, JIT_R0);
			break;
		}

		case RETI: {
			jit_shrink_stack(j, stack);
			jit_leave_jit_abi(j, gprs, fprs, frame);
			jit_reti(j, i.o);
			break;
		}

		case END: {
			/* 'void' return */
			jit_shrink_stack(j, stack);
			jit_leave_jit_abi(j, gprs, fprs, frame);
			jit_reti(j, 0);
			break;
		}

		case PARAM: {
			/* move from argument stack to location */
			jit_operand_t from = jit_operand_mem(
				jit_abi_from(i.r1),
				JIT_R1,
				arg_offset(i)
				);

			jit_operand_t to;
			if (i.r0 < jit_v_num()) {
				/* regular register */
				to = jit_operand_gpr(jit_abi_from(i.r1),
				                     jit_v(i.r2));
			}
			else {
				/* stack location */
				to = jit_operand_mem(jit_abi_from(i.r1), JIT_SP,
				                     stack_loc(i.r2));
			}

			vec_append(&src, &from);
			vec_append(&dst, &to);
			break;
		}

		case START: {
			/* parameters should be done by now */
			jit_move_operands(j, dst.buf, src.buf, vec_len(&src));
			/* reuse for arguments */
			vec_reset(&dst);
			vec_reset(&src);
			break;
		}

		default: abort();
		}
	}

	foreach_vec(ri, relocs) {
		struct reloc_helper h = vect_at(struct reloc_helper, relocs,
		                                ri);
		jit_addr_t a = vect_at(jit_addr_t, labels, h.to);
		jit_reloc_t r = h.r;

		assert(a);
		jit_patch_there(j, r, a);
	}

	vec_destroy(&src);
	vec_destroy(&dst);
	vec_destroy(&relocs);
	vec_destroy(&labels);

	if (jit_end(j, &size))
		return 0;

	return size;
}

bool ejit_compile(struct ejit_func *f)
{
	if (!init_jit())
		return false;

	jit_state_t *j = jit_new_state(NULL, NULL);
	assert(j);

	void *arena = NULL;
	size_t size = 4096;

	while (1) {
		arena = alloc_arena(size);
		assert(arena);

		size_t required_size = compile_fn_body(f, j, arena, size);
		if (required_size == 0)
			break;

		free_arena(arena, size);
		size = required_size + 4096;
	}

	jit_destroy_state(j);
	f->arena = arena;
	f->size = size;
	return true;
}