#include #include "../../deps/lightening/lightening/lightening.c" #include "../common.h" static void *alloc_arena(size_t size) { return mmap(NULL, size, PROT_EXEC | PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); } static void free_arena(void *arena, size_t size) { munmap(arena, size); } static size_t grploc_count(struct ejit_func *f) { return f->gpr >= jit_v_num() ? 0 : f->gpr - jit_v_num(); } static size_t frploc_count(struct ejit_func *f) { return f->fpr >= jit_vf_num() ? 0 : f->fpr - jit_vf_num(); } static size_t stack_size(struct ejit_func *f) { return grploc_count(f) * sizeof(jit_uword_t) + frploc_count(f) * sizeof(jit_float64_t); } static jit_off_t stack_loc(size_t l) { assert(l >= jit_v_num()); return (l - jit_v_num()) * sizeof(jit_uword_t); } static jit_off_t stack_loc_f(struct ejit_func *f, size_t l) { assert(l >= jit_vf_num()); return grploc_count(f) * sizeof(jit_uword_t) + (l - jit_vf_num()) * sizeof(jit_float64_t); } struct reloc_helper { jit_reloc_t r; size_t to; }; static jit_gpr_t getloc(struct ejit_func *f, jit_state_t *j, size_t l, size_t i) { (void)(f); if (l < jit_v_num()) return jit_v(l); jit_ldxi(j, jit_r(i), JIT_SP, stack_loc(l)); return jit_r(i); } static jit_gpr_t getreg(struct ejit_func *f, size_t l, size_t i) { (void)(f); if (l < jit_v_num()) return jit_v(l); return jit_r(i); } static void putloc(struct ejit_func *f, jit_state_t *j, size_t l, jit_gpr_t r) { (void)(f); if (l < jit_v_num()) { assert(jit_v(l).regno == r.regno); return; } jit_stxi(j, stack_loc(l), JIT_SP, r); } static void compile_label(jit_state_t *j, size_t ii, struct vec *labels) { vect_at(jit_addr_t, *labels, ii) = jit_address(j); } static void compile_movi(struct ejit_func *f, jit_state_t *j, struct ejit_insn i) { jit_gpr_t r = getreg(f, i.r0, 0); jit_movi(j, r, i.o); putloc(f, j, i.r0, r); } static void compile_movr(struct ejit_func *f, jit_state_t *j, struct ejit_insn i) { jit_gpr_t to = getreg(f, i.r0, 0); jit_gpr_t from = getreg(f, i.r1, 1); jit_movr(j, to, from); putloc(f, j, i.r0, to); } static void compile_addr(struct ejit_func *f, jit_state_t *j, struct ejit_insn i) { jit_gpr_t dst = getreg(f, i.r0, 0); jit_gpr_t src0 = getloc(f, j, i.r1, 1); jit_gpr_t src1 = getloc(f, j, i.r2, 2); jit_addr(j, dst, src0, src1); putloc(f, j, i.r0, dst); } static void compile_stxi64(struct ejit_func *f, jit_state_t *j, struct ejit_insn i) { jit_gpr_t r0 = getloc(f, j, i.r0, 0); jit_gpr_t r1 = getloc(f, j, i.r1, 1); jit_stxi_l(j, i.o, r1, r0); } static void compile_ldxiu64(struct ejit_func *f, jit_state_t *j, struct ejit_insn i) { jit_gpr_t r0 = getreg(f, i.r0, 0); jit_gpr_t r1 = getloc(f, j, i.r1, 1); jit_ldxi_l(j, r0, r1, i.o); putloc(f, j, i.r0, r0); } static void compile_reg_cmp(struct ejit_func *f, jit_state_t *j, struct ejit_insn i, jit_reloc_t (*bcomp)(jit_state_t *, jit_gpr_t, jit_gpr_t), long same) { jit_gpr_t r0 = getreg(f, i.r0, 0); if (i.r1 == i.r2) { jit_movi(j, r0, same); putloc(f, j, i.r0, r0); return; } jit_gpr_t r1 = getloc(f, j, i.r1, 1); jit_gpr_t r2 = getloc(f, j, i.r2, 2); jit_reloc_t branch = bcomp(j, r1, r2); /* not equal */ jit_movi(j, r0, 0); jit_reloc_t jump = jit_jmp(j); jit_patch_there(j, branch, jit_address(j)); /* equal */ jit_movi(j, r0, 1); jit_patch_there(j, jump, jit_address(j)); /* write final result */ putloc(f, j, i.r0, r0); } static void compile_eqr(struct ejit_func *f, jit_state_t *j, struct ejit_insn i) { compile_reg_cmp(f, j, i, jit_beqr, 1); } static void compile_ltr(struct ejit_func *f, jit_state_t *j, struct ejit_insn i) { compile_reg_cmp(f, j, i, jit_bltr, 0); } static void compile_bltr(struct ejit_func *f, jit_state_t *j, struct ejit_insn i, struct vec *relocs) { jit_gpr_t c0 = getloc(f, j, i.r1, 0); jit_gpr_t c1 = getloc(f, j, i.r2, 1); jit_reloc_t r = jit_bltr(j, c0, c1); struct reloc_helper h = {.r = r, .to = i.r0}; vect_append(struct reloc_helper, *relocs, &h); } static void compile_beqi(struct ejit_func *f, jit_state_t *j, struct ejit_insn i, struct vec *relocs) { jit_gpr_t r1 = getloc(f, j, i.r1, 0); jit_reloc_t r = jit_beqi(j, r1, i.o); struct reloc_helper h = {.r = r, .to = i.r0}; vect_append(struct reloc_helper, *relocs, &h); } static void compile_bnei(struct ejit_func *f, jit_state_t *j, struct ejit_insn i, struct vec *relocs) { jit_gpr_t r1 = getloc(f, j, i.r1, 0); jit_reloc_t r = jit_bnei(j, r1, i.o); struct reloc_helper h = {.r = r, .to = i.r0}; vect_append(struct reloc_helper, *relocs, &h); } static enum jit_operand_abi jit_abi_from(enum ejit_type t) { switch (t) { case EJIT_INT8: return JIT_OPERAND_ABI_INT8; case EJIT_INT16: return JIT_OPERAND_ABI_INT16; case EJIT_INT32: return JIT_OPERAND_ABI_INT32; case EJIT_INT64: return JIT_OPERAND_ABI_INT64; case EJIT_UINT8: return JIT_OPERAND_ABI_UINT8; case EJIT_UINT16: return JIT_OPERAND_ABI_UINT16; case EJIT_UINT32: return JIT_OPERAND_ABI_UINT32; case EJIT_UINT64: return JIT_OPERAND_ABI_UINT64; case EJIT_POINTER: return JIT_OPERAND_ABI_POINTER; case EJIT_FLOAT: return JIT_OPERAND_ABI_FLOAT; case EJIT_DOUBLE: return JIT_OPERAND_ABI_DOUBLE; default: } abort(); } static size_t arg_offsetof(enum ejit_type t) { switch (t) { case EJIT_INT8: return offsetof(struct ejit_arg, c); case EJIT_INT16: return offsetof(struct ejit_arg, s); case EJIT_INT32: return offsetof(struct ejit_arg, i); case EJIT_INT64: return offsetof(struct ejit_arg, l); case EJIT_UINT8: return offsetof(struct ejit_arg, uc); case EJIT_UINT16: return offsetof(struct ejit_arg, us); case EJIT_UINT32: return offsetof(struct ejit_arg, ui); case EJIT_UINT64: return offsetof(struct ejit_arg, ul); case EJIT_POINTER: return offsetof(struct ejit_arg, p); case EJIT_FLOAT: return offsetof(struct ejit_arg, f); case EJIT_DOUBLE: return offsetof(struct ejit_arg, d); default: }; abort(); } static jit_off_t arg_offset(struct ejit_insn i) { /* index of ejit_arg in stack and offset of whatever type we're dealing * with */ return (sizeof(struct ejit_arg) * i.r0) + arg_offsetof(i.r1); } static jit_off_t type_offset(struct ejit_insn i) { return (sizeof(struct ejit_arg) * i.r0) + offsetof(struct ejit_arg, type); } static void fixup_operands(struct vec *operands, size_t fixup) { foreach_vec(i, *operands) { jit_operand_t op = vect_at(jit_operand_t, *operands, i); if (op.kind != JIT_OPERAND_KIND_MEM) continue; op.loc.mem.offset += fixup; vect_at(jit_operand_t, *operands, i) = op; } } static void compile_imm_call(jit_state_t *j, struct vec *src, struct vec *dst, void *addr, size_t argc, jit_operand_t args[argc]) { /* each move is type + arg, so twofold */ size_t movec = vec_len(src) / 2; size_t fixup = jit_align_stack(j, movec * sizeof(struct ejit_arg)); fixup_operands(src, fixup); /* note, do not fix up destination! */ /* remember to move all operands */ jit_move_operands(j, dst->buf, src->buf, movec * 2); jit_calli(j, addr, argc, args); jit_shrink_stack(j, fixup); vec_reset(src); vec_reset(dst); } static size_t compile_fn_body(struct ejit_func *f, jit_state_t *j, void *arena, size_t size) { jit_begin(j, arena, size); size_t gprs = f->gpr >= jit_v_num() ? jit_v_num() : f->gpr; size_t fprs = f->fpr >= jit_vf_num() ? jit_vf_num() : f->fpr; size_t frame = jit_enter_jit_abi(j, gprs, fprs, 0); /* very important, argc we don't really do anything with but JIR_R1 * contains the argument stack! */ jit_load_args_2(j, jit_operand_gpr(JIT_OPERAND_ABI_WORD, JIT_R0), jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_R1)); size_t stack = jit_align_stack(j, stack_size(f)); struct vec src = vec_create(sizeof(jit_operand_t)); struct vec dst = vec_create(sizeof(jit_operand_t)); struct vec relocs = vec_create(sizeof(struct reloc_helper)); struct vec labels = vec_create(sizeof(jit_addr_t)); vec_reserve(&labels, vec_len(&f->insns)); size_t label = 0; foreach_vec(ii, f->insns) { /* if we've hit a label, add it to our vector of label addresses */ if (label < vec_len(&f->labels)) { if (vect_at(size_t, f->labels, label) == ii) { compile_label(j, ii, &labels); label++; } } struct ejit_insn i = vect_at(struct ejit_insn, f->insns, ii); switch (i.op) { case MOVR: compile_movr(f, j, i); break; case MOVI: compile_movi(f, j, i); break; case ADDR: compile_addr(f, j, i); break; case STXI64: compile_stxi64(f, j, i); break; case LDXIU64: compile_ldxiu64(f, j, i); break; case EQR: compile_eqr(f, j, i); break; case LTR: compile_ltr(f, j, i); break; case BLTR: compile_bltr(f, j, i, &relocs); break; case BEQI: compile_beqi(f, j, i, &relocs); break; case BNEI: compile_bnei(f, j, i, &relocs); break; case ARG: { jit_operand_t type = jit_operand_imm(JIT_OPERAND_ABI_WORD, i.r1); jit_operand_t arg; if (i.r0 < jit_v_num()) { /* regular register */ arg = jit_operand_gpr(jit_abi_from(i.r1), jit_v(i.r2)); } else { /* stack location, note that we'll fix up the SP * offset before doing the actual call */ arg = jit_operand_mem(jit_abi_from(i.r1), JIT_SP, stack_loc(i.r0)); } vec_append(&src, &type); vec_append(&src, &arg); jit_operand_t to[2] = { jit_operand_mem(JIT_OPERAND_ABI_WORD, JIT_SP, type_offset(i)), jit_operand_mem(jit_abi_from(i.r1), JIT_SP, arg_offset(i)) }; vec_append(&dst, &to[0]); vec_append(&dst, &to[1]); break; } case ESCAPEI: { jit_operand_t args[2] = { jit_operand_imm(JIT_OPERAND_ABI_WORD, vec_len(&src) / 2), jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_SP) }; compile_imm_call(j, &src, &dst, (void *)i.o, 2, args); break; } case CALLI: { jit_operand_t args[3] = { jit_operand_imm(JIT_OPERAND_ABI_POINTER, i.o), jit_operand_imm(JIT_OPERAND_ABI_WORD, vec_len(&src) / 2), jit_operand_gpr(JIT_OPERAND_ABI_POINTER, JIT_SP) }; compile_imm_call(j, &src, &dst, ejit_run_func, 3, args); break; } case RET: { jit_gpr_t r = getloc(f, j, i.r0, 0); /* R0 won't get overwritten by jit_leave_jit_abi */ jit_movr(j, JIT_R0, r); jit_shrink_stack(j, stack); jit_leave_jit_abi(j, gprs, fprs, frame); jit_retr(j, JIT_R0); break; } case END: { /* 'void' return */ jit_shrink_stack(j, stack); jit_leave_jit_abi(j, gprs, fprs, frame); jit_reti(j, 0); break; } case PARAM: { /* move from argument stack to location */ jit_operand_t from = jit_operand_mem( jit_abi_from(i.r1), JIT_R1, arg_offset(i) ); jit_operand_t to; if (i.r0 < jit_v_num()) { /* regular register */ to = jit_operand_gpr(jit_abi_from(i.r1), jit_v(i.r2)); } else { /* stack location */ to = jit_operand_mem(jit_abi_from(i.r1), JIT_SP, stack_loc(i.r2)); } vec_append(&src, &from); vec_append(&dst, &to); break; } case START: { /* parameters should be done by now */ jit_move_operands(j, dst.buf, src.buf, vec_len(&src)); /* reuse for arguments */ vec_reset(&dst); vec_reset(&src); break; } default: abort(); } } foreach_vec(ri, relocs) { struct reloc_helper h = vect_at(struct reloc_helper, relocs, ri); jit_addr_t a = vect_at(jit_addr_t, labels, h.to); jit_reloc_t r = h.r; assert(a); jit_patch_there(j, r, a); } vec_destroy(&src); vec_destroy(&dst); vec_destroy(&relocs); vec_destroy(&labels); if (jit_end(j, &size)) return 0; return size; } bool ejit_compile(struct ejit_func *f) { if (!init_jit()) return false; jit_state_t *j = jit_new_state(NULL, NULL); assert(j); void *arena = NULL; size_t size = 4096; while (1) { arena = alloc_arena(size); assert(arena); size_t required_size = compile_fn_body(f, j, arena, size); if (required_size == 0) break; free_arena(arena, size); size = required_size + 4096; } jit_destroy_state(j); f->arena = arena; f->size = size; return true; }