diff options
-rw-r--r-- | examples/matrix_mult.c | 35 | ||||
-rw-r--r-- | src/common.h | 9 | ||||
-rw-r--r-- | src/compile/compile.c | 70 | ||||
-rw-r--r-- | src/ejit.c | 8 | ||||
-rw-r--r-- | tests/z_matrix_mult.c | 143 |
5 files changed, 248 insertions, 17 deletions
diff --git a/examples/matrix_mult.c b/examples/matrix_mult.c index fab2319..ff8af55 100644 --- a/examples/matrix_mult.c +++ b/examples/matrix_mult.c @@ -1,14 +1,14 @@ #include <stdio.h> #include "../include/ejit/ejit.h" -#define X 400 +#define X 10 int A[X][X]; int B[X][X]; int C[X][X]; static void init_matrices(int A[X][X], int B[X][X]) { - int counter = 0; + int counter = 1; for (size_t i = 0; i < X; ++i) for (size_t j = 0; j < X; ++j) { A[i][j] = counter; @@ -41,7 +41,8 @@ static struct ejit_func *compile() #define AX EJIT_GPR(6) #define BX EJIT_GPR(7) #define CX EJIT_GPR(8) -#define TMP EJIT_GPR(9) +#define RR EJIT_GPR(9) +#define TMP EJIT_GPR(10) struct ejit_operand args[3] = { EJIT_OPERAND_GPR(3, EJIT_TYPE(int *)), @@ -58,6 +59,8 @@ static struct ejit_func *compile() struct ejit_label mid = ejit_label(f); struct ejit_reloc midder = ejit_bgei(f, JR, X); + ejit_movi(f, RR, 0); + ejit_movi(f, KR, 0); struct ejit_label in = ejit_label(f); struct ejit_reloc inner = ejit_bgei(f, KR, X); @@ -78,20 +81,9 @@ static struct ejit_func *compile() ejit_lshi(f, BX, BX, 2); EJIT_LDXR(f, int, BX, BR, BX); - /* C[i][j] at addr 4 * (i * 400 + j) */ - ejit_movi(f, TMP, X); - ejit_mulr(f, CX, IR, TMP); - ejit_addr(f, CX, CX, JR); - ejit_movi(f, TMP, 4); - /* reuse address */ - ejit_lshi(f, TMP, CX, 2); - EJIT_LDXR(f, int, CX, CR, TMP); - - ejit_mulr(f, AX, AX, BX); - ejit_addr(f, CX, CX, AX); - - /* store result */ - EJIT_STXR(f, int, CX, CR, TMP); + /* R += A[i][k] * B[k][j] */ + ejit_mulr(f, TMP, AX, BX); + ejit_addr(f, RR, RR, TMP); /* increment inner */ ejit_addi(f, KR, KR, 1); @@ -99,6 +91,15 @@ static struct ejit_func *compile() /* end of inner */ ejit_patch(f, inner, ejit_label(f)); + /* C[i][j] at addr 4 * (i * 400 + j) */ + ejit_movi(f, TMP, X); + ejit_mulr(f, CX, IR, TMP); + ejit_addr(f, CX, CX, JR); + ejit_lshi(f, CX, CX, 2); + + /* C[i][j] = R */ + EJIT_STXR(f, int, RR, CX, CR); + /* increment midder */ ejit_addi(f, JR, JR, 1); ejit_patch(f, ejit_jmp(f), mid); diff --git a/src/common.h b/src/common.h index 1eb762e..41a17cf 100644 --- a/src/common.h +++ b/src/common.h @@ -4,6 +4,14 @@ #include <ejit/ejit.h> #include <stdbool.h> +struct barrier_tuple { + size_t start, end; +}; + +#define VEC_TYPE struct barrier_tuple +#define VEC_NAME barriers +#include <conts/vec.h> + #define VEC_TYPE struct ejit_arg #define VEC_NAME args #include <conts/vec.h> @@ -291,6 +299,7 @@ struct ejit_func { struct types sign; struct insns insns; struct labels labels; + struct barriers barriers; enum ejit_type rtype; struct gpr_stats gpr; diff --git a/src/compile/compile.c b/src/compile/compile.c index f552ae3..8e0e250 100644 --- a/src/compile/compile.c +++ b/src/compile/compile.c @@ -2856,6 +2856,18 @@ static int spill_cost_sort(struct alive_slot *a, struct alive_slot *b) return a->cost < b->cost; } +/* sort barriers according to starting address, smallest to largest */ +static int barrier_sort(struct barrier_tuple *a, struct barrier_tuple *b) +{ + if (a->start < b->start) + return -1; + + if (a->start > b->start) + return 1; + + return a->end - b->end; +} + /* slightly more parameters than I would like but I guess it's fine */ static void calculate_alive(struct alive *alive, size_t idx, size_t prio, size_t start, size_t end, size_t *rno, @@ -2924,6 +2936,40 @@ static void linear_gpr_alloc(struct ejit_func *f) } } +static void extend_gpr_lifetime(struct gpr_stat *gpr, struct barriers *barriers, size_t bi) +{ + if (bi >= barriers_len(barriers)) + return; + + struct barrier_tuple *barrier = barriers_at(barriers, bi); + if (gpr->end < barrier->start) + return; + + /* we cross a barrier */ + gpr->end = gpr->end > barrier->end ? gpr->end : barrier->end; + + /* check if we cross the next barrier as well, + * hopefully gets optimized into a tail call */ + return extend_gpr_lifetime(gpr, barriers, bi + 1); +} + +static void extend_fpr_lifetime(struct fpr_stat *fpr, struct barriers *barriers, size_t bi) +{ + if (bi >= barriers_len(barriers)) + return; + + struct barrier_tuple *barrier = barriers_at(barriers, bi); + if (fpr->end < barrier->start) + return; + + /* we cross a barrier */ + fpr->end = fpr->end > barrier->end ? fpr->end : barrier->end; + + /* check if we cross the next barrier as well, + * hopefully gets optimized into a tail call */ + return extend_fpr_lifetime(fpr, barriers, bi + 1); +} + /* there's a fair bit of repetition between this and the gpr case, hmm */ static void assign_gprs(struct ejit_func *f) { @@ -2937,8 +2983,20 @@ static void assign_gprs(struct ejit_func *f) struct alive_slot a = {.r = -1, .cost = 0, .idx = 0}; alive_append(&alive, a); + /* barrier index, keeps track of earliest possible barrier we can be + * dealing with. Since register start addresses grow upward, we can + * fairly easily keep track of which barrier a register cannot cross */ + size_t bi = 0; for (size_t gi = 0; gi < gpr_stats_len(&f->gpr); ++gi) { struct gpr_stat *gpr = gpr_stats_at(&f->gpr, gi); + + extend_gpr_lifetime(gpr, &f->barriers, bi); + if (bi < barriers_len(&f->barriers)) { + struct barrier_tuple *barrier = barriers_at(&f->barriers, bi); + if (gpr->start >= barrier->start) + bi++; + } + calculate_alive(&alive, gi, gpr->prio, gpr->start, gpr->end, &gpr->rno, &f->gpr, gpr_dead); @@ -2989,8 +3047,17 @@ static void assign_fprs(struct ejit_func *f) struct alive_slot a = {.r = -1, .cost = 0, .idx = 0}; alive_append(&alive, a); + size_t bi = 0; for (size_t fi = 0; fi < fpr_stats_len(&f->fpr); ++fi) { struct fpr_stat *fpr = fpr_stats_at(&f->fpr, fi); + + extend_fpr_lifetime(fpr, &f->barriers, bi); + if (bi < barriers_len(&f->barriers)) { + struct barrier_tuple *barrier = barriers_at(&f->barriers, bi); + if (fpr->start >= barrier->start) + bi++; + } + calculate_alive(&alive, fi, fpr->prio, fpr->start, fpr->end, &fpr->fno, &f->fpr, fpr_dead); @@ -3035,6 +3102,9 @@ bool ejit_compile(struct ejit_func *f, bool use_64, bool im_scawed) if (!init_jit()) return false; + /* sort barriers so they can be used to extend register life times in + * loops */ + barriers_sort(&f->barriers, barrier_sort); assign_gprs(f); assign_fprs(f); @@ -341,6 +341,7 @@ struct ejit_func *ejit_create_func(enum ejit_type rtype, size_t argc, f->sign = types_create(0); f->insns = insns_create(0); f->labels = labels_create(0); + f->barriers = barriers_create(0); f->gpr = gpr_stats_create(0); f->fpr = fpr_stats_create(0); f->arena = NULL; @@ -434,6 +435,7 @@ void ejit_destroy_func(struct ejit_func *f) types_destroy(&f->sign); insns_destroy(&f->insns); labels_destroy(&f->labels); + barriers_destroy(&f->barriers); gpr_stats_destroy(&f->gpr); fpr_stats_destroy(&f->fpr); free(f); @@ -452,6 +454,12 @@ void ejit_patch(struct ejit_func *f, struct ejit_reloc r, struct ejit_label l) /** @todo some assert that checks the opcode? */ i.r0 = l.addr; *insns_at(&f->insns, r.insn) = i; + + struct barrier_tuple tuple = { + .start = r.insn > l.addr ? l.addr : r.insn, + .end = r.insn > l.addr ? r.insn : l.addr + }; + barriers_append(&f->barriers, tuple); } void ejit_taili(struct ejit_func *s, struct ejit_func *f, diff --git a/tests/z_matrix_mult.c b/tests/z_matrix_mult.c new file mode 100644 index 0000000..8cfba69 --- /dev/null +++ b/tests/z_matrix_mult.c @@ -0,0 +1,143 @@ +#include <ejit/ejit.h> +#include <assert.h> +#include "do_jit.h" + +/* register allocator had a bug where registers whose last use was midway + * through a loop might get overwritten, here's one case that exposed the bug */ + +#define X 10 +int A[X][X]; +int B[X][X]; +int C[X][X]; + +bool do_jit = false; + +static void init_matrices(int A[X][X], int B[X][X]) +{ + int counter = 1; + for (size_t i = 0; i < X; ++i) + for (size_t j = 0; j < X; ++j) { + A[i][j] = counter; + B[i][j] = counter; + C[i][j] = 0; + + counter++; + } +} + +static int hash(int C[X][X]) +{ + int h = 0; + for (size_t i = 0; i < X; ++i) + for (size_t j = 0; j < X; ++j) { + h += C[i][j]; + } + + return h; +} + +static struct ejit_func *compile() +{ +#define IR EJIT_GPR(0) +#define JR EJIT_GPR(1) +#define KR EJIT_GPR(2) +#define AR EJIT_GPR(3) +#define BR EJIT_GPR(4) +#define CR EJIT_GPR(5) +#define AX EJIT_GPR(6) +#define BX EJIT_GPR(7) +#define CX EJIT_GPR(8) +#define RR EJIT_GPR(9) +#define TMP EJIT_GPR(10) + + struct ejit_operand args[3] = { + EJIT_OPERAND_GPR(3, EJIT_TYPE(int *)), /* A */ + EJIT_OPERAND_GPR(4, EJIT_TYPE(int *)), /* B */ + EJIT_OPERAND_GPR(5, EJIT_TYPE(int *)), /* C */ + }; + struct ejit_func *f = ejit_create_func(EJIT_VOID, 3, args); + + ejit_movi(f, IR, 0); + struct ejit_label out = ejit_label(f); + struct ejit_reloc outer = ejit_bgei(f, IR, X); + + ejit_movi(f, JR, 0); + struct ejit_label mid = ejit_label(f); + struct ejit_reloc midder = ejit_bgei(f, JR, X); + + ejit_movi(f, RR, 0); + + ejit_movi(f, KR, 0); + struct ejit_label in = ejit_label(f); + struct ejit_reloc inner = ejit_bgei(f, KR, X); + + /* A[i][k] at addr 4 * (i * 400 + k) */ + ejit_movi(f, TMP, X); + ejit_mulr(f, AX, IR, TMP); + ejit_addr(f, AX, AX, KR); + ejit_movi(f, TMP, 4); + ejit_lshi(f, AX, AX, 2); + EJIT_LDXR(f, int, AX, AR, AX); + + /* B[k][j] at addr 4 * (k * 400 + j) */ + ejit_movi(f, TMP, X); + ejit_mulr(f, BX, KR, TMP); + ejit_addr(f, BX, BX, JR); + ejit_movi(f, TMP, 4); + ejit_lshi(f, BX, BX, 2); + EJIT_LDXR(f, int, BX, BR, BX); + + /* R += A[i][k] * B[k][j] */ + ejit_mulr(f, TMP, AX, BX); + ejit_addr(f, RR, RR, TMP); + + /* increment inner */ + ejit_addi(f, KR, KR, 1); + ejit_patch(f, ejit_jmp(f), in); + /* end of inner */ + ejit_patch(f, inner, ejit_label(f)); + + /* C[i][j] at addr 4 * (i * 400 + j) */ + ejit_movi(f, TMP, X); + ejit_mulr(f, CX, IR, TMP); + ejit_addr(f, CX, CX, JR); + ejit_lshi(f, CX, CX, 2); + + /* C[i][j] = R */ + EJIT_STXR(f, int, RR, CX, CR); + + /* increment midder */ + ejit_addi(f, JR, JR, 1); + ejit_patch(f, ejit_jmp(f), mid); + /* end of midder */ + ejit_patch(f, midder, ejit_label(f)); + + /* increment outer */ + ejit_addi(f, IR, IR, 1); + ejit_patch(f, ejit_jmp(f), out); + /* end of outer */ + ejit_patch(f, outer, ejit_label(f)); + + /* return */ + ejit_ret(f); + ejit_select_compile_func(f, 11, 0, EJIT_USE64(long), do_jit, true); + return f; +} + +int main(int argc, char *argv[]) +{ + (void)argv; + do_jit = argc > 1; + + init_matrices(A, B); + struct ejit_func *f = compile(); + struct ejit_arg args[3] = { + EJIT_ARG(A, int *), + EJIT_ARG(B, int *), + EJIT_ARG(C, int *) + }; + ejit_run_func(f, 3, args); + assert(hash(C) == 2632750); + ejit_destroy_func(f); + return 0; +} |