#include <stdio.h>
#include "../include/ejit/ejit.h"

#define X 400
int A[X][X];
int B[X][X];
int C[X][X];

static void init_matrices(int A[X][X], int B[X][X])
{
	int counter = 0;
	for (size_t i = 0; i < X; ++i)
	for (size_t j = 0; j < X; ++j) {
		A[i][j] = counter;
		B[i][j] = counter;
		C[i][j] = 0;

		counter++;
	}
}

static int hash(int C[X][X])
{
	int h = 0;
	for (size_t i = 0; i < X; ++i)
	for (size_t j = 0; j < X; ++j) {
		h += C[i][j];
	}

	return h;
}

static struct ejit_func *compile()
{
#define IR EJIT_GPR(0)
#define JR EJIT_GPR(1)
#define KR EJIT_GPR(2)
#define AR EJIT_GPR(3)
#define BR EJIT_GPR(4)
#define CR EJIT_GPR(5)
#define AX EJIT_GPR(6)
#define BX EJIT_GPR(7)
#define CX EJIT_GPR(8)
#define TMP EJIT_GPR(9)

	struct ejit_operand args[3] = {
		EJIT_OPERAND_GPR(3, EJIT_TYPE(int *)),
		EJIT_OPERAND_GPR(4, EJIT_TYPE(int *)),
		EJIT_OPERAND_GPR(5, EJIT_TYPE(int *))
	};
	struct ejit_func *f = ejit_create_func(EJIT_VOID, 3, args);

	ejit_movi(f, IR, 0);
	struct ejit_label out = ejit_label(f);
	struct ejit_reloc outer = ejit_bgei(f, IR, X);

	ejit_movi(f, JR, 0);
	struct ejit_label mid = ejit_label(f);
	struct ejit_reloc midder = ejit_bgei(f, JR, X);

	ejit_movi(f, KR, 0);
	struct ejit_label in = ejit_label(f);
	struct ejit_reloc inner = ejit_bgei(f, KR, X);

	/* A[i][k] at addr 4 * (i * 400 + k) */
	ejit_movi(f, TMP, X);
	ejit_mulr(f, AX, IR, TMP);
	ejit_addr(f, AX, AX, KR);
	ejit_movi(f, TMP, 4);
	ejit_lshi(f, AX, AX, 2);
	EJIT_LDXR(f, int, AX, AR, AX);

	/* B[k][j] at addr 4 * (k * 400 + j) */
	ejit_movi(f, TMP, X);
	ejit_mulr(f, BX, KR, TMP);
	ejit_addr(f, BX, BX, JR);
	ejit_movi(f, TMP, 4);
	ejit_lshi(f, BX, BX, 2);
	EJIT_LDXR(f, int, BX, BR, BX);

	/* C[i][j] at addr 4 * (i * 400 + j) */
	ejit_movi(f, TMP, X);
	ejit_mulr(f, CX, IR, TMP);
	ejit_addr(f, CX, CX, JR);
	ejit_movi(f, TMP, 4);
	/* reuse address */
	ejit_lshi(f, TMP, CX, 2);
	EJIT_LDXR(f, int, CX, CR, TMP);

	ejit_mulr(f, AX, AX, BX);
	ejit_addr(f, CX, CX, AX);

	/* store result */
	EJIT_STXR(f, int, CX, CR, TMP);

	/* increment inner */
	ejit_addi(f, KR, KR, 1);
	ejit_patch(f, ejit_jmp(f), in);
	/* end of inner */
	ejit_patch(f, inner, ejit_label(f));

	/* increment midder */
	ejit_addi(f, JR, JR, 1);
	ejit_patch(f, ejit_jmp(f), mid);
	/* end of midder */
	ejit_patch(f, midder, ejit_label(f));

	/* increment outer */
	ejit_addi(f, IR, IR, 1);
	ejit_patch(f, ejit_jmp(f), out);
	/* end of outer */
	ejit_patch(f, outer, ejit_label(f));

	/* return */
	ejit_ret(f);
	ejit_compile_func(f);
	return f;
}

int main()
{
	init_matrices(A, B);
	struct ejit_func *f = compile();
	struct ejit_arg args[3] = {
		EJIT_ARG(A, int *),
		EJIT_ARG(B, int *),
		EJIT_ARG(C, int *)
	};
	ejit_run_func(f, 3, args);
	printf("%d\n", hash(C));
	ejit_destroy_func(f);
	return 0;
}