aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKimplul <kimi.h.kuparinen@gmail.com>2025-04-01 22:16:25 +0300
committerKimplul <kimi.h.kuparinen@gmail.com>2025-04-01 22:16:25 +0300
commit478c92b425eca53a0d884fb8f5dea8d769016858 (patch)
treed5d85feb1b796530429221d51ab88c8c7f1a0b55
parent4135845b93d5c0eab23ad5da526b03a911878d67 (diff)
downloadejit-478c92b425eca53a0d884fb8f5dea8d769016858.tar.gz
ejit-478c92b425eca53a0d884fb8f5dea8d769016858.zip
expose sqrt
+ Requires linking with libm in some cases, which is fine I suppose, but kind of annoying
-rw-r--r--include/ejit/ejit.h3
-rwxr-xr-xscripts/gen-tests2
-rw-r--r--src/common.h3
-rw-r--r--src/compile/compile.c20
-rw-r--r--src/ejit.c10
-rw-r--r--src/interp.c11
-rw-r--r--tests/sqrtr_d.c23
-rw-r--r--tests/sqrtr_f.c23
8 files changed, 94 insertions, 1 deletions
diff --git a/include/ejit/ejit.h b/include/ejit/ejit.h
index d4bb725..920fdc5 100644
--- a/include/ejit/ejit.h
+++ b/include/ejit/ejit.h
@@ -805,6 +805,9 @@ void ejit_truncr_d_32(struct ejit_func *s, struct ejit_gpr r0,
void ejit_truncr_d_64(struct ejit_func *s, struct ejit_gpr r0,
struct ejit_fpr r1);
+void ejit_sqrtr_f(struct ejit_func *s, struct ejit_fpr r0, struct ejit_fpr r1);
+void ejit_sqrtr_d(struct ejit_func *s, struct ejit_fpr r0, struct ejit_fpr r1);
+
struct ejit_reloc ejit_bltr(struct ejit_func *s, struct ejit_gpr r0,
struct ejit_gpr r1);
struct ejit_reloc ejit_bner(struct ejit_func *s, struct ejit_gpr r0,
diff --git a/scripts/gen-tests b/scripts/gen-tests
index 5521006..deac247 100755
--- a/scripts/gen-tests
+++ b/scripts/gen-tests
@@ -11,5 +11,5 @@ do
echo "${dep}:" >> tests.mk
echo "-include ${dep}" >> tests.mk
echo "${exe}: ${s} libejit.a" >> tests.mk
- echo " \$(COMPILE_TEST) ${s} libejit.a -o ${exe}" >> tests.mk
+ echo " \$(COMPILE_TEST) ${s} libejit.a -o ${exe} -lm" >> tests.mk
done
diff --git a/src/common.h b/src/common.h
index 6a3c754..c690f8f 100644
--- a/src/common.h
+++ b/src/common.h
@@ -142,6 +142,9 @@ enum ejit_opcode {
EJIT_OP_TRUNCR_F_32,
EJIT_OP_TRUNCR_F_64,
+ EJIT_OP_SQRTR_F,
+ EJIT_OP_SQRTR_D,
+
EJIT_OP_EQR,
EJIT_OP_NER,
EJIT_OP_GTR,
diff --git a/src/compile/compile.c b/src/compile/compile.c
index b90ee54..580b7fa 100644
--- a/src/compile/compile.c
+++ b/src/compile/compile.c
@@ -1326,6 +1326,23 @@ static void compile_truncr_f_32(struct ejit_func *f, jit_state_t *j,
#endif
}
+static void compile_sqrtr_f(struct ejit_func *f, jit_state_t *j,
+ struct ejit_insn i)
+{
+ jit_fpr_t r0 = getfpr(f, i.r0, 0);
+ jit_fpr_t r1 = getloc_f(f, j, i.r1, 1);
+ jit_sqrtr_f(j, r0, r1);
+ putloc_d(f, j, i.r0, r0);
+}
+
+static void compile_sqrtr_d(struct ejit_func *f, jit_state_t *j,
+ struct ejit_insn i)
+{
+ jit_fpr_t r0 = getfpr(f, i.r0, 0);
+ jit_fpr_t r1 = getloc_d(f, j, i.r1, 1);
+ jit_sqrtr_d(j, r0, r1);
+ putloc_d(f, j, i.r0, r0);
+}
static void compile_reg_cmp(struct ejit_func *f, jit_state_t *j,
struct ejit_insn i,
@@ -2084,6 +2101,9 @@ static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena,
case EJIT_OP_TRUNCR_F_32: compile_truncr_f_32(f, j, i); break;
case EJIT_OP_TRUNCR_F_64: compile_truncr_f_64(f, j, i); break;
+ case EJIT_OP_SQRTR_F: compile_sqrtr_f(f, j, i); break;
+ case EJIT_OP_SQRTR_D: compile_sqrtr_d(f, j, i); break;
+
case EJIT_OP_EQR: compile_eqr(f, j, i); break;
case EJIT_OP_EQR_F: compile_eqr_f(f, j, i); break;
case EJIT_OP_EQR_D: compile_eqr_d(f, j, i); break;
diff --git a/src/ejit.c b/src/ejit.c
index e7e2ff2..2224198 100644
--- a/src/ejit.c
+++ b/src/ejit.c
@@ -1371,6 +1371,16 @@ void ejit_truncr_f_64(struct ejit_func *s, struct ejit_gpr r0,
emit_insn_orf(s, EJIT_OP_TRUNCR_F_64, r0, f1);
}
+void ejit_sqrtr_f(struct ejit_func *s, struct ejit_fpr r0, struct ejit_fpr r1)
+{
+ emit_insn_off(s, EJIT_OP_SQRTR_F, r0, r1);
+}
+
+void ejit_sqrtr_d(struct ejit_func *s, struct ejit_fpr r0, struct ejit_fpr r1)
+{
+ emit_insn_off(s, EJIT_OP_SQRTR_D, r0, r1);
+}
+
struct ejit_reloc ejit_bner(struct ejit_func *s, struct ejit_gpr r0,
struct ejit_gpr r1)
{
diff --git a/src/interp.c b/src/interp.c
index b858f26..2d9b7c7 100644
--- a/src/interp.c
+++ b/src/interp.c
@@ -147,6 +147,9 @@ union interp_ret ejit_run(struct ejit_func *f, size_t paramc, struct ejit_arg pa
[EJIT_OP_TRUNCR_F_32] = &&TRUNCR_F_32,
[EJIT_OP_TRUNCR_F_64] = &&TRUNCR_F_64,
+ [EJIT_OP_SQRTR_F] = &&SQRTR_F,
+ [EJIT_OP_SQRTR_D] = &&SQRTR_D,
+
[EJIT_OP_BNER] = &&BNER,
[EJIT_OP_BNEI] = &&BNEI,
[EJIT_OP_BNER_F] = &&BNER_F,
@@ -784,6 +787,14 @@ union interp_ret ejit_run(struct ejit_func *f, size_t paramc, struct ejit_arg pa
gpr[i.r0] = (int64_t)fpr[i.r1].f;
DISPATCH();
+ DO(SQRTR_F);
+ fpr[i.r0].f = sqrt(fpr[i.r1].f);
+ DISPATCH();
+
+ DO(SQRTR_D);
+ fpr[i.r0].d = sqrt(fpr[i.r1].d);
+ DISPATCH();
+
DO(BNER);
if (gpr[i.r1] != gpr[i.r2])
JUMP(i.r0);
diff --git a/tests/sqrtr_d.c b/tests/sqrtr_d.c
new file mode 100644
index 0000000..06e7894
--- /dev/null
+++ b/tests/sqrtr_d.c
@@ -0,0 +1,23 @@
+#include <ejit/ejit.h>
+#include <assert.h>
+#include "do_jit.h"
+
+int main(int argc, char *argv[])
+{
+ (void)argv;
+ bool do_jit = argc > 1;
+ struct ejit_operand operands[1] = {
+ EJIT_OPERAND_FPR(0, EJIT_TYPE(double)),
+ };
+ struct ejit_func *f = ejit_create_func(EJIT_TYPE(double), 1, operands);
+
+ ejit_sqrtr_d(f, EJIT_FPR(0), EJIT_FPR(0));
+ ejit_retr_d(f, EJIT_FPR(0));
+
+ ejit_select_compile_func(f, 0, 1, EJIT_USE64(double), do_jit, true);
+
+ assert(erfd1(f, EJIT_ARG( 0.0, double)) == 0.0);
+ assert(erfd1(f, EJIT_ARG( 4.0, double)) == 2.0);
+ assert(erfd1(f, EJIT_ARG(-4.0, double))
+ != erfd1(f, EJIT_ARG(-4.0, double))); // nan
+}
diff --git a/tests/sqrtr_f.c b/tests/sqrtr_f.c
new file mode 100644
index 0000000..3baa00d
--- /dev/null
+++ b/tests/sqrtr_f.c
@@ -0,0 +1,23 @@
+#include <ejit/ejit.h>
+#include <assert.h>
+#include "do_jit.h"
+
+int main(int argc, char *argv[])
+{
+ (void)argv;
+ bool do_jit = argc > 1;
+ struct ejit_operand operands[1] = {
+ EJIT_OPERAND_FPR(0, EJIT_TYPE(float)),
+ };
+ struct ejit_func *f = ejit_create_func(EJIT_TYPE(float), 1, operands);
+
+ ejit_sqrtr_f(f, EJIT_FPR(0), EJIT_FPR(0));
+ ejit_retr_f(f, EJIT_FPR(0));
+
+ ejit_select_compile_func(f, 0, 1, EJIT_USE64(float), do_jit, true);
+
+ assert(erff1(f, EJIT_ARG( 0.0, float)) == 0.0);
+ assert(erff1(f, EJIT_ARG( 4.0, float)) == 2.0);
+ assert(erff1(f, EJIT_ARG(-4.0, float))
+ != erff1(f, EJIT_ARG(-4.0, float))); // nan
+}