From 6824dd4b1ee22184f0e600115db3998924ed39d6 Mon Sep 17 00:00:00 2001
From: Kimplul <kimi.h.kuparinen@gmail.com>
Date: Wed, 9 Apr 2025 19:56:33 +0300
Subject: initial tail call stuff

---
 src/compile/compile.c | 106 ++++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 89 insertions(+), 17 deletions(-)

(limited to 'src/compile')

diff --git a/src/compile/compile.c b/src/compile/compile.c
index dcf662b..54d79f2 100644
--- a/src/compile/compile.c
+++ b/src/compile/compile.c
@@ -22,22 +22,22 @@ struct reloc_helper {
 /* skip assertions since we know they must be valid due to type checking earlier */
 static long checked_run_i(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
 {
-	return ejit_run(f, argc, args, true, NULL).i;
+	return ejit_run(f, argc, args, NULL).i;
 }
 
 static int64_t checked_run_l(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
 {
-	return ejit_run(f, argc, args, true, NULL).i;
+	return ejit_run(f, argc, args, NULL).i;
 }
 
 static float checked_run_f(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
 {
-	return ejit_run(f, argc, args, true, NULL).f;
+	return ejit_run(f, argc, args, NULL).f;
 }
 
 static double checked_run_d(struct ejit_func *f, size_t argc, struct ejit_arg args[argc])
 {
-	return ejit_run(f, argc, args, true, NULL).f;
+	return ejit_run(f, argc, args, NULL).f;
 }
 
 static void *alloc_arena(size_t size, bool im_scawed)
@@ -47,6 +47,11 @@ static void *alloc_arena(size_t size, bool im_scawed)
 	            MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
 }
 
+static void assert_helper(const char *msg)
+{
+	assert(false && msg);
+}
+
 static void free_arena(void *arena, size_t size)
 {
 	munmap(arena, size);
@@ -2042,8 +2047,6 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 	struct addrs addrs = addrs_create();
 	addrs_reserve(&addrs, insns_len(&f->insns));
 
-	void *call = NULL;
-
 	size_t label = 0;
 	foreach_vec(ii, f->insns) {
 		/* if we've hit a label, add it to our vector of label addresses */
@@ -2502,21 +2505,64 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 			break;
 		}
 
-		case EJIT_OP_CALLI_L:
-#if __WORDSIZE == 64
-			 call = checked_run_l; goto calli;
-#else
-			  assert(0 && "trying to compile calli_l on 32bit arch");
-			  break;
+		case EJIT_OP_TAILR: {
+			/* this is admittedly a slightly roundabout way of
+			 * implementing tail calls and is arguably not the most
+			 * performant way (if it works at all, heh) but for now
+			 * I'm more interested in functionality than raw
+			 * performance. Currently only supports two gpr
+			 * registers, but should be fairly easy to extend with
+			 * fprs as well */
+
+			assert(operands_len(&direct) <= 2);
+			jit_gpr_t r = getloc(f, j, i.r1, 0);
+			jit_ldxi(j, JIT_R0, r, offsetof(struct ejit_func, direct_call));
+#if defined(DEBUG)
+			jit_reloc_t assert_reloc = jit_bnei(j, JIT_R0, 0); /* null */
+			jit_calli_1(j, assert_helper,
+					jit_operand_imm(JIT_OPERAND_ABI_POINTER,
+						(jit_imm_t)"trying to tail call interpreted function"));
+			jit_patch_here(j, assert_reloc);
 #endif
+			jit_operand_t regs[2] = {
+				jit_operand_gpr(JIT_OPERAND_ABI_WORD, JIT_R1),
+				jit_operand_gpr(JIT_OPERAND_ABI_WORD, JIT_R2)
+			};
+			jit_move_operands(j, regs, direct.buf, operands_len(&direct));
+
+			/* with args safely in registers, reset stack/state
+			 * while avoiding overwriting the call target */
+			jit_gpr_t tmp = get_callr_temp(j);
+			jit_movr(j, tmp, JIT_R0);
+
+			int frame_size = j->frame_size;
+			jit_shrink_stack(j, stack);
+			jit_leave_jit_abi(j, gprs, fprs, frame);
+
+			/* now move args into place */
+			jit_operand_t args[2] = {};
+			foreach_vec(oi, direct) {
+				args[oi] = *operands_at(&direct, oi);
+			}
 
-		case EJIT_OP_CALLI_F: { call = checked_run_f; goto calli; }
-		case EJIT_OP_CALLI_D: { call = checked_run_d; goto calli; }
-		case EJIT_OP_CALLI_I: { call = checked_run_i; goto calli;
-calli:
+			jit_locate_args(j, operands_len(&direct), args);
+			jit_move_operands(j, args, regs, operands_len(&direct));
+			jit_jmpr(j, tmp);
+			j->frame_size = frame_size;
+
+			operands_reset(&src);
+			operands_reset(&dst);
+			operands_reset(&direct);
+			break;
+		}
+
+		case EJIT_OP_CALLI: {
 			save_caller_save_regs(f, j);
 
 			struct ejit_func *f = (struct ejit_func *)(uintptr_t)i.o;
+#if __WORDSIZE != 64
+			assert(f->rtype != EJIT_INT64 && f->rtype != EJIT_UINT64);
+#endif
 			if (f && f->direct_call) {
 				jit_calli(j, f->direct_call, operands_len(&direct), direct.buf);
 				restore_caller_save_regs(f, j);
@@ -2535,6 +2581,16 @@ calli:
 				 * argument stack address */
 				jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_R0)
 			};
+
+			void *call = NULL;
+			switch (f->rtype) {
+			case EJIT_INT64:
+			case EJIT_UINT64: call = checked_run_l; break;
+			case EJIT_FLOAT: call = checked_run_f; break;
+			case EJIT_DOUBLE: call = checked_run_d; break;
+			default: call = checked_run_i; break;
+			}
+
 			compile_imm_call(j, &src, &dst, call, 3, args);
 			restore_caller_save_regs(f, j);
 
@@ -2552,39 +2608,55 @@ calli:
 			jit_gpr_t r = getloc(f, j, i.r1, 0);
 			/* R0 won't get overwritten by jit_leave_jit_abi */
 			jit_movr(j, JIT_R0, r);
+
+			/* keep track of frame size so we can continue
+			 * generating code after 'leaving' the ABI. Bit of a
+			 * hack, should maybe codify this better in the
+			 * lightening API? */
+			int frame_size = j->frame_size;
 			jit_shrink_stack(j, stack);
 			jit_leave_jit_abi(j, gprs, fprs, frame);
 			jit_retr(j, JIT_R0);
+			j->frame_size = frame_size;
 			break;
 		}
 
 		case EJIT_OP_RETR_F: {
 			jit_fpr_t r = getloc_f(f, j, i.r1, 0);
 			jit_movr_f(j, JIT_F0, r);
+
+			int frame_size = j->frame_size;
 			jit_shrink_stack(j, stack);
 			jit_leave_jit_abi(j, gprs, fprs, frame);
 			jit_retr_f(j, JIT_F0);
+			j->frame_size = frame_size;
 			break;
 		}
 
 		case EJIT_OP_RETR_D: {
 			jit_fpr_t r = getloc_d(f, j, i.r1, 0);
 			jit_movr_d(j, JIT_F0, r);
+
+			int frame_size = j->frame_size;
 			jit_shrink_stack(j, stack);
 			jit_leave_jit_abi(j, gprs, fprs, frame);
 			jit_retr_d(j, JIT_F0);
+			j->frame_size = frame_size;
 			break;
 		}
 
 		case EJIT_OP_RETI: {
+			int frame_size = j->frame_size;
 			jit_shrink_stack(j, stack);
 			jit_leave_jit_abi(j, gprs, fprs, frame);
 			jit_reti(j, i.o);
+			j->frame_size = frame_size;
 			break;
 		}
 
 		case EJIT_OP_END: {
-			/* 'void' return */
+			/* 'void' return, must be last thing in function so no
+			 * need to keep track of frame size */
 			jit_shrink_stack(j, stack);
 			jit_leave_jit_abi(j, gprs, fprs, frame);
 			jit_reti(j, 0);
-- 
cgit v1.2.3