From 912c07167705613c6db70e542723c7ec2c06c7ea Mon Sep 17 00:00:00 2001
From: Kimplul <kimi.h.kuparinen@gmail.com>
Date: Sat, 15 Mar 2025 13:16:54 +0200
Subject: experiment with allocating regs on stack in interp

+ Avoids having to lug around an execution context, arguably simplified
  things but now there's no real way to detect when we run out memory
  for regs.
---
 src/common.h          |   9 +----
 src/compile/compile.c |  31 ++++++++++++---
 src/ejit.c            |  79 +++++++++----------------------------
 src/interp.c          | 105 ++++++++++++++------------------------------------
 4 files changed, 74 insertions(+), 150 deletions(-)

(limited to 'src')

diff --git a/src/common.h b/src/common.h
index 6ee0df7..493a89d 100644
--- a/src/common.h
+++ b/src/common.h
@@ -288,13 +288,7 @@ struct ejit_func {
 	void *direct_call;
 	size_t size;
 	size_t prio;
-};
-
-
-struct interp_state {
-	struct gprs gprs;
-	struct fprs fprs;
-	struct args args;
+	size_t max_args;
 };
 
 union interp_ret {
@@ -304,7 +298,6 @@ union interp_ret {
 
 union interp_ret ejit_run(struct ejit_func *f, size_t argc,
                      struct ejit_arg args[argc],
-		     struct interp_state *state,
 		     bool run,
 		     void ***labels_wb);
 
diff --git a/src/compile/compile.c b/src/compile/compile.c
index 490bc43..c979305 100644
--- a/src/compile/compile.c
+++ b/src/compile/compile.c
@@ -19,6 +19,27 @@ struct reloc_helper {
 #define VEC_NAME addrs
 #include "../vec.h"
 
+/* skip assertions since we know they must be valid due to type checking earlier */
+static long checked_run_i(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
+{
+	return ejit_run(f, argc, args, true, NULL).i;
+}
+
+static int64_t checked_run_l(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
+{
+	return ejit_run(f, argc, args, true, NULL).i;
+}
+
+static float checked_run_f(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
+{
+	return ejit_run(f, argc, args, true, NULL).f;
+}
+
+static double checked_run_d(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
+{
+	return ejit_run(f, argc, args, true, NULL).f;
+}
+
 static void *alloc_arena(size_t size, bool im_scawed)
 {
 	return mmap(NULL, size,
@@ -2217,15 +2238,15 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 
 		case CALLI_L:
 #if __WORDSIZE == 64
-			 call = ejit_run_func_l; goto calli;
+			 call = checked_run_l; goto calli;
 #else
 			  assert(0 && "trying to compile calli_l on 32bit arch");
 			  break;
 #endif
 
-		case CALLI_F: { call = ejit_run_func_f; goto calli; }
-		case CALLI_D: { call = ejit_run_func_d; goto calli; }
-		case CALLI_I: { call = ejit_run_func_i; goto calli;
+		case CALLI_F: { call = checked_run_f; goto calli; }
+		case CALLI_D: { call = checked_run_d; goto calli; }
+		case CALLI_I: { call = checked_run_i; goto calli;
 calli:
 			save_caller_save_regs(f, j);
 
@@ -2425,7 +2446,7 @@ static size_t align_up(size_t a, size_t n)
 bool ejit_compile(struct ejit_func *f, bool use_64, bool im_scawed)
 {
 	(void)use_64;
-#if __WORDSIZE == 32
+#if __WORDSIZE != 64
 	/* can't compile 64bit code on 32bit systems, give up early */
 	if (use_64)
 		return false;
diff --git a/src/ejit.c b/src/ejit.c
index 265acea..0a0e7a1 100644
--- a/src/ejit.c
+++ b/src/ejit.c
@@ -339,6 +339,7 @@ struct ejit_func *ejit_create_func(enum ejit_type rtype, size_t argc,
 	f->size = 0;
 	f->prio = 1;
 	f->use_64 = false;
+	f->max_args = 0;
 
 	for (size_t i = 0; i < argc; ++i) {
 		types_append(&f->sign, args[i].type);
@@ -403,7 +404,7 @@ void ejit_select_compile_func(struct ejit_func *f, size_t gpr, size_t fpr,
 
 	void **labels;
 	/* just get labels, don't actually run anything yet */
-	ejit_run(f, 0, NULL, NULL, false, &labels);
+	ejit_run(f, 0, NULL, false, &labels);
 
 	foreach_vec(ii, f->insns) {
 		struct ejit_insn i = *insns_at(&f->insns, ii);
@@ -448,6 +449,7 @@ void ejit_patch(struct ejit_func *f, struct ejit_reloc r, struct ejit_label l)
 void ejit_calli_i(struct ejit_func *s, struct ejit_func *f, size_t argc,
                 const struct ejit_operand args[argc])
 {
+	f->max_args = argc > f->max_args ? argc : f->max_args;
 	check_operands(f, argc, args);
 
 	for (size_t i = 0; i < argc; ++i) {
@@ -467,6 +469,7 @@ void ejit_calli_l(struct ejit_func *s, struct ejit_func *f, size_t argc,
                 const struct ejit_operand args[argc])
 {
 	s->use_64 = true;
+	f->max_args = argc > f->max_args ? argc : f->max_args;
 	check_operands(f, argc, args);
 
 	for (size_t i = 0; i < argc; ++i) {
@@ -485,6 +488,7 @@ void ejit_calli_l(struct ejit_func *s, struct ejit_func *f, size_t argc,
 void ejit_calli_f(struct ejit_func *s, struct ejit_func *f, size_t argc,
                 const struct ejit_operand args[argc])
 {
+	s->max_args = argc > s->max_args ? argc : s->max_args;
 	check_operands(f, argc, args);
 
 	for (size_t i = 0; i < argc; ++i) {
@@ -503,6 +507,7 @@ void ejit_calli_f(struct ejit_func *s, struct ejit_func *f, size_t argc,
 void ejit_calli_d(struct ejit_func *s, struct ejit_func *f, size_t argc,
                 const struct ejit_operand args[argc])
 {
+	s->max_args = argc > s->max_args ? argc : s->max_args;
 	check_operands(f, argc, args);
 
 	for (size_t i = 0; i < argc; ++i) {
@@ -521,6 +526,7 @@ void ejit_calli_d(struct ejit_func *s, struct ejit_func *f, size_t argc,
 void ejit_escapei_i(struct ejit_func *s, ejit_escape_i_t f, size_t argc,
                   const struct ejit_operand args[argc])
 {
+	s->max_args = argc > s->max_args ? argc : s->max_args;
 	for (size_t i = 0; i < argc; ++i) {
 		switch (args[i].kind) {
 		case EJIT_OPERAND_GPR: emit_insn_ar(s, ARG, i, args[i].type, EJIT_GPR(args[i].r)); break;
@@ -538,6 +544,7 @@ void ejit_escapei_l(struct ejit_func *s, ejit_escape_l_t f, size_t argc,
                   const struct ejit_operand args[argc])
 {
 	s->use_64 = true;
+	s->max_args = argc > s->max_args ? argc : s->max_args;
 	for (size_t i = 0; i < argc; ++i) {
 		switch (args[i].kind) {
 		case EJIT_OPERAND_GPR: emit_insn_ar(s, ARG, i, args[i].type, EJIT_GPR(args[i].r)); break;
@@ -554,6 +561,7 @@ void ejit_escapei_l(struct ejit_func *s, ejit_escape_l_t f, size_t argc,
 void ejit_escapei_f(struct ejit_func *s, ejit_escape_f_t f, size_t argc,
                     const struct ejit_operand args[argc])
 {
+	s->max_args = argc > s->max_args ? argc : s->max_args;
 	for (size_t i = 0; i < argc; ++i) {
 		switch (args[i].kind) {
 		case EJIT_OPERAND_GPR: emit_insn_ar(s, ARG, i, args[i].type, EJIT_GPR(args[i].r)); break;
@@ -570,6 +578,7 @@ void ejit_escapei_f(struct ejit_func *s, ejit_escape_f_t f, size_t argc,
 void ejit_escapei_d(struct ejit_func *s, ejit_escape_d_t f, size_t argc,
                     const struct ejit_operand args[argc])
 {
+	s->max_args = argc > s->max_args ? argc : s->max_args;
 	for (size_t i = 0; i < argc; ++i) {
 		switch (args[i].kind) {
 		case EJIT_OPERAND_GPR: emit_insn_ar(s, ARG, i, args[i].type, EJIT_GPR(args[i].r)); break;
@@ -1647,23 +1656,8 @@ struct ejit_reloc ejit_bmsr(struct ejit_func *s, struct ejit_gpr r0,
 	return (struct ejit_reloc){.insn = addr};
 }
 
-static struct interp_state create_interp_state()
-{
-	struct interp_state state;
-	state.gprs = gprs_create();
-	state.fprs = fprs_create();
-	state.args = args_create();
-	return state;
-}
-
-static void destroy_interp_state(struct interp_state state)
-{
-	gprs_destroy(&state.gprs);
-	fprs_destroy(&state.fprs);
-	args_destroy(&state.args);
-}
-
-long ejit_run_func_ctx_i(struct ejit_func *f, size_t argc, struct ejit_arg args[argc], struct interp_state *ctx)
+long ejit_run_func_i(struct ejit_func *f, size_t argc,
+                      struct ejit_arg args[argc])
 {
 	check_args(f, argc, args);
 	assert((f->rtype == EJIT_VOID || ejit_int_type(f->rtype))
@@ -1672,64 +1666,29 @@ long ejit_run_func_ctx_i(struct ejit_func *f, size_t argc, struct ejit_arg args[
 #endif
 		);
 
-	return ejit_run(f, argc, args, ctx, true, NULL).i;
+	return ejit_run(f, argc, args, true, NULL).i;
 }
 
-long ejit_run_func_i(struct ejit_func *f, size_t argc,
+int64_t ejit_run_func_l(struct ejit_func *f, size_t argc,
                       struct ejit_arg args[argc])
-{
-	struct interp_state state = create_interp_state();
-	long r = ejit_run_func_ctx_i(f, argc, args, &state);
-	destroy_interp_state(state);
-	return r;
-}
-
-int64_t ejit_run_func_ctx_l(struct ejit_func *f, size_t argc, struct ejit_arg args[argc], struct interp_state *ctx)
 {
 	check_args(f, argc, args);
 	assert(f->rtype == EJIT_INT64 || f->rtype == EJIT_UINT64);
-	return ejit_run(f, argc, args, ctx, true, NULL).i;
-}
-
-int64_t ejit_run_func_l(struct ejit_func *f, size_t argc,
-                      struct ejit_arg args[argc])
-{
-	struct interp_state state = create_interp_state();
-	int64_t r = ejit_run_func_ctx_l(f, argc, args, &state);
-	destroy_interp_state(state);
-	return r;
+	return ejit_run(f, argc, args, true, NULL).i;
 }
 
-float ejit_run_func_ctx_f(struct ejit_func *f, size_t argc, struct ejit_arg args[argc], struct interp_state *ctx)
+float ejit_run_func_f(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
 {
 	check_args(f, argc, args);
 	assert(f->rtype == EJIT_FLOAT);
-	return ejit_run(f, argc, args, ctx, true, NULL).f;
+	return ejit_run(f, argc, args, true, NULL).f;
 }
 
-float ejit_run_func_f(struct ejit_func *f, size_t argc,
-                       struct ejit_arg args[argc])
-{
-	struct interp_state state = create_interp_state();
-	float r = ejit_run_func_ctx_f(f, argc, args, &state);
-	destroy_interp_state(state);
-	return r;
-}
-
-double ejit_run_func_ctx_d(struct ejit_func *f, size_t argc, struct ejit_arg args[argc], struct interp_state *ctx)
+double ejit_run_func_d(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
 {
 	check_args(f, argc, args);
 	assert(f->rtype == EJIT_DOUBLE);
-	return ejit_run(f, argc, args, ctx, true, NULL).f;
-}
-
-double ejit_run_func_d(struct ejit_func *f, size_t argc,
-                       struct ejit_arg args[argc])
-{
-	struct interp_state state = create_interp_state();
-	double r = ejit_run_func_ctx_d(f, argc, args, &state);
-	destroy_interp_state(state);
-	return r;
+	return ejit_run(f, argc, args, true, NULL).f;
 }
 
 struct ejit_arg ejit_run_func(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
diff --git a/src/interp.c b/src/interp.c
index f8ba927..80a9edc 100644
--- a/src/interp.c
+++ b/src/interp.c
@@ -5,7 +5,7 @@
 /* this is the body of a given ejit_interp function, it assumes there's an
  * external int64_t retval and double retval_f into which it places the value to
  * be returned. Included from src/interp.c */
-union interp_ret ejit_run(struct ejit_func *f, size_t argc, struct ejit_arg args[argc], struct interp_state *state, bool run, void ***labels_wb)
+union interp_ret ejit_run(struct ejit_func *f, size_t paramc, struct ejit_arg params[paramc], bool run, void ***labels_wb)
 {
 	static void *labels[OPCODE_COUNT] = {
 		[MOVI] = &&MOVI,
@@ -227,40 +227,35 @@ union interp_ret ejit_run(struct ejit_func *f, size_t argc, struct ejit_arg args
 	if (f->arena) {
 		if (f->rtype == EJIT_INT64 || f->rtype == EJIT_UINT64)
 			return (union interp_ret){
-				.i = ((ejit_escape_l_t)f->arena)(argc, args)
+				.i = ((ejit_escape_l_t)f->arena)(paramc, params)
 			};
 
 		if (f->rtype == EJIT_DOUBLE)
 			return (union interp_ret){
-				.f = ((ejit_escape_d_t)f->arena)(argc, args)
+				.f = ((ejit_escape_d_t)f->arena)(paramc, params)
 			};
 
 		if (f->rtype == EJIT_FLOAT)
 			return (union interp_ret){
-				.f = ((ejit_escape_f_t)f->arena)(argc, args)
+				.f = ((ejit_escape_f_t)f->arena)(paramc, params)
 			};
 
 		return (union interp_ret){
-			.i = ((ejit_escape_i_t)f->arena)(argc, args)
+			.i = ((ejit_escape_i_t)f->arena)(paramc, params)
 		};
 	}
 
 	int64_t retval = 0; double retval_f = 0.0;
 
-	size_t prev_gprs = gprs_len(&state->gprs);
-	size_t prev_fprs = fprs_len(&state->fprs);
-	size_t prev_argc = args_len(&state->args);
-
-	gprs_reserve(&state->gprs, prev_gprs + gpr_stats_len(&f->gpr));
-	fprs_reserve(&state->fprs, prev_fprs + fpr_stats_len(&f->fpr));
-
 	union fpr {
 		double d;
 		float f;
 	};
-	int64_t *gpr = ((int64_t *)state->gprs.buf) + prev_gprs;
-	union fpr *fpr = ((union fpr *)state->fprs.buf) + prev_fprs;
 
+	size_t argc = 0;
+	int64_t *gpr = alloca(sizeof(int64_t) * gpr_stats_len(&f->gpr));
+	union fpr *fpr = alloca(sizeof(int64_t) * fpr_stats_len(&f->fpr));
+	struct ejit_arg *args = alloca(sizeof(struct ejit_arg) * f->max_args);
 	struct ejit_insn *insns = f->insns.buf;
 
 	/* retval is kind of an unfortunate extra bit of state to keep track of,
@@ -978,25 +973,25 @@ union interp_ret ejit_run(struct ejit_func *f, size_t argc, struct ejit_arg args
 	DISPATCH();
 
 	DO(PARAM);
-	gpr[i.r2] = args[i.r0].u64;
+	gpr[i.r2] = params[i.r0].u64;
 	DISPATCH();
 
 	DO(PARAM_F);
 	if (i.r1 == EJIT_FLOAT)
-		fpr[i.r2].f = args[i.r0].f;
+		fpr[i.r2].f = params[i.r0].f;
 	else
-		fpr[i.r2].d = args[i.r0].d;
+		fpr[i.r2].d = params[i.r0].d;
 
 	DISPATCH();
 
 	DO(ARG);
 	struct ejit_arg a = ejit_build_arg(i.r1, gpr[i.r2]);
-	args_append(&state->args, a);
+	args[argc++] = a;
 	DISPATCH();
 
 	DO(ARG_I);
 	struct ejit_arg a = ejit_build_arg(i.r1, i.o);
-	args_append(&state->args, a);
+	args[argc++] = a;
 	DISPATCH();
 
 	DO(ARG_F);
@@ -1006,7 +1001,7 @@ union interp_ret ejit_run(struct ejit_func *f, size_t argc, struct ejit_arg args
 	else
 		a = ejit_build_arg_f(i.r1, fpr[i.r2].f);
 
-	args_append(&state->args, a);
+	args[argc++] = a;
 	DISPATCH();
 
 	DO(ARG_FI);
@@ -1016,95 +1011,55 @@ union interp_ret ejit_run(struct ejit_func *f, size_t argc, struct ejit_arg args
 	else
 		a = ejit_build_arg_f(i.r1, i.f);
 
-	args_append(&state->args, a);
+	args[argc++] = a;
 	DISPATCH();
 
 	DO(CALLI_I);
 	struct ejit_func *f = i.p;
-	size_t argc = args_len(&state->args) - prev_argc;
-	struct ejit_arg *args = state->args.buf + prev_argc;
-
-	retval = ejit_run(f, argc, args, state, true, NULL).i;
-
-	gpr = state->gprs.buf + prev_gprs;
-	fpr = (union fpr *)state->fprs.buf + prev_fprs;
-	args_shrink(&state->args, prev_argc);
+	retval = ejit_run(f, argc, args, true, NULL).i;
+	argc = 0;
 	DISPATCH();
 
 	DO(CALLI_L);
 	struct ejit_func *f = i.p;
-	size_t argc = args_len(&state->args) - prev_argc;
-	struct ejit_arg *args = state->args.buf + prev_argc;
-
-	retval = ejit_run(f, argc, args, state, true, NULL).i;
-
-	gpr = state->gprs.buf + prev_gprs;
-	fpr = (union fpr *)state->fprs.buf + prev_fprs;
-	args_shrink(&state->args, prev_argc);
+	retval = ejit_run(f, argc, args, true, NULL).i;
+	argc = 0;
 	DISPATCH();
 
 	DO(CALLI_F);
 	struct ejit_func *f = i.p;
-	size_t argc = args_len(&state->args) - prev_argc;
-	struct ejit_arg *args = state->args.buf + prev_argc;
-
-	retval_f = ejit_run(f, argc, args, state, true, NULL).f;
-
-	gpr = state->gprs.buf + prev_gprs;
-	fpr = (union fpr *)state->fprs.buf + prev_fprs;
-	args_shrink(&state->args, prev_argc);
+	retval_f = ejit_run(f, argc, args, true, NULL).f;
+	argc = 0;
 	DISPATCH();
 
 	DO(CALLI_D);
 	struct ejit_func *f = i.p;
-	size_t argc = args_len(&state->args) - prev_argc;
-	struct ejit_arg *args = state->args.buf + prev_argc;
-
-	retval_f = ejit_run(f, argc, args, state, true, NULL).f;
-
-	gpr = state->gprs.buf + prev_gprs;
-	fpr = (union fpr *)state->fprs.buf + prev_fprs;
-	args_shrink(&state->args, prev_argc);
+	retval_f = ejit_run(f, argc, args, true, NULL).f;
+	argc = 0;
 	DISPATCH();
 
 	DO(ESCAPEI_I);
 	ejit_escape_i_t f = i.p;
-	size_t argc = args_len(&state->args) - prev_argc;
-	struct ejit_arg *args = state->args.buf + prev_argc;
-
 	retval = f(argc, args);
-
-	args_shrink(&state->args, prev_argc);
+	argc = 0;
 	DISPATCH();
 
 	DO(ESCAPEI_L);
 	ejit_escape_l_t f = i.p;
-	size_t argc = args_len(&state->args) - prev_argc;
-	struct ejit_arg *args = state->args.buf + prev_argc;
-
 	retval = f(argc, args);
-
-	args_shrink(&state->args, prev_argc);
+	argc = 0;
 	DISPATCH();
 
 	DO(ESCAPEI_F);
 	ejit_escape_f_t f = i.p;
-	size_t argc = args_len(&state->args) - prev_argc;
-	struct ejit_arg *args = state->args.buf + prev_argc;
-
 	retval_f = f(argc, args);
-
-	args_shrink(&state->args, prev_argc);
+	argc = 0;
 	DISPATCH();
 
 	DO(ESCAPEI_D);
 	ejit_escape_d_t f = i.p;
-	size_t argc = args_len(&state->args) - prev_argc;
-	struct ejit_arg *args = state->args.buf + prev_argc;
-
 	retval_f = f(argc, args);
-
-	args_shrink(&state->args, prev_argc);
+	argc = 0;
 	DISPATCH();
 
 	/* dispatch is technically unnecessary for returns, but keep it for
@@ -1144,13 +1099,9 @@ union interp_ret ejit_run(struct ejit_func *f, size_t argc, struct ejit_arg args
 #undef DO
 
 out_float:
-	gprs_shrink(&state->gprs, prev_gprs);
-	fprs_shrink(&state->fprs, prev_fprs);
 	return (union interp_ret){.f = retval_f};
 
 out_int:
-	gprs_shrink(&state->gprs, prev_gprs);
-	fprs_shrink(&state->fprs, prev_fprs);
 	return (union interp_ret){.i = retval};
 
 zero_out:
-- 
cgit v1.2.3