From ba9145b0b7af2a82c62f8dfa28807958af5d0c8d Mon Sep 17 00:00:00 2001 From: Kimplul Date: Fri, 7 Mar 2025 18:50:34 +0200 Subject: make code a bit more robust + Should be more difficult to make mistakes in the future, ejit can now automatically keep track of how many register slots are used and if 64 bit mode is required. Slight runtime overhead, but not too bad. --- examples/loop.c | 5 +- examples/matrix_mult.c | 133 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 4 deletions(-) create mode 100644 examples/matrix_mult.c (limited to 'examples') diff --git a/examples/loop.c b/examples/loop.c index f4f66ae..51ba2e9 100644 --- a/examples/loop.c +++ b/examples/loop.c @@ -20,10 +20,7 @@ int main() ejit_patch(f, r, l); - /* the highest location we used was 3, so we need to request 4 locations - * for general purpose registers in total. No floating point registers, - * so 0. */ - ejit_compile_func(f, 4, 0, true); + ejit_compile_func(f); long result = ejit_run_func(f, 0, NULL); // no args so this is fine printf("%ld\n", result); diff --git a/examples/matrix_mult.c b/examples/matrix_mult.c new file mode 100644 index 0000000..04ebb9d --- /dev/null +++ b/examples/matrix_mult.c @@ -0,0 +1,133 @@ +#include +#include "../include/ejit/ejit.h" + +#define X 400 +int A[X][X]; +int B[X][X]; +int C[X][X]; + +static void init_matrices(int A[X][X], int B[X][X]) +{ + int counter = 0; + for (size_t i = 0; i < X; ++i) + for (size_t j = 0; j < X; ++j) { + A[i][j] = counter; + B[i][j] = counter; + C[i][j] = 0; + + counter++; + } +} + +static int hash(int C[X][X]) +{ + int h = 0; + for (size_t i = 0; i < X; ++i) + for (size_t j = 0; j < X; ++j) { + h += C[i][j]; + } + + return h; +} + +static struct ejit_func *compile() +{ +#define IR EJIT_GPR(0) +#define JR EJIT_GPR(1) +#define KR EJIT_GPR(2) +#define AR EJIT_GPR(3) +#define BR EJIT_GPR(4) +#define CR EJIT_GPR(5) +#define AX EJIT_GPR(6) +#define BX EJIT_GPR(7) +#define CX EJIT_GPR(8) +#define TMP EJIT_GPR(9) + + struct ejit_operand args[3] = { + EJIT_OPERAND_GPR(3, EJIT_TYPE(int *)), + EJIT_OPERAND_GPR(4, EJIT_TYPE(int *)), + EJIT_OPERAND_GPR(5, EJIT_TYPE(int *)) + }; + struct ejit_func *f = ejit_create_func(EJIT_VOID, 3, args); + + ejit_movi(f, IR, 0); + struct ejit_label out = ejit_label(f); + struct ejit_reloc outer = ejit_bgei(f, IR, X); + + ejit_movi(f, JR, 0); + struct ejit_label mid = ejit_label(f); + struct ejit_reloc midder = ejit_bgei(f, JR, X); + + ejit_movi(f, KR, 0); + struct ejit_label in = ejit_label(f); + struct ejit_reloc inner = ejit_bgei(f, KR, X); + + /* A[i][k] at addr 4 * (i * 400 + k) */ + ejit_movi(f, TMP, 400); + ejit_mulr(f, AX, IR, TMP); + ejit_addr(f, AX, AX, KR); + ejit_movi(f, TMP, 4); + ejit_lshi(f, AX, AX, 2); + EJIT_LDXR(f, int, AX, AR, AX); + + /* B[k][j] at addr 4 * (k * 400 + j) */ + ejit_movi(f, TMP, 400); + ejit_mulr(f, BX, KR, TMP); + ejit_addr(f, BX, BX, JR); + ejit_movi(f, TMP, 4); + ejit_lshi(f, BX, BX, 2); + EJIT_LDXR(f, int, BX, BR, BX); + + /* C[i][j] at addr 4 * (i * 400 + j) */ + ejit_movi(f, TMP, 400); + ejit_mulr(f, CX, IR, TMP); + ejit_addr(f, CX, CX, JR); + ejit_movi(f, TMP, 4); + /* reuse address */ + ejit_lshi(f, TMP, CX, 2); + EJIT_LDXR(f, int, CX, CR, TMP); + + ejit_mulr(f, AX, AX, BX); + ejit_addr(f, CX, CX, AX); + + /* store result */ + EJIT_STXR(f, int, CX, CR, TMP); + + /* increment inner */ + ejit_addi(f, KR, KR, 1); + ejit_patch(f, ejit_jmp(f), in); + /* end of inner */ + ejit_patch(f, inner, ejit_label(f)); + + /* increment midder */ + ejit_addi(f, JR, JR, 1); + ejit_patch(f, ejit_jmp(f), mid); + /* end of midder */ + ejit_patch(f, midder, ejit_label(f)); + + /* increment outer */ + ejit_addi(f, IR, IR, 1); + ejit_patch(f, ejit_jmp(f), out); + /* end of outer */ + ejit_patch(f, outer, ejit_label(f)); + + /* return */ + ejit_ret(f); + ejit_compile_func(f); + return f; +} + +int main() +{ + init_matrices(A, B); + struct ejit_func *f = compile(); + struct ejit_arg args[3] = { + EJIT_ARG(A, int *), + EJIT_ARG(B, int *), + EJIT_ARG(C, int *) + }; + ejit_run_func(f, 3, args); + printf("%d\n", hash(C)); + ejit_destroy_func(f); + return 0; +} -- cgit v1.2.3