#include <ejit/ejit.h>
#include <math.h>

#include "common.h"

union interp_ret ejit_interp(struct ejit_func *f, size_t argc,
                             struct ejit_arg args[argc],
                             struct interp_state *state, bool run,
                             void ***labels_wb)
{
	static void *labels[OPCODE_COUNT] = {
		[MOVI] = &&MOVI,
		[MOVR] = &&MOVR,
		[MOVR_F] = &&MOVR_F,

		[ADDR] = &&ADDR,
		[ADDR_F] = &&ADDR_F,
		[ADDI] = &&ADDI,

		[ABSR_F] = &&ABSR_F,

		[SUBR] = &&SUBR,
		[SUBR_F] = &&SUBR_F,
		[SUBI] = &&SUBI,

		[MULR] = &&MULR,
		[DIVR] = &&DIVR,

		[ANDR] = &&ANDR,
		[ANDI] = &&ANDI,

		[EQR] = &&EQR,
		[LTR] = &&LTR,

		[STXI64] = &&STXI64,
		[LDXIU64] = &&LDXIU64,

		[BLTR] = &&BLTR,
		[BNEI] = &&BNEI,
		[BEQR] = &&BEQR,
		[BEQI] = &&BEQI,
		[BGTI] = &&BGTI,

		[JMP] = &&JMP,

		[RETR] = &&RETR,
		[RETI] = &&RETI,
		[RETR_F] = &&RETR_F,
		[RETI_F] = &&RETI_F,

		[RETVAL] = &&RETVAL,

		[ARG] = &&ARG,
		[ARG_I] = &&ARG_I,
		[ARG_F] = &&ARG_F,
		[ARG_FI] = &&ARG_FI,

		[PARAM] = &&PARAM,
		[PARAM_F] = &&PARAM_F,

		[CALLI] = &&CALLI,
		[CALLI_F] = &&CALLI_F,
		[ESCAPEI] = &&ESCAPEI,

		[START] = &&START,
		[END] = &&END,
	};

	if (!run) {
		*labels_wb = labels;
		return (union interp_ret){.r = 0};
	}

	size_t prev_gprs = vec_len(&state->gprs);
	size_t prev_fprs = vec_len(&state->fprs);
	size_t prev_argc = vec_len(&state->args);

	vec_reserve(&state->gprs, prev_gprs + f->gpr);
	vec_reserve(&state->fprs, prev_fprs + f->fpr);

	long *gpr = ((long *)state->gprs.buf) + prev_gprs;
	double *fpr = ((double *)state->fprs.buf) + prev_fprs;

	struct ejit_insn *insns = f->insns.buf;

	/* retval is kind of an unfortunate extra bit of state to keep track of,
	 * but having call and return value separated is pretty convenient for
	 * void calls so I guess I don't mind? */
	int64_t retval = 0; double retval_f = 0.;
	size_t pc = 0;

#define DO(x) x : { struct ejit_insn i = insns[pc]; (void)i;
#define JUMP(a) goto *insns[pc = a].addr;
#define DISPATCH() } goto *insns[++pc].addr;

	JUMP(0);

	DO(START);
	DISPATCH();

	DO(END);
	goto out_int;
	DISPATCH();

	DO(MOVI);
	gpr[i.r0] = i.o;
	DISPATCH();

	DO(MOVR);
	gpr[i.r0] = gpr[i.r1];
	DISPATCH();

	DO(MOVR_F);
	fpr[i.r0] = fpr[i.r1];
	DISPATCH();

	DO(ADDR);
	gpr[i.r0] = gpr[i.r1] + gpr[i.r2];
	DISPATCH();

	DO(ADDR_F);
	fpr[i.r0] = fpr[i.r1] + fpr[i.r2];
	DISPATCH();

	DO(ADDI);
	gpr[i.r0] = gpr[i.r1] + i.o;
	DISPATCH();

	DO(ABSR_F);
	fpr[i.r0] = fabs(fpr[i.r1]);
	DISPATCH();

	DO(SUBR);
	gpr[i.r0] = gpr[i.r1] - gpr[i.r2];
	DISPATCH();

	DO(SUBR_F);
	fpr[i.r0] = fpr[i.r1] - fpr[i.r2];
	DISPATCH();

	DO(SUBI);
	gpr[i.r0] = gpr[i.r1] - i.o;
	DISPATCH();

	DO(MULR);
	gpr[i.r0] = gpr[i.r1] * gpr[i.r2];
	DISPATCH();

	DO(DIVR);
	gpr[i.r0] = gpr[i.r1] / gpr[i.r2];
	DISPATCH();

	DO(ANDR);
	gpr[i.r0] = gpr[i.r1] & gpr[i.r2];
	DISPATCH();

	DO(ANDI);
	gpr[i.r0] = gpr[i.r1] & i.o;
	DISPATCH();

	DO(EQR);
	gpr[i.r0] = gpr[i.r1] == gpr[i.r2];
	DISPATCH();

	DO(LTR);
	gpr[i.r0] = gpr[i.r1] < gpr[i.r2];
	DISPATCH();

	DO(STXI64);
	int64_t *addr = (int64_t *)(gpr[i.r1] + i.o);
	*addr = gpr[i.r0];
	DISPATCH();

	DO(LDXIU64);
	uint64_t *addr = (uint64_t *)(gpr[i.r1] + i.o);
	gpr[i.r0] = *addr;
	DISPATCH();

	DO(BLTR);
	if (gpr[i.r1] < gpr[i.r2])
		JUMP(i.r0);

	DISPATCH();

	DO(BNEI);
	if (gpr[i.r1] != i.o)
		JUMP(i.r0);

	DISPATCH();

	DO(BEQR);
	if (gpr[i.r1] == gpr[i.r2])
		JUMP(i.r0);

	DISPATCH();

	DO(BEQI);
	if (gpr[i.r1] == i.o)
		JUMP(i.r0);

	DISPATCH();

	DO(BGTI);
	if (gpr[i.r1] > i.o)
		JUMP(i.r0);

	DISPATCH();

	DO(JMP);
	JUMP(i.r0);
	DISPATCH();

	DO(RETVAL);
	gpr[i.r0] = retval;
	DISPATCH();

	DO(PARAM);
	gpr[i.r2] = args[i.r0].u64;
	DISPATCH();

	DO(PARAM_F);
	fpr[i.r2] = args[i.r0].d;
	DISPATCH();

	DO(ARG);
	struct ejit_arg a = ejit_build_arg(i.r1, gpr[i.r2]);
	vec_append(&state->args, &a);
	DISPATCH();

	DO(ARG_I);
	struct ejit_arg a = ejit_build_arg(i.r1, i.o);
	vec_append(&state->args, &a);
	DISPATCH();

	DO(ARG_F);
	struct ejit_arg a = ejit_build_arg_f(i.r1, fpr[i.r2]);
	vec_append(&state->args, &a);
	DISPATCH();

	DO(ARG_FI);
	struct ejit_arg a = ejit_build_arg_f(i.r1, i.d);
	vec_append(&state->args, &a);
	DISPATCH();

	DO(CALLI);
	struct ejit_func *f = i.p;
	size_t argc = vec_len(&state->args) - prev_argc;
	struct ejit_arg *args = ((struct ejit_arg *)state->args.buf) +
	                        prev_argc;

	retval = ejit_run_interp(f, argc, args, state);

	gpr = ((long *)state->gprs.buf) + prev_gprs;
	fpr = ((double *)state->fprs.buf) + prev_fprs;
	vec_shrink(&state->args, prev_argc);
	DISPATCH();

	DO(CALLI_F);
	DISPATCH();

	DO(ESCAPEI);
	ejit_escape_t f = i.p;
	size_t argc = vec_len(&state->args) - prev_argc;
	struct ejit_arg *args = ((struct ejit_arg *)state->args.buf) +
	                        prev_argc;

	retval = f(argc, args);

	vec_shrink(&state->args, prev_argc);
	DISPATCH();

	/* dispatch is technically unnecessary for returns, but keep it for
	 * symmetry */
	DO(RETR);
	retval = gpr[i.r0];
	goto out_int;
	DISPATCH();

	DO(RETI);
	retval = i.o;
	goto out_int;
	DISPATCH();

	DO(RETR_F);
	retval_f = fpr[i.r0];
	goto out_float;
	DISPATCH();

	DO(RETI_F);
	retval_f = i.d;
	goto out_float;
	DISPATCH();

#undef DISPATCH
#undef JUMP
#undef DO

out_int:
	vec_shrink(&state->gprs, prev_gprs);
	vec_shrink(&state->fprs, prev_fprs);
	return (union interp_ret){.r = retval};

out_float:
	vec_shrink(&state->gprs, prev_gprs);
	vec_shrink(&state->fprs, prev_fprs);
	return (union interp_ret){.d = retval_f};
}