#include <ejit/ejit.h>

#include "common.h"


union interp_ret ejit_interp(struct ejit_func *f, size_t argc, struct ejit_arg args[argc], bool run, void ***labels_wb)
{
	static void *labels[OPCODE_COUNT] = {
		[MOVI] = &&movi,
		[MOVR] = &&movr,
		[MOVR_F] = &&movr_f,

		[ADDR] = &&addr,
		[ADDR_F] = &&addr_f,
		[SUBR] = &&subr,
		[SUBR_F] = &&subr_f,

		[BLTR] = &&bltr,

		[RET] = &&ret,
		[RET_I] = &&ret_i,
		[RET_F] = &&ret_f,
		[RET_FI] = &&ret_fi,

		[ARG] = &&arg,
		[ARG_I] = &&arg_i,
		[ARG_F] = &&arg_f,
		[ARG_FI] = &&arg_fi,

		[PARAM] = &&param,
		[PARAM_F] = &&param_f,

		[CALLI] = &&calli,
		[CALLI_F] = &&calli_f,

		[START] = &&start,
		[END] = &&end,
	};

	if (!run) {
		*labels_wb = labels;
		return (union interp_ret){.r = 0};
	}

	/** @todo this can be optimized by for example using a common buffer
	 * 'stack' */
	long *gpr = malloc(f->gpr * sizeof(long));
	long *fpr = malloc(f->fpr * sizeof(long));
	struct vec call_args = vec_create(sizeof(struct ejit_arg));
	struct ejit_insn *insns = f->insns.buf;

	struct ejit_insn i;
	long retval = 0; double retval_f = 0.;
	size_t pc = 0;

#define DO(x) x: { i = insns[pc];
#define JUMP(a) goto *insns[pc = a].addr;
#define DISPATCH() } goto *insns[++pc].addr;

	DO(start);
	DISPATCH();

	DO(end);
	goto out_int;
	DISPATCH();

	DO(movi);
	gpr[i.r0] = i.o;
	DISPATCH();

	DO(movr);
	gpr[i.r0] = gpr[i.r1];
	DISPATCH();

	DO(movr_f);
	fpr[i.r0] = fpr[i.r1];
	DISPATCH();

	DO(addr);
	gpr[i.r0] = gpr[i.r1] + gpr[i.r2];
	DISPATCH();

	DO(addr_f);
	fpr[i.r0] = fpr[i.r1] + fpr[i.r2];
	DISPATCH();

	DO(subr);
	gpr[i.r0] = gpr[i.r1] - gpr[i.r2];
	DISPATCH();

	DO(subr_f);
	fpr[i.r0] = fpr[i.r1] - fpr[i.r2];
	DISPATCH();

	DO(bltr);
	if (gpr[i.r0] < gpr[i.r1])
		JUMP(i.o);

	DISPATCH();

	DO(param);
	gpr[i.r0] = args[i.o].l;
	DISPATCH();

	DO(param_f);
	fpr[i.r0] = args[i.o].d;
	DISPATCH();

	DO(arg);
	struct ejit_arg a = ejit_build_arg(i.r1, gpr[i.r0]);
	vec_append(&call_args, &a);
	DISPATCH();

	DO(arg_i);
	struct ejit_arg a = ejit_build_arg(i.r1, i.o);
	vec_append(&call_args, &a);
	DISPATCH();

	DO(arg_f);
	struct ejit_arg a = ejit_build_arg_f(i.r1, fpr[i.r0]);
	vec_append(&call_args, &a);
	DISPATCH();

	DO(arg_fi);
	struct ejit_arg a = ejit_build_arg_f(i.r1, i.d);
	vec_append(&call_args, &a);
	DISPATCH();

	DO(calli);
	struct ejit_func *f = i.p;
	retval = ejit_run_func(f, vec_len(&call_args), call_args.buf);
	vec_reset(&call_args);
	DISPATCH();

	DO(calli_f);
	struct ejit_func *f = i.p;
	retval_f = ejit_run_func_f(f, vec_len(&call_args), call_args.buf);
	vec_reset(&call_args);
	DISPATCH();

	/* dispatch is technically unnecessary for returns, but keep it for
	 * symmetry */
	DO(ret);
	retval = gpr[i.r0];
	goto out_int;
	DISPATCH();

	DO(ret_i);
	retval = i.o;
	goto out_int;
	DISPATCH();

	DO(ret_f);
	retval_f = fpr[i.r0];
	goto out_float;
	DISPATCH();

	DO(ret_fi);
	retval_f = i.d;
	goto out_float;
	DISPATCH();

#undef DISPATCH
#undef JUMP
#undef DO

out_int:
	vec_destroy(&call_args);
	free(gpr);
	free(fpr);
	return (union interp_ret){.r = retval};

out_float:
	vec_destroy(&call_args);
	free(gpr);
	free(fpr);
	return (union interp_ret){.d = retval_f};
}

long ejit_run_func(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
{
	assert(f->rtype == EJIT_VOID || ejit_int_type(f->rtype));
	return ejit_interp(f, argc, args, true, NULL).r;
}

double ejit_run_func_f(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
{
	assert(ejit_float_type(f->rtype));
	return ejit_interp(f, argc, args, true, NULL).d;
}