#include <ejit/ejit.h>

#include "common.h"

union interp_ret ejit_interp(struct ejit_func *f, size_t argc, struct ejit_arg args[argc], bool run, void ***labels_wb)
{
	static void *labels[OPCODE_COUNT] = {
		[MOVI] = &&MOVI,
		[MOVR] = &&MOVR,
		[MOVR_F] = &&MOVR_F,

		[ADDR] = &&ADDR,
		[ADDR_F] = &&ADDR_F,
		[ADDI] = &&ADDI,

		[SUBR] = &&SUBR,
		[SUBR_F] = &&SUBR_F,
		[SUBI] = &&SUBI,

		[MULR] = &&MULR,
		[DIVR] = &&DIVR,

		[EQR] = &&EQR,
		[LTR] = &&LTR,

		[STXI64] = &&STXI64,
		[LDXIU64] = &&LDXIU64,

		[BLTR] = &&BLTR,
		[BNEI] = &&BNEI,
		[BEQI] = &&BEQI,
		[BGTI] = &&BGTI,

		[JMP] = &&JMP,

		[RET] = &&RET,
		[RET_I] = &&RET_I,
		[RET_F] = &&RET_F,
		[RET_FI] = &&RET_FI,

		[RETVAL] = &&RETVAL,

		[ARG] = &&ARG,
		[ARG_I] = &&ARG_I,
		[ARG_F] = &&ARG_F,
		[ARG_FI] = &&ARG_FI,

		[PARAM] = &&PARAM,
		[PARAM_F] = &&PARAM_F,

		[CALLI] = &&CALLI,
		[CALLI_F] = &&CALLI_F,
		[ESCAPEI] = &&ESCAPEI,

		[START] = &&START,
		[END] = &&END,
	};

	if (!run) {
		*labels_wb = labels;
		return (union interp_ret){.r = 0};
	}

	long *gpr = malloc(f->gpr * sizeof(long));
	double *fpr = malloc(f->fpr * sizeof(double));

	struct vec call_args = vec_create(sizeof(struct ejit_arg));
	struct ejit_insn *insns = f->insns.buf;

	/* retval is kind of an unfortunate extra bit of state to keep track of,
	 * but having call and return value separated is pretty convenient for
	 * void calls so I guess I don't mind? */
	long retval = 0; double retval_f = 0.;
	size_t pc = 0;

#define DO(x) x: { struct ejit_insn i = insns[pc]; (void)i;
#define JUMP(a) goto *insns[pc = a].addr;
#define DISPATCH() } goto *insns[++pc].addr;

	JUMP(0);

	DO(START);
	DISPATCH();

	DO(END);
	goto out_int;
	DISPATCH();

	DO(MOVI);
	gpr[i.r0] = i.o;
	DISPATCH();

	DO(MOVR);
	gpr[i.r0] = gpr[i.r1];
	DISPATCH();

	DO(MOVR_F);
	fpr[i.r0] = fpr[i.r1];
	DISPATCH();

	DO(ADDR);
	gpr[i.r0] = gpr[i.r1] + gpr[i.r2];
	DISPATCH();

	DO(ADDR_F);
	fpr[i.r0] = fpr[i.r1] + fpr[i.r2];
	DISPATCH();

	DO(ADDI);
	gpr[i.r0] = gpr[i.r1] + i.o;
	DISPATCH();

	DO(SUBR);
	gpr[i.r0] = gpr[i.r1] - gpr[i.r2];
	DISPATCH();

	DO(SUBR_F);
	fpr[i.r0] = fpr[i.r1] - fpr[i.r2];
	DISPATCH();

	DO(SUBI);
	gpr[i.r0] = gpr[i.r1] - i.o;
	DISPATCH();

	DO(MULR);
	gpr[i.r0] = gpr[i.r1] * gpr[i.r2];
	DISPATCH();

	DO(DIVR);
	gpr[i.r0] = gpr[i.r1] / gpr[i.r2];
	DISPATCH();

	DO(EQR);
	gpr[i.r0] = gpr[i.r1] == gpr[i.r2];
	DISPATCH();

	DO(LTR);
	gpr[i.r0] = gpr[i.r1] < gpr[i.r2];
	DISPATCH();

	DO(STXI64);
	int64_t *addr = (int64_t *)(gpr[i.r1] + i.o);
	*addr = gpr[i.r0];
	DISPATCH();

	DO(LDXIU64);
	uint64_t *addr = (uint64_t *)(gpr[i.r1] + i.o);
	gpr[i.r0] = *addr;
	DISPATCH();

	DO(BLTR);
	if (gpr[i.r1] < gpr[i.r2])
		JUMP(i.r0);

	DISPATCH();

	DO(BNEI);
	if (gpr[i.r1] != i.o)
		JUMP(i.r0);

	DISPATCH();

	DO(BEQI);
	if (gpr[i.r1] == i.o)
		JUMP(i.r0);

	DISPATCH();

	DO(BGTI);
	if (gpr[i.r1] > i.o)
		JUMP(i.r0);

	DISPATCH();

	DO(JMP);
	JUMP(i.r0);
	DISPATCH();

	DO(RETVAL);
	gpr[i.r0] = retval;
	DISPATCH();

	DO(PARAM);
	gpr[i.r2] = args[i.r0].l;
	DISPATCH();

	DO(PARAM_F);
	fpr[i.r2] = args[i.r0].d;
	DISPATCH();

	DO(ARG);
	struct ejit_arg a = ejit_build_arg(i.r1, gpr[i.r2]);
	vec_append(&call_args, &a);
	DISPATCH();

	DO(ARG_I);
	struct ejit_arg a = ejit_build_arg(i.r1, i.o);
	vec_append(&call_args, &a);
	DISPATCH();

	DO(ARG_F);
	struct ejit_arg a = ejit_build_arg_f(i.r1, fpr[i.r2]);
	vec_append(&call_args, &a);
	DISPATCH();

	DO(ARG_FI);
	struct ejit_arg a = ejit_build_arg_f(i.r1, i.d);
	vec_append(&call_args, &a);
	DISPATCH();

	DO(CALLI);
	struct ejit_func *f = i.p;
	retval = ejit_run_func(f, vec_len(&call_args), call_args.buf);
	vec_reset(&call_args);
	DISPATCH();

	DO(CALLI_F);
	struct ejit_func *f = i.p;
	retval_f = ejit_run_func_f(f, vec_len(&call_args), call_args.buf);
	vec_reset(&call_args);
	DISPATCH();

	DO(ESCAPEI);
	ejit_escape_t f = i.p;
	retval = f(vec_len(&call_args), call_args.buf);
	vec_reset(&call_args);
	DISPATCH();

	/* dispatch is technically unnecessary for returns, but keep it for
	 * symmetry */
	DO(RET);
	retval = gpr[i.r0];
	goto out_int;
	DISPATCH();

	DO(RET_I);
	retval = i.o;
	goto out_int;
	DISPATCH();

	DO(RET_F);
	retval_f = fpr[i.r0];
	goto out_float;
	DISPATCH();

	DO(RET_FI);
	retval_f = i.d;
	goto out_float;
	DISPATCH();

#undef DISPATCH
#undef JUMP
#undef DO

out_int:
	free(gpr);
	free(fpr);
	vec_destroy(&call_args);
	return (union interp_ret){.r = retval};

out_float:
	free(gpr);
	free(fpr);
	vec_destroy(&call_args);
	return (union interp_ret){.d = retval_f};
}