aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKimplul <kimi.h.kuparinen@gmail.com>2025-05-17 13:42:33 +0300
committerKimplul <kimi.h.kuparinen@gmail.com>2025-05-17 13:42:33 +0300
commitb9372c7be73a7cad6d741f5323dc8b2b758198d4 (patch)
tree66b766bccad93abd2311b9215651bd94456b3728
parent5c11915931ad43617ba6f62189bb11630b9624d4 (diff)
downloadejit-b9372c7be73a7cad6d741f5323dc8b2b758198d4.tar.gz
ejit-b9372c7be73a7cad6d741f5323dc8b2b758198d4.zip
take loops into account in register allocatorHEADmaster
-rw-r--r--examples/matrix_mult.c35
-rw-r--r--src/common.h9
-rw-r--r--src/compile/compile.c70
-rw-r--r--src/ejit.c8
-rw-r--r--tests/z_matrix_mult.c143
5 files changed, 248 insertions, 17 deletions
diff --git a/examples/matrix_mult.c b/examples/matrix_mult.c
index fab2319..ff8af55 100644
--- a/examples/matrix_mult.c
+++ b/examples/matrix_mult.c
@@ -1,14 +1,14 @@
#include <stdio.h>
#include "../include/ejit/ejit.h"
-#define X 400
+#define X 10
int A[X][X];
int B[X][X];
int C[X][X];
static void init_matrices(int A[X][X], int B[X][X])
{
- int counter = 0;
+ int counter = 1;
for (size_t i = 0; i < X; ++i)
for (size_t j = 0; j < X; ++j) {
A[i][j] = counter;
@@ -41,7 +41,8 @@ static struct ejit_func *compile()
#define AX EJIT_GPR(6)
#define BX EJIT_GPR(7)
#define CX EJIT_GPR(8)
-#define TMP EJIT_GPR(9)
+#define RR EJIT_GPR(9)
+#define TMP EJIT_GPR(10)
struct ejit_operand args[3] = {
EJIT_OPERAND_GPR(3, EJIT_TYPE(int *)),
@@ -58,6 +59,8 @@ static struct ejit_func *compile()
struct ejit_label mid = ejit_label(f);
struct ejit_reloc midder = ejit_bgei(f, JR, X);
+ ejit_movi(f, RR, 0);
+
ejit_movi(f, KR, 0);
struct ejit_label in = ejit_label(f);
struct ejit_reloc inner = ejit_bgei(f, KR, X);
@@ -78,20 +81,9 @@ static struct ejit_func *compile()
ejit_lshi(f, BX, BX, 2);
EJIT_LDXR(f, int, BX, BR, BX);
- /* C[i][j] at addr 4 * (i * 400 + j) */
- ejit_movi(f, TMP, X);
- ejit_mulr(f, CX, IR, TMP);
- ejit_addr(f, CX, CX, JR);
- ejit_movi(f, TMP, 4);
- /* reuse address */
- ejit_lshi(f, TMP, CX, 2);
- EJIT_LDXR(f, int, CX, CR, TMP);
-
- ejit_mulr(f, AX, AX, BX);
- ejit_addr(f, CX, CX, AX);
-
- /* store result */
- EJIT_STXR(f, int, CX, CR, TMP);
+ /* R += A[i][k] * B[k][j] */
+ ejit_mulr(f, TMP, AX, BX);
+ ejit_addr(f, RR, RR, TMP);
/* increment inner */
ejit_addi(f, KR, KR, 1);
@@ -99,6 +91,15 @@ static struct ejit_func *compile()
/* end of inner */
ejit_patch(f, inner, ejit_label(f));
+ /* C[i][j] at addr 4 * (i * 400 + j) */
+ ejit_movi(f, TMP, X);
+ ejit_mulr(f, CX, IR, TMP);
+ ejit_addr(f, CX, CX, JR);
+ ejit_lshi(f, CX, CX, 2);
+
+ /* C[i][j] = R */
+ EJIT_STXR(f, int, RR, CX, CR);
+
/* increment midder */
ejit_addi(f, JR, JR, 1);
ejit_patch(f, ejit_jmp(f), mid);
diff --git a/src/common.h b/src/common.h
index 1eb762e..41a17cf 100644
--- a/src/common.h
+++ b/src/common.h
@@ -4,6 +4,14 @@
#include <ejit/ejit.h>
#include <stdbool.h>
+struct barrier_tuple {
+ size_t start, end;
+};
+
+#define VEC_TYPE struct barrier_tuple
+#define VEC_NAME barriers
+#include <conts/vec.h>
+
#define VEC_TYPE struct ejit_arg
#define VEC_NAME args
#include <conts/vec.h>
@@ -291,6 +299,7 @@ struct ejit_func {
struct types sign;
struct insns insns;
struct labels labels;
+ struct barriers barriers;
enum ejit_type rtype;
struct gpr_stats gpr;
diff --git a/src/compile/compile.c b/src/compile/compile.c
index f552ae3..8e0e250 100644
--- a/src/compile/compile.c
+++ b/src/compile/compile.c
@@ -2856,6 +2856,18 @@ static int spill_cost_sort(struct alive_slot *a, struct alive_slot *b)
return a->cost < b->cost;
}
+/* sort barriers according to starting address, smallest to largest */
+static int barrier_sort(struct barrier_tuple *a, struct barrier_tuple *b)
+{
+ if (a->start < b->start)
+ return -1;
+
+ if (a->start > b->start)
+ return 1;
+
+ return a->end - b->end;
+}
+
/* slightly more parameters than I would like but I guess it's fine */
static void calculate_alive(struct alive *alive, size_t idx,
size_t prio, size_t start, size_t end, size_t *rno,
@@ -2924,6 +2936,40 @@ static void linear_gpr_alloc(struct ejit_func *f)
}
}
+static void extend_gpr_lifetime(struct gpr_stat *gpr, struct barriers *barriers, size_t bi)
+{
+ if (bi >= barriers_len(barriers))
+ return;
+
+ struct barrier_tuple *barrier = barriers_at(barriers, bi);
+ if (gpr->end < barrier->start)
+ return;
+
+ /* we cross a barrier */
+ gpr->end = gpr->end > barrier->end ? gpr->end : barrier->end;
+
+ /* check if we cross the next barrier as well,
+ * hopefully gets optimized into a tail call */
+ return extend_gpr_lifetime(gpr, barriers, bi + 1);
+}
+
+static void extend_fpr_lifetime(struct fpr_stat *fpr, struct barriers *barriers, size_t bi)
+{
+ if (bi >= barriers_len(barriers))
+ return;
+
+ struct barrier_tuple *barrier = barriers_at(barriers, bi);
+ if (fpr->end < barrier->start)
+ return;
+
+ /* we cross a barrier */
+ fpr->end = fpr->end > barrier->end ? fpr->end : barrier->end;
+
+ /* check if we cross the next barrier as well,
+ * hopefully gets optimized into a tail call */
+ return extend_fpr_lifetime(fpr, barriers, bi + 1);
+}
+
/* there's a fair bit of repetition between this and the gpr case, hmm */
static void assign_gprs(struct ejit_func *f)
{
@@ -2937,8 +2983,20 @@ static void assign_gprs(struct ejit_func *f)
struct alive_slot a = {.r = -1, .cost = 0, .idx = 0};
alive_append(&alive, a);
+ /* barrier index, keeps track of earliest possible barrier we can be
+ * dealing with. Since register start addresses grow upward, we can
+ * fairly easily keep track of which barrier a register cannot cross */
+ size_t bi = 0;
for (size_t gi = 0; gi < gpr_stats_len(&f->gpr); ++gi) {
struct gpr_stat *gpr = gpr_stats_at(&f->gpr, gi);
+
+ extend_gpr_lifetime(gpr, &f->barriers, bi);
+ if (bi < barriers_len(&f->barriers)) {
+ struct barrier_tuple *barrier = barriers_at(&f->barriers, bi);
+ if (gpr->start >= barrier->start)
+ bi++;
+ }
+
calculate_alive(&alive, gi,
gpr->prio, gpr->start, gpr->end, &gpr->rno,
&f->gpr, gpr_dead);
@@ -2989,8 +3047,17 @@ static void assign_fprs(struct ejit_func *f)
struct alive_slot a = {.r = -1, .cost = 0, .idx = 0};
alive_append(&alive, a);
+ size_t bi = 0;
for (size_t fi = 0; fi < fpr_stats_len(&f->fpr); ++fi) {
struct fpr_stat *fpr = fpr_stats_at(&f->fpr, fi);
+
+ extend_fpr_lifetime(fpr, &f->barriers, bi);
+ if (bi < barriers_len(&f->barriers)) {
+ struct barrier_tuple *barrier = barriers_at(&f->barriers, bi);
+ if (fpr->start >= barrier->start)
+ bi++;
+ }
+
calculate_alive(&alive, fi,
fpr->prio, fpr->start, fpr->end, &fpr->fno,
&f->fpr, fpr_dead);
@@ -3035,6 +3102,9 @@ bool ejit_compile(struct ejit_func *f, bool use_64, bool im_scawed)
if (!init_jit())
return false;
+ /* sort barriers so they can be used to extend register life times in
+ * loops */
+ barriers_sort(&f->barriers, barrier_sort);
assign_gprs(f);
assign_fprs(f);
diff --git a/src/ejit.c b/src/ejit.c
index 2341890..1bb272d 100644
--- a/src/ejit.c
+++ b/src/ejit.c
@@ -341,6 +341,7 @@ struct ejit_func *ejit_create_func(enum ejit_type rtype, size_t argc,
f->sign = types_create(0);
f->insns = insns_create(0);
f->labels = labels_create(0);
+ f->barriers = barriers_create(0);
f->gpr = gpr_stats_create(0);
f->fpr = fpr_stats_create(0);
f->arena = NULL;
@@ -434,6 +435,7 @@ void ejit_destroy_func(struct ejit_func *f)
types_destroy(&f->sign);
insns_destroy(&f->insns);
labels_destroy(&f->labels);
+ barriers_destroy(&f->barriers);
gpr_stats_destroy(&f->gpr);
fpr_stats_destroy(&f->fpr);
free(f);
@@ -452,6 +454,12 @@ void ejit_patch(struct ejit_func *f, struct ejit_reloc r, struct ejit_label l)
/** @todo some assert that checks the opcode? */
i.r0 = l.addr;
*insns_at(&f->insns, r.insn) = i;
+
+ struct barrier_tuple tuple = {
+ .start = r.insn > l.addr ? l.addr : r.insn,
+ .end = r.insn > l.addr ? r.insn : l.addr
+ };
+ barriers_append(&f->barriers, tuple);
}
void ejit_taili(struct ejit_func *s, struct ejit_func *f,
diff --git a/tests/z_matrix_mult.c b/tests/z_matrix_mult.c
new file mode 100644
index 0000000..8cfba69
--- /dev/null
+++ b/tests/z_matrix_mult.c
@@ -0,0 +1,143 @@
+#include <ejit/ejit.h>
+#include <assert.h>
+#include "do_jit.h"
+
+/* register allocator had a bug where registers whose last use was midway
+ * through a loop might get overwritten, here's one case that exposed the bug */
+
+#define X 10
+int A[X][X];
+int B[X][X];
+int C[X][X];
+
+bool do_jit = false;
+
+static void init_matrices(int A[X][X], int B[X][X])
+{
+ int counter = 1;
+ for (size_t i = 0; i < X; ++i)
+ for (size_t j = 0; j < X; ++j) {
+ A[i][j] = counter;
+ B[i][j] = counter;
+ C[i][j] = 0;
+
+ counter++;
+ }
+}
+
+static int hash(int C[X][X])
+{
+ int h = 0;
+ for (size_t i = 0; i < X; ++i)
+ for (size_t j = 0; j < X; ++j) {
+ h += C[i][j];
+ }
+
+ return h;
+}
+
+static struct ejit_func *compile()
+{
+#define IR EJIT_GPR(0)
+#define JR EJIT_GPR(1)
+#define KR EJIT_GPR(2)
+#define AR EJIT_GPR(3)
+#define BR EJIT_GPR(4)
+#define CR EJIT_GPR(5)
+#define AX EJIT_GPR(6)
+#define BX EJIT_GPR(7)
+#define CX EJIT_GPR(8)
+#define RR EJIT_GPR(9)
+#define TMP EJIT_GPR(10)
+
+ struct ejit_operand args[3] = {
+ EJIT_OPERAND_GPR(3, EJIT_TYPE(int *)), /* A */
+ EJIT_OPERAND_GPR(4, EJIT_TYPE(int *)), /* B */
+ EJIT_OPERAND_GPR(5, EJIT_TYPE(int *)), /* C */
+ };
+ struct ejit_func *f = ejit_create_func(EJIT_VOID, 3, args);
+
+ ejit_movi(f, IR, 0);
+ struct ejit_label out = ejit_label(f);
+ struct ejit_reloc outer = ejit_bgei(f, IR, X);
+
+ ejit_movi(f, JR, 0);
+ struct ejit_label mid = ejit_label(f);
+ struct ejit_reloc midder = ejit_bgei(f, JR, X);
+
+ ejit_movi(f, RR, 0);
+
+ ejit_movi(f, KR, 0);
+ struct ejit_label in = ejit_label(f);
+ struct ejit_reloc inner = ejit_bgei(f, KR, X);
+
+ /* A[i][k] at addr 4 * (i * 400 + k) */
+ ejit_movi(f, TMP, X);
+ ejit_mulr(f, AX, IR, TMP);
+ ejit_addr(f, AX, AX, KR);
+ ejit_movi(f, TMP, 4);
+ ejit_lshi(f, AX, AX, 2);
+ EJIT_LDXR(f, int, AX, AR, AX);
+
+ /* B[k][j] at addr 4 * (k * 400 + j) */
+ ejit_movi(f, TMP, X);
+ ejit_mulr(f, BX, KR, TMP);
+ ejit_addr(f, BX, BX, JR);
+ ejit_movi(f, TMP, 4);
+ ejit_lshi(f, BX, BX, 2);
+ EJIT_LDXR(f, int, BX, BR, BX);
+
+ /* R += A[i][k] * B[k][j] */
+ ejit_mulr(f, TMP, AX, BX);
+ ejit_addr(f, RR, RR, TMP);
+
+ /* increment inner */
+ ejit_addi(f, KR, KR, 1);
+ ejit_patch(f, ejit_jmp(f), in);
+ /* end of inner */
+ ejit_patch(f, inner, ejit_label(f));
+
+ /* C[i][j] at addr 4 * (i * 400 + j) */
+ ejit_movi(f, TMP, X);
+ ejit_mulr(f, CX, IR, TMP);
+ ejit_addr(f, CX, CX, JR);
+ ejit_lshi(f, CX, CX, 2);
+
+ /* C[i][j] = R */
+ EJIT_STXR(f, int, RR, CX, CR);
+
+ /* increment midder */
+ ejit_addi(f, JR, JR, 1);
+ ejit_patch(f, ejit_jmp(f), mid);
+ /* end of midder */
+ ejit_patch(f, midder, ejit_label(f));
+
+ /* increment outer */
+ ejit_addi(f, IR, IR, 1);
+ ejit_patch(f, ejit_jmp(f), out);
+ /* end of outer */
+ ejit_patch(f, outer, ejit_label(f));
+
+ /* return */
+ ejit_ret(f);
+ ejit_select_compile_func(f, 11, 0, EJIT_USE64(long), do_jit, true);
+ return f;
+}
+
+int main(int argc, char *argv[])
+{
+ (void)argv;
+ do_jit = argc > 1;
+
+ init_matrices(A, B);
+ struct ejit_func *f = compile();
+ struct ejit_arg args[3] = {
+ EJIT_ARG(A, int *),
+ EJIT_ARG(B, int *),
+ EJIT_ARG(C, int *)
+ };
+ ejit_run_func(f, 3, args);
+ assert(hash(C) == 2632750);
+ ejit_destroy_func(f);
+ return 0;
+}