#include <ejit/ejit.h>
#include "../../deps/lightening/lightening/lightening.c"
#include "../common.h"

static void *alloc_arena(size_t size)
{
	return mmap(NULL, size,
			PROT_EXEC | PROT_READ | PROT_WRITE,
			MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
}

static void free_arena(void *arena, size_t size)
{
	munmap(arena, size);
}

static size_t grploc_count(struct ejit_func *f)
{
	return f->gpr >= jit_v_num() ? 0 : f->gpr - jit_v_num();
}

static size_t frploc_count(struct ejit_func *f)
{
	return f->fpr >= jit_vf_num() ? 0 : f->fpr - jit_vf_num();
}

static size_t stack_size(struct ejit_func *f)
{
	return grploc_count(f) * sizeof(jit_uword_t)
		+ frploc_count(f) * sizeof(jit_float64_t);
}

static jit_off_t stack_loc(struct ejit_func *f, size_t l)
{
	assert(l >= jit_v_num());
	return (l - jit_v_num()) * sizeof(jit_uword_t);
}

static jit_off_t stack_loc_f(struct ejit_func *f, size_t l)
{
	assert(l >= jit_vf_num());
	return grploc_count(f) * sizeof(jit_uword_t)
		+ (l - jit_vf_num()) * sizeof(jit_float64_t);
}

struct reloc_helper {
	jit_reloc_t r;
	size_t to;
};

static jit_gpr_t getloc(struct ejit_func *f, jit_state_t *j, size_t l, size_t i)
{
	if (l < jit_v_num())
		return jit_v(l);

	jit_ldxi(j, jit_r(i), JIT_SP, stack_loc(f, l));
	return jit_r(i);
}

static jit_gpr_t getreg(struct ejit_func *f, size_t l, size_t i)
{
	if (l < jit_v_num())
		return jit_v(l);

	return jit_r(i);
}

static void putloc(struct ejit_func *f, jit_state_t *j, size_t l, jit_gpr_t r)
{
	if (l < jit_v_num()) {
		assert(jit_v(l).regno == r.regno);
		return;
	}

	jit_stxi(j, stack_loc(f, l), JIT_SP, r);
}

static void compile_label(jit_state_t *j, size_t ii, struct vec *labels)
{
	vect_at(jit_addr_t, *labels, ii) = jit_address(j);
}

static void compile_movi(struct ejit_func *f, jit_state_t *j, struct ejit_insn i)
{
	jit_gpr_t r = getreg(f, i.r0, 0);
	jit_movi(j, r, i.o);
	putloc(f, j, i.r0, r);
}

static void compile_addr(struct ejit_func *f, jit_state_t *j, struct ejit_insn i)
{
	jit_gpr_t dst = getreg(f, i.r0, 0);
	jit_gpr_t src0 = getloc(f, j, i.r1, 1);
	jit_gpr_t src1 = getloc(f, j, i.r2, 2);
	jit_addr(j, dst, src0, src1);
	putloc(f, j, i.r0, dst);
}

static void compile_bltr(struct ejit_func *f, jit_state_t *j, struct ejit_insn i, struct vec *relocs)
{
	jit_gpr_t c0 = getloc(f, j, i.r0, 0);
	jit_gpr_t c1 = getloc(f, j, i.r1, 1);
	jit_reloc_t r = jit_bltr(j, c0, c1);
	struct reloc_helper h = {.r = r, .to = i.o};
	vect_append(struct reloc_helper, *relocs, &h);
}

static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena, size_t size)
{
	jit_begin(j, arena, size);
	size_t gprs = f->gpr >= jit_v_num() ? jit_v_num() : f->gpr;
	size_t fprs = f->fpr >= jit_vf_num() ? jit_vf_num() : f->fpr;
	size_t frame = jit_enter_jit_abi(j, gprs, fprs, 0);

	size_t stack = jit_align_stack(j, stack_size(f));

	struct vec relocs = vec_create(sizeof(struct reloc_helper));
	struct vec labels = vec_create(sizeof(jit_addr_t));
	vec_reserve(&labels, vec_len(&f->insns));

	foreach_vec(ii, f->insns) {
		struct ejit_insn i = vect_at(struct ejit_insn, f->insns, ii);
		switch (i.op) {
		case MOVI: compile_movi(f, j, i); break;
		case ADDR: compile_addr(f, j, i); break;

		case BLTR: compile_bltr(f, j, i, &relocs); break;

		case LABEL: compile_label(j, ii, &labels); break;
		case RET: {
			jit_gpr_t r = getloc(f, j, i.r0, 0);
			/* R0 won't get overwritten by jit_leave_jit_abi */
			jit_movr(j, JIT_R0, r);
			jit_shrink_stack(j, stack);
			jit_leave_jit_abi(j, gprs, fprs, frame);
			jit_retr(j, JIT_R0);
			break;
		}

		case START: continue;
		case END: continue;
		default: abort();
		}
	}

	foreach_vec(ri, relocs) {
		struct reloc_helper h = vect_at(struct reloc_helper, relocs, ri);
		jit_addr_t a = vect_at(jit_addr_t, labels, h.to);
		jit_reloc_t r = h.r;

		assert(a);
		jit_patch_there(j, r, a);
	}

	vec_destroy(&relocs);
	vec_destroy(&labels);

	if (jit_end(j, &size))
		return 0;

	return size;
}

bool ejit_compile(struct ejit_func *f)
{
	if (!init_jit())
		return false;

	jit_state_t *j = jit_new_state(NULL, NULL);
	assert(j);

	void *arena = NULL;
	size_t size = 4096;

	while (1) {
		arena = alloc_arena(size);
		assert(arena);

		size_t required_size = compile_fn_body(f, j, arena, size);
		if (required_size == 0)
			break;

		free_arena(arena, size);
		size = required_size + 4096;
	}

	jit_destroy_state(j);
	f->arena = arena;
	f->size = size;
	return true;
}