diff options
-rw-r--r-- | .gitmodules | 3 | ||||
m--------- | deps/conts | 0 | ||||
-rw-r--r-- | examples/matrix_mult.c | 35 | ||||
-rw-r--r-- | include/ejit/ejit.h | 24 | ||||
-rw-r--r-- | scripts/makefile | 2 | ||||
-rw-r--r-- | src/common.h | 25 | ||||
-rw-r--r-- | src/compile/compile.c | 223 | ||||
-rw-r--r-- | src/ejit.c | 26 | ||||
-rw-r--r-- | src/vec.h | 119 | ||||
-rw-r--r-- | tests/z_matrix_mult.c | 143 |
10 files changed, 396 insertions, 204 deletions
diff --git a/.gitmodules b/.gitmodules index a514045..24f5f47 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "deps/lightening"] path = deps/lightening url = https://gitlab.com/Kimplul/lightening.git +[submodule "deps/conts"] + path = deps/conts + url = https://metanimi.dy.fi/cgit/conts diff --git a/deps/conts b/deps/conts new file mode 160000 +Subproject 5a4cb1b5a8ba258a23a62feab1e66cc7d0eba3b diff --git a/examples/matrix_mult.c b/examples/matrix_mult.c index fab2319..ff8af55 100644 --- a/examples/matrix_mult.c +++ b/examples/matrix_mult.c @@ -1,14 +1,14 @@ #include <stdio.h> #include "../include/ejit/ejit.h" -#define X 400 +#define X 10 int A[X][X]; int B[X][X]; int C[X][X]; static void init_matrices(int A[X][X], int B[X][X]) { - int counter = 0; + int counter = 1; for (size_t i = 0; i < X; ++i) for (size_t j = 0; j < X; ++j) { A[i][j] = counter; @@ -41,7 +41,8 @@ static struct ejit_func *compile() #define AX EJIT_GPR(6) #define BX EJIT_GPR(7) #define CX EJIT_GPR(8) -#define TMP EJIT_GPR(9) +#define RR EJIT_GPR(9) +#define TMP EJIT_GPR(10) struct ejit_operand args[3] = { EJIT_OPERAND_GPR(3, EJIT_TYPE(int *)), @@ -58,6 +59,8 @@ static struct ejit_func *compile() struct ejit_label mid = ejit_label(f); struct ejit_reloc midder = ejit_bgei(f, JR, X); + ejit_movi(f, RR, 0); + ejit_movi(f, KR, 0); struct ejit_label in = ejit_label(f); struct ejit_reloc inner = ejit_bgei(f, KR, X); @@ -78,20 +81,9 @@ static struct ejit_func *compile() ejit_lshi(f, BX, BX, 2); EJIT_LDXR(f, int, BX, BR, BX); - /* C[i][j] at addr 4 * (i * 400 + j) */ - ejit_movi(f, TMP, X); - ejit_mulr(f, CX, IR, TMP); - ejit_addr(f, CX, CX, JR); - ejit_movi(f, TMP, 4); - /* reuse address */ - ejit_lshi(f, TMP, CX, 2); - EJIT_LDXR(f, int, CX, CR, TMP); - - ejit_mulr(f, AX, AX, BX); - ejit_addr(f, CX, CX, AX); - - /* store result */ - EJIT_STXR(f, int, CX, CR, TMP); + /* R += A[i][k] * B[k][j] */ + ejit_mulr(f, TMP, AX, BX); + ejit_addr(f, RR, RR, TMP); /* increment inner */ ejit_addi(f, KR, KR, 1); @@ -99,6 +91,15 @@ static struct ejit_func *compile() /* end of inner */ ejit_patch(f, inner, ejit_label(f)); + /* C[i][j] at addr 4 * (i * 400 + j) */ + ejit_movi(f, TMP, X); + ejit_mulr(f, CX, IR, TMP); + ejit_addr(f, CX, CX, JR); + ejit_lshi(f, CX, CX, 2); + + /* C[i][j] = R */ + EJIT_STXR(f, int, RR, CX, CR); + /* increment midder */ ejit_addi(f, JR, JR, 1); ejit_patch(f, ejit_jmp(f), mid); diff --git a/include/ejit/ejit.h b/include/ejit/ejit.h index d3bc6c3..9343110 100644 --- a/include/ejit/ejit.h +++ b/include/ejit/ejit.h @@ -564,6 +564,16 @@ static inline void ejit_ldxi_ptr(struct ejit_func *s, struct ejit_gpr r0, abort(); } +static inline void ejit_ldi_ptr(struct ejit_func *s, struct ejit_gpr r0, void *p) +{ + if (sizeof(void *) == sizeof(int64_t)) + ejit_ldi_u64(s, r0, p); + else if (sizeof(void *) == sizeof(int32_t)) + ejit_ldi_u32(s, r0, p); + else + abort(); +} + static inline void ejit_ldxi_label(struct ejit_func *s, struct ejit_gpr r0, struct ejit_gpr r1, int64_t o) { @@ -571,6 +581,20 @@ static inline void ejit_ldxi_label(struct ejit_func *s, struct ejit_gpr r0, ejit_ldxi_ptr(s, r0, r1, o); } +#define EJIT_LDI(f, t, r0, o) \ + _Generic((t)(0), \ + int8_t: ejit_ldi_i8, \ + uint8_t: ejit_ldi_u8, \ + int16_t: ejit_ldi_i16, \ + uint16_t: ejit_ldi_u16, \ + int32_t: ejit_ldi_i32, \ + uint32_t: ejit_ldi_u32, \ + int64_t: ejit_ldi_i64, \ + uint64_t: ejit_ldi_u64, \ + float: ejit_ldi_f, \ + double: ejit_ldi_d, \ + default: ejit_ldi_ptr)((f), (r0), (o)) + #define EJIT_LDXI(f, t, r0, r1, o) \ _Generic((t)(0), \ int8_t: ejit_ldxi_i8, \ diff --git a/scripts/makefile b/scripts/makefile index dbb7a1c..6e3b972 100644 --- a/scripts/makefile +++ b/scripts/makefile @@ -45,7 +45,7 @@ WARNFLAGS := -Wall -Wextra COMPILE_FLAGS := $(CFLAGS) $(WARNFLAGS) $(OPTFLAGS) $(LTOFLAGS) \ $(OBFLAGS) $(DEBUGFLAGS) -INCLUDE_FLAGS := -I include +INCLUDE_FLAGS := -I include -I deps/conts/include COMPILE = $(COMPILER) \ $(COMPILE_FLAGS) $(DEPFLAGS) $(INCLUDE_FLAGS) diff --git a/src/common.h b/src/common.h index dc970f0..41a17cf 100644 --- a/src/common.h +++ b/src/common.h @@ -4,21 +4,29 @@ #include <ejit/ejit.h> #include <stdbool.h> +struct barrier_tuple { + size_t start, end; +}; + +#define VEC_TYPE struct barrier_tuple +#define VEC_NAME barriers +#include <conts/vec.h> + #define VEC_TYPE struct ejit_arg #define VEC_NAME args -#include "vec.h" +#include <conts/vec.h> #define VEC_TYPE int64_t #define VEC_NAME gprs -#include "vec.h" +#include <conts/vec.h> #define VEC_TYPE double #define VEC_NAME fprs -#include "vec.h" +#include <conts/vec.h> #define VEC_TYPE size_t #define VEC_NAME labels -#include "vec.h" +#include <conts/vec.h> enum ejit_opcode { EJIT_OP_MOVI, @@ -263,11 +271,11 @@ struct ejit_insn { #define VEC_TYPE struct ejit_insn #define VEC_NAME insns -#include "vec.h" +#include <conts/vec.h> #define VEC_TYPE enum ejit_type #define VEC_NAME types -#include "vec.h" +#include <conts/vec.h> struct fpr_stat { struct ejit_fpr f; @@ -276,7 +284,7 @@ struct fpr_stat { #define VEC_NAME fpr_stats #define VEC_TYPE struct fpr_stat -#include "vec.h" +#include <conts/vec.h> struct gpr_stat { struct ejit_gpr r; @@ -285,12 +293,13 @@ struct gpr_stat { #define VEC_NAME gpr_stats #define VEC_TYPE struct gpr_stat -#include "vec.h" +#include <conts/vec.h> struct ejit_func { struct types sign; struct insns insns; struct labels labels; + struct barriers barriers; enum ejit_type rtype; struct gpr_stats gpr; diff --git a/src/compile/compile.c b/src/compile/compile.c index 5432bc1..2c72b91 100644 --- a/src/compile/compile.c +++ b/src/compile/compile.c @@ -4,7 +4,7 @@ #define VEC_TYPE jit_operand_t #define VEC_NAME operands -#include "../vec.h" +#include <conts/vec.h> struct reloc_helper { jit_reloc_t r; @@ -13,11 +13,11 @@ struct reloc_helper { #define VEC_TYPE struct reloc_helper #define VEC_NAME relocs -#include "../vec.h" +#include <conts/vec.h> #define VEC_TYPE jit_addr_t #define VEC_NAME addrs -#include "../vec.h" +#include <conts/vec.h> /* skip assertions since we know they must be valid due to type checking earlier */ static long checked_run_i(struct ejit_func *f, size_t argc, struct ejit_arg args[argc]) @@ -1880,13 +1880,11 @@ static jit_off_t type_offset(struct ejit_insn i) static void fixup_operands(struct operands *operands, size_t fixup) { - foreach_vec(i, *operands) { - jit_operand_t op = *operands_at(operands, i); - if (op.kind != JIT_OPERAND_KIND_MEM) + foreach(operands, op, operands) { + if (op->kind != JIT_OPERAND_KIND_MEM) continue; - op.loc.mem.offset += fixup; - *operands_at(operands, i) = op; + op->loc.mem.offset += fixup; } } @@ -1922,17 +1920,22 @@ static void compile_trampoline(struct ejit_func *f, jit_state_t *j) struct operands args = operands_create(0); - foreach_vec(ii, f->insns) { - struct ejit_insn i = *insns_at(&f->insns, ii); - switch (i.op) { + foreach(insns, i, &f->insns) { + switch (i->op) { case EJIT_OP_PARAM: { - jit_operand_t p = jit_operand_mem(jit_abi_from(i.r1), JIT_R1, arg_offset(i)); + jit_operand_t p = jit_operand_mem( + jit_abi_from(i->r1), + JIT_R1, arg_offset(*i)); + operands_append(&args, p); break; } case EJIT_OP_PARAM_F: { - jit_operand_t p = jit_operand_mem(jit_abi_from(i.r1), JIT_R1, arg_offset(i)); + jit_operand_t p = jit_operand_mem( + jit_abi_from(i->r1), + JIT_R1, arg_offset(*i)); + operands_append(&args, p); break; } @@ -1971,17 +1974,21 @@ static void resolve_top_reloc(jit_state_t *j, struct relocs *relocs, struct addr assert(a); jit_patch_there(j, r, a); relocs_pop(relocs); + + /* hope this turns into a tailcall */ + if (relocs_len(relocs)) + resolve_top_reloc(j, relocs, addrs, ii); } static void resolve_relocs(jit_state_t *j, struct relocs *relocs, struct addrs *addrs, size_t ii) { - foreach_vec(ri, *relocs) { - struct reloc_helper h = *relocs_at(relocs, ri); - if (h.to != ii) + for (size_t ri = 0; ri < relocs_len(relocs); ++ri) { + struct reloc_helper *h = relocs_at(relocs, ri); + if (h->to != ii) continue; jit_addr_t a = *addrs_at(addrs, ii); - jit_reloc_t r = h.r; + jit_reloc_t r = h->r; assert(a); jit_patch_there(j, r, a); @@ -2052,16 +2059,16 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena, size_t frame = jit_enter_jit_abi(j, gprs, fprs, 0); size_t stack = jit_align_stack(j, stack_size(f)); - struct operands src = operands_create(); - struct operands dst = operands_create(); - struct operands direct = operands_create(); + struct operands src = operands_create(f->max_args); + struct operands dst = operands_create(f->max_args); + struct operands direct = operands_create(f->max_args); - struct relocs relocs = relocs_create(); - struct addrs addrs = addrs_create(); + struct relocs relocs = relocs_create(labels_len(&f->labels)); + struct addrs addrs = addrs_create(0); addrs_reserve(&addrs, insns_len(&f->insns)); size_t label = 0; - foreach_vec(ii, f->insns) { + for (size_t ii = 0; ii < insns_len(&f->insns); ++ii) { /* if we've hit a label, add it to our vector of label addresses */ if (label < labels_len(&f->labels)) { if (*labels_at(&f->labels, label) == ii) { @@ -2530,7 +2537,7 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena, /* now move args into place */ jit_operand_t args[2] = {}; - foreach_vec(oi, direct) { + for (size_t oi = 0; oi < operands_len(&direct); ++oi) { args[oi] = *operands_at(&direct, oi); } @@ -2665,13 +2672,12 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena, case EJIT_OP_CALLI: { save_caller_save_regs(f, j); - - struct ejit_func *f = (struct ejit_func *)i.p; + struct ejit_func *t = (struct ejit_func *)i.p; #if __WORDSIZE != 64 - assert(f->rtype != EJIT_INT64 && f->rtype != EJIT_UINT64); + assert(t->rtype != EJIT_INT64 && t->rtype != EJIT_UINT64); #endif - if (f && f->direct_call) { - jit_calli(j, f->direct_call, operands_len(&direct), direct.buf); + if (t && t->direct_call) { + jit_calli(j, t->direct_call, operands_len(&direct), direct.buf); restore_caller_save_regs(f, j); operands_reset(&src); @@ -2690,7 +2696,7 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena, }; void *call = NULL; - switch (f->rtype) { + switch (t->rtype) { case EJIT_INT64: case EJIT_UINT64: call = checked_run_l; break; case EJIT_FLOAT: call = checked_run_f; break; @@ -2843,7 +2849,7 @@ struct alive_slot { #define VEC_NAME alive #define VEC_TYPE struct alive_slot -#include "../vec.h" +#include <conts/vec.h> static int spill_cost_sort(struct alive_slot *a, struct alive_slot *b) { @@ -2853,6 +2859,38 @@ static int spill_cost_sort(struct alive_slot *a, struct alive_slot *b) return a->cost < b->cost; } +/* sort barriers according to starting address, smallest to largest */ +static int barrier_sort(struct barrier_tuple *a, struct barrier_tuple *b) +{ + if (a->start < b->start) + return -1; + + if (a->start > b->start) + return 1; + + if (a->end < b->end) + return -1; + + return 1; +} + +/* sort gprs in order of starting address */ +static int gpr_start_sort(struct gpr_stat *a, struct gpr_stat *b) +{ + if (a->start < b->start) + return -1; + + return 1; +} + +static int fpr_start_sort(struct fpr_stat *a, struct fpr_stat *b) +{ + if (a->start < b->start) + return -1; + + return 1; +} + /* slightly more parameters than I would like but I guess it's fine */ static void calculate_alive(struct alive *alive, size_t idx, size_t prio, size_t start, size_t end, size_t *rno, @@ -2871,7 +2909,7 @@ static void calculate_alive(struct alive *alive, size_t idx, long max_cost_idx = -1; size_t max_cost = 0; long counter = 0; - foreach_vec(ai, *alive) { + for (size_t ai = 0; ai < alive_len(alive); ++ai) { /* skip oneshot */ if (ai == 0) goto next; @@ -2916,11 +2954,45 @@ static int gpr_dead(void *regs, size_t idx, size_t start) static void linear_gpr_alloc(struct ejit_func *f) { - foreach_vec(gi, f->gpr) { + for (size_t gi = 0; gi < gpr_stats_len(&f->gpr); ++gi) { gpr_stats_at(&f->gpr, gi)->rno = gi; } } +static void extend_gpr_lifetime(struct gpr_stat *gpr, struct barriers *barriers, size_t bi) +{ + if (bi >= barriers_len(barriers)) + return; + + struct barrier_tuple *barrier = barriers_at(barriers, bi); + if (gpr->end < barrier->start) + return; + + /* we cross a barrier */ + gpr->end = gpr->end > barrier->end ? gpr->end : barrier->end; + + /* check if we cross the next barrier as well, + * hopefully gets optimized into a tail call */ + return extend_gpr_lifetime(gpr, barriers, bi + 1); +} + +static void extend_fpr_lifetime(struct fpr_stat *fpr, struct barriers *barriers, size_t bi) +{ + if (bi >= barriers_len(barriers)) + return; + + struct barrier_tuple *barrier = barriers_at(barriers, bi); + if (fpr->end < barrier->start) + return; + + /* we cross a barrier */ + fpr->end = fpr->end > barrier->end ? fpr->end : barrier->end; + + /* check if we cross the next barrier as well, + * hopefully gets optimized into a tail call */ + return extend_fpr_lifetime(fpr, barriers, bi + 1); +} + /* there's a fair bit of repetition between this and the gpr case, hmm */ static void assign_gprs(struct ejit_func *f) { @@ -2928,22 +3000,44 @@ static void assign_gprs(struct ejit_func *f) if (gpr_stats_len(&f->gpr) <= physgpr_count()) return linear_gpr_alloc(f); - struct alive alive = alive_create(gpr_stats_len(&f->gpr)); + /* create temporary buffer to sort */ + struct gpr_stats gprs = gpr_stats_create(gpr_stats_len(&f->gpr)); + foreach(gpr_stats, g, &f->gpr) { + gpr_stats_append(&gprs, *g); + } + + gpr_stats_sort(&gprs, gpr_start_sort); + + struct alive alive = alive_create(gpr_stats_len(&gprs)); /* special oneshot register class */ struct alive_slot a = {.r = -1, .cost = 0, .idx = 0}; alive_append(&alive, a); - foreach_vec(gi, f->gpr) { - struct gpr_stat *gpr = gpr_stats_at(&f->gpr, gi); + /* barrier index, keeps track of earliest possible barrier we can be + * dealing with. Since register start addresses grow upward, we can + * fairly easily keep track of which barrier a register cannot cross */ + size_t bi = 0; + for (size_t gi = 0; gi < gpr_stats_len(&gprs); ++gi) { + struct gpr_stat *gpr = gpr_stats_at(&gprs, gi); + if (gpr->prio == 0) + continue; + + extend_gpr_lifetime(gpr, &f->barriers, bi); + if (bi < barriers_len(&f->barriers)) { + struct barrier_tuple *barrier = barriers_at(&f->barriers, bi); + if (gpr->start >= barrier->start) + bi++; + } + calculate_alive(&alive, gi, gpr->prio, gpr->start, gpr->end, &gpr->rno, - &f->gpr, gpr_dead); + &gprs, gpr_dead); } /* sort so that the highest spill cost register classes are at the front and * as such more likely to be placed in registers */ - alive_sort(&alive, (vec_comp_t)spill_cost_sort); + alive_sort(&alive, (alive_comp_t)spill_cost_sort); /* update remapping info */ for(size_t i = 0; i < alive_len(&alive); ++i) { @@ -2952,13 +3046,18 @@ static void assign_gprs(struct ejit_func *f) } /* remap locations */ - for (size_t i = 0; i < gpr_stats_len(&f->gpr); ++i) { - struct gpr_stat *gpr = gpr_stats_at(&f->gpr, i); + for (size_t i = 0; i < gpr_stats_len(&gprs); ++i) { + struct gpr_stat *gpr = gpr_stats_at(&gprs, i); + if (gpr->prio == 0) + continue; + struct alive_slot *a = alive_at(&alive, gpr->rno); - gpr->rno = a->remap; + struct gpr_stat *orig = gpr_stats_at(&f->gpr, gpr->r.r); + orig->rno = a->remap; } alive_destroy(&alive); + gpr_stats_destroy(&gprs); } static int fpr_dead(void *regs, size_t idx, size_t start) @@ -2969,7 +3068,7 @@ static int fpr_dead(void *regs, size_t idx, size_t start) static void linear_fpr_alloc(struct ejit_func *f) { - foreach_vec(fi, f->fpr) { + for (size_t fi = 0; fi < fpr_stats_len(&f->fpr); ++fi) { fpr_stats_at(&f->fpr, fi)->fno = fi; } } @@ -2980,22 +3079,40 @@ static void assign_fprs(struct ejit_func *f) if (fpr_stats_len(&f->fpr) <= physfpr_count()) return linear_fpr_alloc(f); + struct fpr_stats fprs = fpr_stats_create(fpr_stats_len(&f->fpr)); + foreach(fpr_stats, r, &f->fpr) { + fpr_stats_append(&fprs, *r); + } + + fpr_stats_sort(&fprs, fpr_start_sort); + struct alive alive = alive_create(fpr_stats_len(&f->fpr)); /* special oneshot register class */ struct alive_slot a = {.r = -1, .cost = 0, .idx = 0}; alive_append(&alive, a); - foreach_vec(fi, f->fpr) { - struct fpr_stat *fpr = fpr_stats_at(&f->fpr, fi); + size_t bi = 0; + for (size_t fi = 0; fi < fpr_stats_len(&fprs); ++fi) { + struct fpr_stat *fpr = fpr_stats_at(&fprs, fi); + if (fpr->prio == 0) + continue; + + extend_fpr_lifetime(fpr, &f->barriers, bi); + if (bi < barriers_len(&f->barriers)) { + struct barrier_tuple *barrier = barriers_at(&f->barriers, bi); + if (fpr->start >= barrier->start) + bi++; + } + calculate_alive(&alive, fi, fpr->prio, fpr->start, fpr->end, &fpr->fno, - &f->fpr, fpr_dead); + &fprs, fpr_dead); } /* sort so that the highest spill cost register classes are at the front and * as such more likely to be placed in registers */ - alive_sort(&alive, (vec_comp_t)spill_cost_sort); + alive_sort(&alive, (alive_comp_t)spill_cost_sort); /* update remapping info */ for(size_t i = 0; i < alive_len(&alive); ++i) { @@ -3004,13 +3121,18 @@ static void assign_fprs(struct ejit_func *f) } /* remap locations */ - for (size_t i = 0; i < fpr_stats_len(&f->fpr); ++i) { - struct fpr_stat *fpr = fpr_stats_at(&f->fpr, i); + for (size_t i = 0; i < fpr_stats_len(&fprs); ++i) { + struct fpr_stat *fpr = fpr_stats_at(&fprs, i); + if (fpr->prio == 0) + continue; + struct alive_slot *a = alive_at(&alive, fpr->fno); - fpr->fno = a->remap; + struct fpr_stat *orig = fpr_stats_at(&f->fpr, fpr->f.f); + orig->fno = a->remap; } alive_destroy(&alive); + fpr_stats_destroy(&fprs); } static size_t align_up(size_t a, size_t n) @@ -3032,6 +3154,9 @@ bool ejit_compile(struct ejit_func *f, bool use_64, bool im_scawed) if (!init_jit()) return false; + /* sort barriers so they can be used to extend register lifetimes in + * loops */ + barriers_sort(&f->barriers, barrier_sort); assign_gprs(f); assign_fprs(f); @@ -338,11 +338,12 @@ struct ejit_func *ejit_create_func(enum ejit_type rtype, size_t argc, f->rtype = rtype; - f->sign = types_create(); - f->insns = insns_create(); - f->labels = labels_create(); - f->gpr = gpr_stats_create(); - f->fpr = fpr_stats_create(); + f->sign = types_create(0); + f->insns = insns_create(0); + f->labels = labels_create(0); + f->barriers = barriers_create(0); + f->gpr = gpr_stats_create(0); + f->fpr = fpr_stats_create(0); f->arena = NULL; f->direct_call = NULL; f->extern_call = NULL; @@ -416,12 +417,10 @@ void ejit_select_compile_func(struct ejit_func *f, size_t gpr, size_t fpr, /* just get labels, don't actually run anything yet */ ejit_run(f, 0, NULL, &labels); - foreach_vec(ii, f->insns) { - struct ejit_insn i = *insns_at(&f->insns, ii); - void *addr = labels[i.op]; + foreach(insns, i, &f->insns) { + void *addr = labels[i->op]; assert(addr); - i.addr = addr; - *insns_at(&f->insns, ii) = i; + i->addr = addr; } /* doesn't really matter what we put here as long as it isn't 0 */ @@ -436,6 +435,7 @@ void ejit_destroy_func(struct ejit_func *f) types_destroy(&f->sign); insns_destroy(&f->insns); labels_destroy(&f->labels); + barriers_destroy(&f->barriers); gpr_stats_destroy(&f->gpr); fpr_stats_destroy(&f->fpr); free(f); @@ -454,6 +454,12 @@ void ejit_patch(struct ejit_func *f, struct ejit_reloc r, struct ejit_label l) /** @todo some assert that checks the opcode? */ i.r0 = l.addr; *insns_at(&f->insns, r.insn) = i; + + struct barrier_tuple tuple = { + .start = r.insn > l.addr ? l.addr : r.insn, + .end = r.insn > l.addr ? r.insn : l.addr + }; + barriers_append(&f->barriers, tuple); } void ejit_taili(struct ejit_func *s, struct ejit_func *f, diff --git a/src/vec.h b/src/vec.h deleted file mode 100644 index f5a6fd9..0000000 --- a/src/vec.h +++ /dev/null @@ -1,119 +0,0 @@ -#ifndef VEC_TYPE -#error "Need vector type" -#endif - -#ifndef VEC_NAME -#error "Need vector name" -#endif - -#include <stddef.h> -#include <stdlib.h> -#include <string.h> -#include <assert.h> - -#define UNDERSCORE2(a, b) a##_##b -#define UNDERSCORE(a, b) UNDERSCORE2(a, b) -#define VEC(n) UNDERSCORE(VEC_NAME, n) - - -#define VEC_STRUCT VEC_NAME -struct VEC_STRUCT { - size_t n; - size_t s; - VEC_TYPE *buf; -}; - -#ifndef VEC_H -#define VEC_H - -#define foreach_vec(iter, v) \ - for (size_t iter = 0; iter < (v).n; ++iter) - -#define vec_uninit(v) \ - (v.buf == NULL) -#endif - -static inline struct VEC_STRUCT VEC(create)() -{ - const size_t s = 8; - return (struct VEC_STRUCT) { - .n = 0, - .s = s, - .buf = malloc(s * sizeof(VEC_TYPE)), - }; -} - -static inline size_t VEC(len)(struct VEC_STRUCT *v) -{ - return v->n; -} - -static inline VEC_TYPE *VEC(at)(struct VEC_STRUCT *v, size_t i) -{ - assert(i < v->n && "out of vector bounds"); - return &v->buf[i]; -} - -static inline VEC_TYPE *VEC(back)(struct VEC_STRUCT *v) -{ - assert(v->n); - return &v->buf[v->n - 1]; -} - -static inline VEC_TYPE *VEC(pop)(struct VEC_STRUCT *v) -{ - assert(v->n && "attempting to pop empty vector"); - v->n--; - return &v->buf[v->n]; -} - -static inline void VEC(append)(struct VEC_STRUCT *v, VEC_TYPE n) -{ - v->n++; - if (v->n >= v->s) { - v->s *= 2; - v->buf = realloc(v->buf, v->s * sizeof(VEC_TYPE)); - assert(v->buf); - } - - v->buf[v->n - 1] = n; -} - -static inline void VEC(reset)(struct VEC_STRUCT *v) -{ - v->n = 0; -} - -static inline void VEC(destroy)(struct VEC_STRUCT *v) { - free(v->buf); -} - -typedef int (*vec_comp_t)(void *a, void *b); -static inline void VEC(sort)(struct VEC_STRUCT *v, vec_comp_t comp) -{ - qsort(v->buf, v->n, sizeof(VEC_TYPE), (__compar_fn_t)comp); -} - -static inline void VEC(reserve)(struct VEC_STRUCT *v, size_t n) -{ - if (v->n >= n) - return; - - v->n = n; - if (v->s >= v->n) - return; - - while (v->s < v->n) - v->s *= 2; - - v->buf = realloc(v->buf, v->s * sizeof(VEC_TYPE)); -} - -static inline void VEC(shrink)(struct VEC_STRUCT *v, size_t n) -{ - /* assert(v->n >= n); */ - v->n = n; -} - -#undef VEC_TYPE -#undef VEC_NAME diff --git a/tests/z_matrix_mult.c b/tests/z_matrix_mult.c new file mode 100644 index 0000000..8cfba69 --- /dev/null +++ b/tests/z_matrix_mult.c @@ -0,0 +1,143 @@ +#include <ejit/ejit.h> +#include <assert.h> +#include "do_jit.h" + +/* register allocator had a bug where registers whose last use was midway + * through a loop might get overwritten, here's one case that exposed the bug */ + +#define X 10 +int A[X][X]; +int B[X][X]; +int C[X][X]; + +bool do_jit = false; + +static void init_matrices(int A[X][X], int B[X][X]) +{ + int counter = 1; + for (size_t i = 0; i < X; ++i) + for (size_t j = 0; j < X; ++j) { + A[i][j] = counter; + B[i][j] = counter; + C[i][j] = 0; + + counter++; + } +} + +static int hash(int C[X][X]) +{ + int h = 0; + for (size_t i = 0; i < X; ++i) + for (size_t j = 0; j < X; ++j) { + h += C[i][j]; + } + + return h; +} + +static struct ejit_func *compile() +{ +#define IR EJIT_GPR(0) +#define JR EJIT_GPR(1) +#define KR EJIT_GPR(2) +#define AR EJIT_GPR(3) +#define BR EJIT_GPR(4) +#define CR EJIT_GPR(5) +#define AX EJIT_GPR(6) +#define BX EJIT_GPR(7) +#define CX EJIT_GPR(8) +#define RR EJIT_GPR(9) +#define TMP EJIT_GPR(10) + + struct ejit_operand args[3] = { + EJIT_OPERAND_GPR(3, EJIT_TYPE(int *)), /* A */ + EJIT_OPERAND_GPR(4, EJIT_TYPE(int *)), /* B */ + EJIT_OPERAND_GPR(5, EJIT_TYPE(int *)), /* C */ + }; + struct ejit_func *f = ejit_create_func(EJIT_VOID, 3, args); + + ejit_movi(f, IR, 0); + struct ejit_label out = ejit_label(f); + struct ejit_reloc outer = ejit_bgei(f, IR, X); + + ejit_movi(f, JR, 0); + struct ejit_label mid = ejit_label(f); + struct ejit_reloc midder = ejit_bgei(f, JR, X); + + ejit_movi(f, RR, 0); + + ejit_movi(f, KR, 0); + struct ejit_label in = ejit_label(f); + struct ejit_reloc inner = ejit_bgei(f, KR, X); + + /* A[i][k] at addr 4 * (i * 400 + k) */ + ejit_movi(f, TMP, X); + ejit_mulr(f, AX, IR, TMP); + ejit_addr(f, AX, AX, KR); + ejit_movi(f, TMP, 4); + ejit_lshi(f, AX, AX, 2); + EJIT_LDXR(f, int, AX, AR, AX); + + /* B[k][j] at addr 4 * (k * 400 + j) */ + ejit_movi(f, TMP, X); + ejit_mulr(f, BX, KR, TMP); + ejit_addr(f, BX, BX, JR); + ejit_movi(f, TMP, 4); + ejit_lshi(f, BX, BX, 2); + EJIT_LDXR(f, int, BX, BR, BX); + + /* R += A[i][k] * B[k][j] */ + ejit_mulr(f, TMP, AX, BX); + ejit_addr(f, RR, RR, TMP); + + /* increment inner */ + ejit_addi(f, KR, KR, 1); + ejit_patch(f, ejit_jmp(f), in); + /* end of inner */ + ejit_patch(f, inner, ejit_label(f)); + + /* C[i][j] at addr 4 * (i * 400 + j) */ + ejit_movi(f, TMP, X); + ejit_mulr(f, CX, IR, TMP); + ejit_addr(f, CX, CX, JR); + ejit_lshi(f, CX, CX, 2); + + /* C[i][j] = R */ + EJIT_STXR(f, int, RR, CX, CR); + + /* increment midder */ + ejit_addi(f, JR, JR, 1); + ejit_patch(f, ejit_jmp(f), mid); + /* end of midder */ + ejit_patch(f, midder, ejit_label(f)); + + /* increment outer */ + ejit_addi(f, IR, IR, 1); + ejit_patch(f, ejit_jmp(f), out); + /* end of outer */ + ejit_patch(f, outer, ejit_label(f)); + + /* return */ + ejit_ret(f); + ejit_select_compile_func(f, 11, 0, EJIT_USE64(long), do_jit, true); + return f; +} + +int main(int argc, char *argv[]) +{ + (void)argv; + do_jit = argc > 1; + + init_matrices(A, B); + struct ejit_func *f = compile(); + struct ejit_arg args[3] = { + EJIT_ARG(A, int *), + EJIT_ARG(B, int *), + EJIT_ARG(C, int *) + }; + ejit_run_func(f, 3, args); + assert(hash(C) == 2632750); + ejit_destroy_func(f); + return 0; +} |