From 57f6b41047e95374701ee276248f0f8615168450 Mon Sep 17 00:00:00 2001
From: Kimplul <kimi.h.kuparinen@gmail.com>
Date: Thu, 13 Mar 2025 19:20:56 +0200
Subject: improve register allocation

+ Still linear, but orders regs by some kind of priority
+ Use all registers available, not just callee-save
---
 src/compile/compile.c | 310 ++++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 251 insertions(+), 59 deletions(-)

(limited to 'src/compile/compile.c')

diff --git a/src/compile/compile.c b/src/compile/compile.c
index 9a97b3a..c9c7652 100644
--- a/src/compile/compile.c
+++ b/src/compile/compile.c
@@ -31,119 +31,243 @@ static void free_arena(void *arena, size_t size)
 	munmap(arena, size);
 }
 
-static size_t grploc_count(struct ejit_func *f)
+/* value slots are mapped to physical registers such that the first
+ * available callee-save register is used. When those run out, skip R0/R1/R2 as
+ * they are reserved for transferring values to/from the stack, continue with
+ * caller-save. When those run out, put slots on stack.
+ *
+ * The reasoning here is that callee-save are stored to the stack once, and
+ * after that are 'free', whereas caller-save registers must be stored/restored
+ * before/after each call. With some more advanced liveness analysis we might be
+ * able to avoid storing most caller-save registers, but the simple scheme I
+ * have here just pushes all regs so try to avoid that as much as possible.
+ */
+static size_t physgpr_count()
 {
-	return f->gpr < jit_v_num() ? 0 : f->gpr - jit_v_num();
+	return jit_v_num() + jit_r_num() - 3;
 }
 
-static size_t frploc_count(struct ejit_func *f)
+/* how many gpr slots are on the stack */
+static size_t gprloc_stack_count(struct ejit_func *f)
 {
-	return f->fpr < jit_vf_num() ? 0 : f->fpr - jit_vf_num();
+	return gpr_stats_len(&f->gpr) < physgpr_count()
+		? 0
+		: gpr_stats_len(&f->gpr) - physgpr_count();
+}
+
+/* get physical register for ordered gpr slot */
+static jit_gpr_t physgpr_at(size_t r)
+{
+	if (r < jit_v_num())
+		return jit_v(r);
+
+	/* avoid R0 - R2 as they're reserved for tmp use */
+	return jit_r(r - jit_v_num() + 3);
+}
+
+static size_t caller_save_gprs(struct ejit_func *f)
+{
+	if (gpr_stats_len(&f->gpr) >= physgpr_count())
+		return jit_r_num() - 3;
+
+	if (gpr_stats_len(&f->gpr) <= jit_v_num())
+		return 0;
+
+	return gpr_stats_len(&f->gpr) - jit_v_num();
+}
+
+static size_t physfpr_count()
+{
+	return jit_vf_num() + jit_f_num() - 3;
+}
+
+static jit_fpr_t physfpr_at(size_t r)
+{
+	if (r < jit_vf_num())
+		return jit_vf(r);
+
+	return jit_f(r - jit_vf_num() + 3);
+}
+
+static size_t fprloc_stack_count(struct ejit_func *f)
+{
+	return fpr_stats_len(&f->fpr) < physfpr_count()
+		? 0
+		: fpr_stats_len(&f->fpr) - physfpr_count();
+}
+
+static size_t caller_save_fprs(struct ejit_func *f)
+{
+	if (fpr_stats_len(&f->fpr) >= physfpr_count())
+		return jit_f_num() - 3;
+
+	if (fpr_stats_len(&f->fpr) <= jit_vf_num())
+		return 0;
+
+	return fpr_stats_len(&f->fpr) - jit_vf_num();
 }
 
 static size_t stack_size(struct ejit_func *f)
 {
-	return grploc_count(f) * sizeof(jit_uword_t)
-	       + frploc_count(f) * sizeof(jit_float64_t);
+	return gprloc_stack_count(f) * sizeof(jit_uword_t)
+	       + fprloc_stack_count(f) * sizeof(jit_float64_t)
+	       + caller_save_gprs(f) * sizeof(jit_uword_t)
+	       + caller_save_fprs(f) * sizeof(jit_float64_t)
+	       ;
 }
 
 static jit_off_t stack_loc(size_t l)
 {
-	assert(l >= jit_v_num());
-	return (l - jit_v_num()) * sizeof(jit_uword_t);
+	assert(l >= physgpr_count());
+	return (l - physgpr_count()) * sizeof(jit_uword_t);
 }
 
 static jit_off_t stack_loc_f(struct ejit_func *f, size_t l)
 {
-	assert(l >= jit_vf_num());
-	return grploc_count(f) * sizeof(jit_uword_t)
-	       + (l - jit_vf_num()) * sizeof(jit_float64_t);
+	assert(l >= physfpr_count());
+	return gprloc_stack_count(f) * sizeof(jit_uword_t)
+	       + (l - physfpr_count()) * sizeof(jit_float64_t);
+}
+
+static jit_off_t stack_loc_save_gpr(struct ejit_func *f, size_t i)
+{
+	return gprloc_stack_count(f) * sizeof(jit_uword_t)
+		+ fprloc_stack_count(f) * sizeof(jit_float64_t)
+		+ i * sizeof(jit_uword_t);
+}
+
+static jit_off_t stack_loc_save_fpr(struct ejit_func *f, size_t i)
+{
+	return gprloc_stack_count(f) * sizeof(jit_uword_t)
+		+ fprloc_stack_count(f) * sizeof(jit_float64_t)
+		+ caller_save_gprs(f) * sizeof(jit_uword_t)
+		+ i * sizeof(jit_float64_t);
+}
+
+/* for now, just push all caller-save register. Theoretically, we could fairly
+ * easily keep track of ranges where registers are alive and skip ones that are
+ * dead here, but I'm not sure how useful that would be without some form of
+ * SSA, which is maybe pushing how complex I want this to become. */
+static void save_caller_save_regs(struct ejit_func *f, jit_state_t *j)
+{
+	for (size_t i = 0; i < caller_save_gprs(f); ++i)
+		jit_stxi(j, stack_loc_save_gpr(f, i), JIT_SP, jit_r(i + 3));
+
+	for (size_t i = 0; i < caller_save_fprs(f); ++i)
+		jit_stxi_d(j, stack_loc_save_fpr(f, i), JIT_SP, jit_f(i + 3));
 }
 
+static void restore_caller_save_regs(struct ejit_func *f, jit_state_t *j)
+{
+	for (size_t i = 0; i < caller_save_gprs(f); ++i)
+		jit_ldxi(j, jit_r(i + 3), JIT_SP, stack_loc_save_gpr(f, i));
+
+	for (size_t i = 0; i < caller_save_fprs(f); ++i)
+		jit_ldxi_d(j, jit_f(i + 3), JIT_SP, stack_loc_save_fpr(f, i));
+}
 
+
+/* get ordered slot register. If slot is directly mapped to a physical register,
+ * return it, otherwise load from stack into R0-R2 (given by i) */
 static jit_gpr_t getloc(struct ejit_func *f, jit_state_t *j, size_t l, size_t i)
 {
-	assert(l < f->gpr);
-	if (l < jit_v_num())
-		return jit_v(l);
+	assert(l < gpr_stats_len(&f->gpr));
+	assert(i <= 2);
+	size_t r = gpr_stats_at(&f->gpr, l)->rno;
+	if (r < physgpr_count())
+		return physgpr_at(r);
 
-	jit_ldxi(j, jit_r(i), JIT_SP, stack_loc(l));
+	jit_ldxi(j, jit_r(i), JIT_SP, stack_loc(r));
 	return jit_r(i);
 }
 
 static jit_fpr_t getloc_f(struct ejit_func *f, jit_state_t *j, size_t l,
                           size_t i)
 {
-	assert(l < f->fpr);
-	if (l < jit_vf_num())
-		return jit_vf(l);
+	assert(l < fpr_stats_len(&f->fpr));
+	assert(i <= 2);
+	size_t r = fpr_stats_at(&f->fpr, l)->fno;
+	if (r < physfpr_count())
+		return physfpr_at(r);
 
-	jit_ldxi_f(j, jit_f(i), JIT_SP, stack_loc_f(f, l));
+	jit_ldxi_f(j, jit_f(i), JIT_SP, stack_loc_f(f, r));
 	return jit_f(i);
 }
 
 static jit_fpr_t getloc_d(struct ejit_func *f, jit_state_t *j, size_t l,
                           size_t i)
 {
-	assert(l < f->fpr);
-	if (l < jit_vf_num())
-		return jit_vf(l);
+	assert(l < fpr_stats_len(&f->fpr));
+	assert(i <= 2);
+	size_t r = fpr_stats_at(&f->fpr, l)->fno;
+	if (r < physfpr_count())
+		return physfpr_at(r);
 
 	/* not that stack_loc_f assumes double, so floats technically take up
 	 * more space than needed but at least we don't get any alignment issues */
-	jit_ldxi_d(j, jit_f(i), JIT_SP, stack_loc_f(f, l));
+	jit_ldxi_d(j, jit_f(i), JIT_SP, stack_loc_f(f, r));
 	return jit_f(i);
 }
 
+/* get physical register for slot l. If l is already in a physical register,
+ * return it, otherwise R0-R2 given by i. Does not fetch any values from a
+ * stack, mainly used for preparing a destination register. */
 static jit_gpr_t getgpr(struct ejit_func *f, size_t l, size_t i)
 {
-	assert(l < f->gpr);
-	if (l < jit_v_num())
-		return jit_v(l);
+	assert(l < gpr_stats_len(&f->gpr));
+	assert(i <= 2);
+	size_t r = gpr_stats_at(&f->gpr, l)->rno;
+	if (r < physgpr_count())
+		return physgpr_at(r);
 
 	return jit_r(i);
 }
 
 static jit_fpr_t getfpr(struct ejit_func *f, size_t l, size_t i)
 {
-	assert(l < f->fpr);
-	if (l < jit_vf_num())
-		return jit_vf(l);
+	assert(l < fpr_stats_len(&f->fpr));
+	assert(i <= 2);
+	size_t r = fpr_stats_at(&f->fpr, l)->fno;
+	if (r < physfpr_count())
+		return physfpr_at(r);
 
 	return jit_f(i);
 }
 
 static void putloc(struct ejit_func *f, jit_state_t *j, size_t l, jit_gpr_t r)
 {
-	assert(l < f->gpr);
-	if (l < jit_v_num()) {
-		assert(jit_v(l).regno == r.regno);
+	assert(l < gpr_stats_len(&f->gpr));
+	size_t rno = gpr_stats_at(&f->gpr, l)->rno;
+	if (rno < physgpr_count()) {
+		assert(physgpr_at(rno).regno == r.regno);
 		return;
 	}
 
-	jit_stxi(j, stack_loc(l), JIT_SP, r);
+	jit_stxi(j, stack_loc(rno), JIT_SP, r);
 }
 
 static void putloc_f(struct ejit_func *f, jit_state_t *j, size_t l, jit_fpr_t r)
 {
-	assert(l < f->fpr);
-	if (l < jit_vf_num()) {
-		assert(jit_v(l).regno == r.regno);
+	assert(l < fpr_stats_len(&f->fpr));
+	size_t rno = fpr_stats_at(&f->fpr, l)->fno;
+	if (rno < physfpr_count()) {
+		assert(physfpr_at(rno).regno == r.regno);
 		return;
 	}
 
-	jit_stxi_f(j, stack_loc_f(f, l), JIT_SP, r);
+	jit_stxi_f(j, stack_loc_f(f, rno), JIT_SP, r);
 }
 
 static void putloc_d(struct ejit_func *f, jit_state_t *j, size_t l, jit_fpr_t r)
 {
-	assert(l < f->fpr);
-	if (l < jit_vf_num()) {
-		assert(jit_v(l).regno == r.regno);
+	assert(l < fpr_stats_len(&f->fpr));
+	size_t rno = fpr_stats_at(&f->fpr, l)->fno;
+	if (rno < physfpr_count()) {
+		assert(physfpr_at(rno).regno == r.regno);
 		return;
 	}
 
-	jit_stxi_d(j, stack_loc_f(f, l), JIT_SP, r);
+	jit_stxi_d(j, stack_loc_f(f, rno), JIT_SP, r);
 }
 
 static void compile_label(jit_state_t *j, size_t ii, struct addrs *addrs)
@@ -1671,8 +1795,14 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 	jit_begin(j, arena, size);
 	compile_trampoline(f, j);
 
-	size_t gprs = f->gpr >= jit_v_num() ? jit_v_num() : f->gpr;
-	size_t fprs = f->fpr >= jit_vf_num() ? jit_vf_num() : f->fpr;
+	size_t gprs = gpr_stats_len(&f->gpr) >= jit_v_num()
+		? jit_v_num()
+		: gpr_stats_len(&f->gpr);
+
+	size_t fprs = fpr_stats_len(&f->fpr) >= jit_vf_num()
+		? jit_vf_num()
+		: fpr_stats_len(&f->fpr);
+
 	size_t frame = jit_enter_jit_abi(j, gprs, fprs, 0);
 	size_t stack = jit_align_stack(j, stack_size(f));
 
@@ -1876,19 +2006,20 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 		case JMP: compile_jmp(f, j, i, &relocs); break;
 
 		case ARG: {
-			jit_operand_t type =
-				jit_operand_imm(JIT_OPERAND_ABI_WORD, i.r1);
+			size_t r2 = gpr_stats_at(&f->gpr, i.r2)->rno;
+			jit_operand_t type = jit_operand_imm(JIT_OPERAND_ABI_WORD, i.r1);
+
 			jit_operand_t arg;
-			if (i.r2 < jit_v_num()) {
+			if (r2 < physgpr_count()) {
 				/* regular register */
 				arg = jit_operand_gpr(jit_abi_from(i.r1),
-				                      jit_v(i.r2));
+				                      physgpr_at(r2));
 			}
 			else {
 				/* stack location, note that we'll fix up the SP
 				 * offset before doing the actual call */
 				arg = jit_operand_mem(jit_abi_from(i.r1),
-				                      JIT_SP, stack_loc(i.r2));
+				                      JIT_SP, stack_loc(r2));
 			}
 
 			operands_append(&src, type);
@@ -1908,20 +2039,21 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 		}
 
 		case ARG_F: {
-			jit_operand_t type =
-				jit_operand_imm(JIT_OPERAND_ABI_WORD, i.r1);
+			size_t f2 = fpr_stats_at(&f->fpr, i.r2)->fno;
+			jit_operand_t type = jit_operand_imm(JIT_OPERAND_ABI_WORD, i.r1);
+
 			jit_operand_t arg;
-			if (i.r2 < jit_vf_num()) {
+			if (i.r2 < physfpr_count()) {
 				/* regular register */
 				arg = jit_operand_fpr(jit_abi_from(i.r1),
-				                      jit_vf(i.r2));
+				                      physfpr_at(f2));
 			}
 			else {
 				/* stack location, note that we'll fix up the SP
 				 * offset before doing the actual call */
 				arg = jit_operand_mem(jit_abi_from(i.r1),
 				                      JIT_SP,
-				                      stack_loc_f(f, i.r2));
+				                      stack_loc_f(f, f2));
 			}
 
 			operands_append(&src, type);
@@ -1941,12 +2073,16 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 		}
 
 		case ESCAPEI: {
+			save_caller_save_regs(f, j);
+
 			jit_operand_t args[2] = {
 				jit_operand_imm(JIT_OPERAND_ABI_WORD,
 				                operands_len(&src) / 2),
 				jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_SP)
 			};
 			compile_imm_call(j, &src, &dst, (void *)i.o, 2, args);
+			restore_caller_save_regs(f, j);
+
 			operands_reset(&src);
 			operands_reset(&dst);
 			operands_reset(&direct);
@@ -1954,12 +2090,16 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 		}
 
 		case ESCAPEI_F: {
+			save_caller_save_regs(f, j);
+
 			jit_operand_t args[2] = {
 				jit_operand_imm(JIT_OPERAND_ABI_WORD,
 				                operands_len(&src) / 2),
 				jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_SP)
 			};
 			compile_imm_call(j, &src, &dst, (void *)i.o, 2, args);
+			restore_caller_save_regs(f, j);
+
 			operands_reset(&src);
 			operands_reset(&dst);
 			operands_reset(&direct);
@@ -1967,9 +2107,13 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 		}
 
 		case CALLI: {
+			save_caller_save_regs(f, j);
+
 			struct ejit_func *f = (struct ejit_func *)i.o;
 			if (f && f->direct_call) {
 				jit_calli(j, f->direct_call, operands_len(&direct), direct.buf);
+				restore_caller_save_regs(f, j);
+
 				operands_reset(&src);
 				operands_reset(&dst);
 				operands_reset(&direct);
@@ -1983,6 +2127,8 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 				jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_SP)
 			};
 			compile_imm_call(j, &src, &dst, ejit_run_func, 3, args);
+			restore_caller_save_regs(f, j);
+
 			operands_reset(&src);
 			operands_reset(&dst);
 			operands_reset(&direct);
@@ -2038,16 +2184,18 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 		}
 
 		case PARAM_F: {
+			size_t f2 = fpr_stats_at(&f->fpr, i.r2)->fno;
+
 			jit_operand_t to;
-			if (i.r2 < jit_vf_num()) {
+			if (f2 < physfpr_count()) {
 				/* regular register */
 				to = jit_operand_fpr(jit_abi_from(i.r1),
-				                     jit_vf(i.r2));
+				                     physfpr_at(f2));
 			}
 			else {
 				/* stack location */
 				to = jit_operand_mem(jit_abi_from(i.r1), JIT_SP,
-				                     stack_loc_f(f, i.r2));
+				                     stack_loc_f(f, f2));
 			}
 
 			operands_append(&dst, to);
@@ -2055,16 +2203,18 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 		}
 
 		case PARAM: {
+			size_t r2 = gpr_stats_at(&f->gpr, i.r2)->rno;
+
 			jit_operand_t to;
-			if (i.r2 < jit_v_num()) {
+			if (r2 < physgpr_count()) {
 				/* regular register */
 				to = jit_operand_gpr(jit_abi_from(i.r1),
-				                     jit_v(i.r2));
+				                     physgpr_at(r2));
 			}
 			else {
 				/* stack location */
 				to = jit_operand_mem(jit_abi_from(i.r1), JIT_SP,
-				                     stack_loc(i.r2));
+				                     stack_loc(r2));
 			}
 
 			operands_append(&dst, to);
@@ -2104,6 +2254,45 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
 	return size;
 }
 
+/* highest prio first */
+static int gpr_sort_prio(struct gpr_stat *a, struct gpr_stat *b)
+{
+	return (int)b->prio - (int)a->prio;
+}
+
+static int fpr_sort_prio(struct fpr_stat *a, struct fpr_stat *b)
+{
+	return (int)b->prio - (int)a->prio;
+}
+
+/* sort registers by highest priority first, then renumber registers in the
+ * given order. Higher priority is given a physical register first.
+ *
+ * Note that the `->r` field becomes 'meaningless' after sorting, and you should
+ * only use the `->rno` field after this point. Essentially, if you have a
+ * register EJIT_GPR(2), you should use `gpr_stats_at(2)->rno` for the 'actual'
+ * register number in `getloc` and the like.
+ *
+ * Can be a bit confusing, but this way we don't have to allocate any new
+ * arrays, which is cool. */
+static void assign_gprs(struct ejit_func *f)
+{
+	gpr_stats_sort(&f->gpr, (vec_comp_t)gpr_sort_prio);
+	for (size_t i = 0; i < gpr_stats_len(&f->gpr); ++i) {
+		size_t rno = gpr_stats_at(&f->gpr, i)->r.r;
+		gpr_stats_at(&f->gpr, rno)->rno = i;
+	}
+}
+
+static void assign_fprs(struct ejit_func *f)
+{
+	fpr_stats_sort(&f->fpr, (vec_comp_t)fpr_sort_prio);
+	for (size_t i = 0; i < fpr_stats_len(&f->fpr); ++i) {
+		size_t rno = fpr_stats_at(&f->fpr, i)->f.f;
+		fpr_stats_at(&f->fpr, rno)->fno = i;
+	}
+}
+
 bool ejit_compile(struct ejit_func *f, bool use_64)
 {
 	(void)use_64;
@@ -2115,6 +2304,9 @@ bool ejit_compile(struct ejit_func *f, bool use_64)
 	if (!init_jit())
 		return false;
 
+	assign_gprs(f);
+	assign_fprs(f);
+
 	/* the main overhead of compilation seems to be the syscall to mmap a
 	 * new arena, I might look into allocating a big buffer at once and
 	 * caching it to be reused later, might allow us to compile many small
-- 
cgit v1.2.3