aboutsummaryrefslogtreecommitdiff
path: root/examples/matrix_mult.c
diff options
context:
space:
mode:
Diffstat (limited to 'examples/matrix_mult.c')
-rw-r--r--examples/matrix_mult.c35
1 files changed, 18 insertions, 17 deletions
diff --git a/examples/matrix_mult.c b/examples/matrix_mult.c
index fab2319..ff8af55 100644
--- a/examples/matrix_mult.c
+++ b/examples/matrix_mult.c
@@ -1,14 +1,14 @@
#include <stdio.h>
#include "../include/ejit/ejit.h"
-#define X 400
+#define X 10
int A[X][X];
int B[X][X];
int C[X][X];
static void init_matrices(int A[X][X], int B[X][X])
{
- int counter = 0;
+ int counter = 1;
for (size_t i = 0; i < X; ++i)
for (size_t j = 0; j < X; ++j) {
A[i][j] = counter;
@@ -41,7 +41,8 @@ static struct ejit_func *compile()
#define AX EJIT_GPR(6)
#define BX EJIT_GPR(7)
#define CX EJIT_GPR(8)
-#define TMP EJIT_GPR(9)
+#define RR EJIT_GPR(9)
+#define TMP EJIT_GPR(10)
struct ejit_operand args[3] = {
EJIT_OPERAND_GPR(3, EJIT_TYPE(int *)),
@@ -58,6 +59,8 @@ static struct ejit_func *compile()
struct ejit_label mid = ejit_label(f);
struct ejit_reloc midder = ejit_bgei(f, JR, X);
+ ejit_movi(f, RR, 0);
+
ejit_movi(f, KR, 0);
struct ejit_label in = ejit_label(f);
struct ejit_reloc inner = ejit_bgei(f, KR, X);
@@ -78,20 +81,9 @@ static struct ejit_func *compile()
ejit_lshi(f, BX, BX, 2);
EJIT_LDXR(f, int, BX, BR, BX);
- /* C[i][j] at addr 4 * (i * 400 + j) */
- ejit_movi(f, TMP, X);
- ejit_mulr(f, CX, IR, TMP);
- ejit_addr(f, CX, CX, JR);
- ejit_movi(f, TMP, 4);
- /* reuse address */
- ejit_lshi(f, TMP, CX, 2);
- EJIT_LDXR(f, int, CX, CR, TMP);
-
- ejit_mulr(f, AX, AX, BX);
- ejit_addr(f, CX, CX, AX);
-
- /* store result */
- EJIT_STXR(f, int, CX, CR, TMP);
+ /* R += A[i][k] * B[k][j] */
+ ejit_mulr(f, TMP, AX, BX);
+ ejit_addr(f, RR, RR, TMP);
/* increment inner */
ejit_addi(f, KR, KR, 1);
@@ -99,6 +91,15 @@ static struct ejit_func *compile()
/* end of inner */
ejit_patch(f, inner, ejit_label(f));
+ /* C[i][j] at addr 4 * (i * 400 + j) */
+ ejit_movi(f, TMP, X);
+ ejit_mulr(f, CX, IR, TMP);
+ ejit_addr(f, CX, CX, JR);
+ ejit_lshi(f, CX, CX, 2);
+
+ /* C[i][j] = R */
+ EJIT_STXR(f, int, RR, CX, CR);
+
/* increment midder */
ejit_addi(f, JR, JR, 1);
ejit_patch(f, ejit_jmp(f), mid);