From 057131cb20fb1c46e90adecfb4a16eb62f100580 Mon Sep 17 00:00:00 2001
From: Kimplul <kimi.h.kuparinen@gmail.com>
Date: Wed, 9 Apr 2025 20:20:48 +0300
Subject: add taili

---
 include/ejit/ejit.h   |  3 +++
 src/common.h          |  2 ++
 src/compile/compile.c | 35 +++++++++++++++++++++++++++++++++++
 src/ejit.c            | 44 +++++++++++++++++++++++++++++++++++++++++++-
 src/interp.c          | 14 ++++++++++++++
 tests/taili.c         | 39 +++++++++++++++++++++++++++++++++++++++
 tests/tailr.c         | 41 +++++++++++++++++++++++++++++++++++++++++
 7 files changed, 177 insertions(+), 1 deletion(-)
 create mode 100644 tests/taili.c
 create mode 100644 tests/tailr.c

diff --git a/include/ejit/ejit.h b/include/ejit/ejit.h
index aa42eca..5baaab6 100644
--- a/include/ejit/ejit.h
+++ b/include/ejit/ejit.h
@@ -457,6 +457,9 @@ struct ejit_label ejit_label(struct ejit_func *s);
 void ejit_tailr(struct ejit_func *s, struct ejit_gpr target,
 		size_t argc, const struct ejit_operand args[argc]);
 
+void ejit_taili(struct ejit_func *s, struct ejit_func *f,
+		size_t argc, const struct ejit_operand args[argc]);
+
 void ejit_calli(struct ejit_func *s, struct ejit_func *f, size_t argc,
                 const struct ejit_operand args[argc]);
 
diff --git a/src/common.h b/src/common.h
index 3512717..333c794 100644
--- a/src/common.h
+++ b/src/common.h
@@ -219,7 +219,9 @@ enum ejit_opcode {
 	EJIT_OP_ESCAPEI_D,
 
 	EJIT_OP_CALLI,
+
 	EJIT_OP_TAILR,
+	EJIT_OP_TAILI,
 
 	EJIT_OP_RETR,
 	EJIT_OP_RETI,
diff --git a/src/compile/compile.c b/src/compile/compile.c
index 54d79f2..60059d5 100644
--- a/src/compile/compile.c
+++ b/src/compile/compile.c
@@ -2505,6 +2505,40 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 			break;
 		}
 
+		case EJIT_OP_TAILI: {
+			/* a bit of copy-paste between this and the next func,
+			 * hmm */
+			assert(operands_len(&direct) <= 2);
+			struct ejit_func *f = (struct ejit_func *)(uintptr_t)i.o;
+			assert(f->direct_call);
+
+			jit_operand_t regs[2] = {
+				jit_operand_gpr(JIT_OPERAND_ABI_WORD, JIT_R1),
+				jit_operand_gpr(JIT_OPERAND_ABI_WORD, JIT_R2)
+			};
+			jit_move_operands(j, regs, direct.buf, operands_len(&direct));
+
+			int frame_size = j->frame_size;
+			jit_shrink_stack(j, stack);
+			jit_leave_jit_abi(j, gprs, fprs, frame);
+
+			/* now move args into place */
+			jit_operand_t args[2] = {};
+			foreach_vec(oi, direct) {
+				args[oi] = *operands_at(&direct, oi);
+			}
+
+			jit_locate_args(j, operands_len(&direct), args);
+			jit_move_operands(j, args, regs, operands_len(&direct));
+			jit_jmpi(j, f->direct_call);
+			j->frame_size = frame_size;
+
+			operands_reset(&src);
+			operands_reset(&dst);
+			operands_reset(&direct);
+			break;
+		}
+
 		case EJIT_OP_TAILR: {
 			/* this is admittedly a slightly roundabout way of
 			 * implementing tail calls and is arguably not the most
@@ -2518,6 +2552,7 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 			jit_gpr_t r = getloc(f, j, i.r1, 0);
 			jit_ldxi(j, JIT_R0, r, offsetof(struct ejit_func, direct_call));
 #if defined(DEBUG)
+			/** @todo other checks? */
 			jit_reloc_t assert_reloc = jit_bnei(j, JIT_R0, 0); /* null */
 			jit_calli_1(j, assert_helper,
 					jit_operand_imm(JIT_OPERAND_ABI_POINTER,
diff --git a/src/ejit.c b/src/ejit.c
index 059d5d4..0701b90 100644
--- a/src/ejit.c
+++ b/src/ejit.c
@@ -456,12 +456,54 @@ void ejit_patch(struct ejit_func *f, struct ejit_reloc r, struct ejit_label l)
 	*insns_at(&f->insns, r.insn) = i;
 }
 
+void ejit_taili(struct ejit_func *s, struct ejit_func *f,
+		size_t argc, const struct ejit_operand args[argc])
+{
+	assert(s->rtype == f->rtype);
+
+	s->max_args = argc > s->max_args ? argc : s->max_args;
+	check_operands(f, argc, args);
+
+	size_t gpr_args = 0, fpr_args = 0;
+	for (size_t i = 0; i < argc; ++i) {
+		switch (args[i].kind) {
+		case EJIT_OPERAND_GPR:
+			gpr_args++;
+			emit_insn_ar(s, EJIT_OP_ARG, i, args[i].type, EJIT_GPR(args[i].r));
+			break;
+
+		case EJIT_OPERAND_FPR:
+			fpr_args++;
+			emit_insn_af(s, EJIT_OP_ARG_F, i, args[i].type, EJIT_FPR(args[i].r));
+			break;
+
+		case EJIT_OPERAND_IMM:
+			gpr_args++;
+			emit_insn_ai(s, EJIT_OP_ARG_I, i, args[i].type, args[i].r);
+			break;
+
+		case EJIT_OPERAND_FLT:
+			fpr_args++;
+			emit_insn_ad(s, EJIT_OP_ARG_FI, i, args[i].type, args[i].d);
+			break;
+
+		default: abort();
+		}
+	}
+
+	assert(gpr_args <= 2 && fpr_args == 0
+			&& "only 2 gpr args and 0 fpr args supported in tail calls for now");
+	emit_insn_op(s, EJIT_OP_TAILI, f);
+}
+
 void ejit_tailr(struct ejit_func *s, struct ejit_gpr target, size_t argc,
                 const struct ejit_operand args[argc])
 {
 	s->max_args = argc > s->max_args ? argc : s->max_args;
 
-	/** @todo check that gpr_args <= 2 and fpr_args <= 3 (?) */
+	/* operands must match */
+	check_operands(s, argc, args);
+
 	size_t gpr_args = 0, fpr_args = 0;
 	for (size_t i = 0; i < argc; ++i) {
 		switch (args[i].kind) {
diff --git a/src/interp.c b/src/interp.c
index 132ba4a..268bfb3 100644
--- a/src/interp.c
+++ b/src/interp.c
@@ -214,7 +214,9 @@ union interp_ret ejit_run(struct ejit_func *f, size_t paramc, struct ejit_arg pa
 		[EJIT_OP_PARAM_F] = &&PARAM_F,
 
 		[EJIT_OP_CALLI] = &&CALLI,
+
 		[EJIT_OP_TAILR] = &&TAILR,
+		[EJIT_OP_TAILI] = &&TAILI,
 
 		[EJIT_OP_ESCAPEI_I] = &&ESCAPEI_I,
 		[EJIT_OP_ESCAPEI_F] = &&ESCAPEI_F,
@@ -1056,6 +1058,18 @@ top:
 	args[argc++] = a;
 	DISPATCH();
 
+	DO(TAILI);
+	f = (struct ejit_func *)(uintptr_t)i.o;
+
+	assert(!f->direct_call && "trying to interpret compiled fun");
+
+	paramc = argc;
+	for (size_t i = 0; i < argc; ++i)
+		params[i] = args[i];
+
+	goto top;
+	DISPATCH();
+
 	DO(TAILR);
 	f = (struct ejit_func *)gpr[i.r1];
 
diff --git a/tests/taili.c b/tests/taili.c
new file mode 100644
index 0000000..cc09f59
--- /dev/null
+++ b/tests/taili.c
@@ -0,0 +1,39 @@
+#include <ejit/ejit.h>
+#include <assert.h>
+#include "do_jit.h"
+
+int main(int argc, char *argv[])
+{
+	(void)argv;
+	bool do_jit = argc > 1;
+	struct ejit_operand operands[2] = {
+		EJIT_OPERAND_GPR(0, EJIT_INT32), /* s */
+		EJIT_OPERAND_GPR(1, EJIT_INT32)  /* n */
+	};
+
+	struct ejit_func *f = ejit_create_func(EJIT_INT32, 2, operands);
+
+	/* n == 0, return s */
+	struct ejit_reloc r = ejit_bnei(f, EJIT_GPR(1), 0);
+	ejit_retr(f, EJIT_GPR(0));
+	ejit_patch(f, r, ejit_label(f));
+
+	/* s += n */
+	ejit_addr(f, EJIT_GPR(0), EJIT_GPR(0), EJIT_GPR(1));
+
+	/* n -= 1 */
+	ejit_subi(f, EJIT_GPR(1), EJIT_GPR(1), 1);
+
+	struct ejit_operand args[2] = {
+		EJIT_OPERAND_GPR(0, EJIT_INT32), /* s */
+		EJIT_OPERAND_GPR(1, EJIT_INT32)  /* n */
+	};
+	ejit_taili(f, f, 2, args);
+
+	ejit_select_compile_func(f, 2, 0, EJIT_USE64(uintptr_t), do_jit, true);
+
+	/* arbitrary number but large enough to most likely cause a stack fault
+	 * if the tail call leaks memory or something */
+	assert((int32_t)erfi2(f, EJIT_ARG(0, int32_t), EJIT_ARG(1000000, int32_t)) == 1784293664);
+	ejit_destroy_func(f);
+}
diff --git a/tests/tailr.c b/tests/tailr.c
new file mode 100644
index 0000000..69ad44b
--- /dev/null
+++ b/tests/tailr.c
@@ -0,0 +1,41 @@
+#include <ejit/ejit.h>
+#include <assert.h>
+#include "do_jit.h"
+
+int main(int argc, char *argv[])
+{
+	(void)argv;
+	bool do_jit = argc > 1;
+	struct ejit_operand operands[2] = {
+		EJIT_OPERAND_GPR(0, EJIT_INT32), /* s */
+		EJIT_OPERAND_GPR(1, EJIT_INT32)  /* n */
+	};
+
+	struct ejit_func *f = ejit_create_func(EJIT_INT32, 2, operands);
+
+	/* n == 0, return s */
+	struct ejit_reloc r = ejit_bnei(f, EJIT_GPR(1), 0);
+	ejit_retr(f, EJIT_GPR(0));
+	ejit_patch(f, r, ejit_label(f));
+
+	/* s += n */
+	ejit_addr(f, EJIT_GPR(0), EJIT_GPR(0), EJIT_GPR(1));
+
+	/* n -= 1 */
+	ejit_subi(f, EJIT_GPR(1), EJIT_GPR(1), 1);
+
+	struct ejit_operand args[2] = {
+		EJIT_OPERAND_GPR(0, EJIT_INT32), /* s */
+		EJIT_OPERAND_GPR(1, EJIT_INT32)  /* n */
+	};
+
+	ejit_movi(f, EJIT_GPR(2), (uintptr_t)f);
+	ejit_tailr(f, EJIT_GPR(2), 2, args);
+
+	ejit_select_compile_func(f, 3, 0, EJIT_USE64(uintptr_t), do_jit, true);
+
+	/* arbitrary number but large enough to most likely cause a stack fault
+	 * if the tail call leaks memory or something */
+	assert((int32_t)erfi2(f, EJIT_ARG(0, int32_t), EJIT_ARG(1000000, int32_t)) == 1784293664);
+	ejit_destroy_func(f);
+}
-- 
cgit v1.2.3