/* SPDX-License-Identifier: copyleft-next-0.3.1 */ /* Copyright 2024 Kim Kuparinen < kimi.h.kuparinen@gmail.com > */ #include #include #include #include #include #include #include #include #include /* stuff to be careful of: * * + Really easy to forget to pass the stack argument via reference, since it's * just a void* we don't get a warning for it * * + lower_call() and lower_closure_call() share a lot of the same stuff, might * want to try and merge them somehow, otherwise too easy to make a change in * one and forget the other */ /* placeholder, should probably scale according to biggest need */ #define FWD_FRAME_SIZE 1024 static inline void mangle(FILE *f, struct ast *id) { assert(id->k == AST_PROC_DEF || id->k == AST_VAR_DEF || id->k == AST_ID); assert(id->s && id->scope); fprintf(f, "%s_s%zu", id->s, id->scope->number); } static inline char *mangle2(struct ast *id) { char *buf = NULL; size_t size = 0; FILE *f = open_memstream(&buf, &size); assert(f); mangle(f, id); fclose(f); assert(buf); return buf; } static inline __attribute__((format (printf, 1, 2))) char *buildstr(const char *fmt, ...) { va_list args; va_start(args, fmt); char *buf = NULL; size_t size = 0; FILE *f = open_memstream(&buf, &size); assert(f); vfprintf(f, fmt, args); fclose(f); va_end(args); assert(buf); return buf; } /** @todo semantics in this file are a bit unclear, should probably do some kind * of "each function starts and ends on an indented empty line" or something */ #define VEC_NAME string_vec #define VEC_TYPE char * #include #define MAP_NAME proc_set #define MAP_KEY struct ast * #define MAP_TYPE struct ast * #define MAP_CMP(a, b) ((uintptr_t)(a) - (uintptr_t)(b)) #define MAP_HASH(a) ((uintptr_t)(a)) #include struct state { char *prefix; long indent; size_t uniq; FILE *ctx; FILE *code; struct ast *current; struct string_vec *decls; struct string_vec *defns; struct string_vec *types; struct proc_set *procs; }; static struct state create_state(struct state *parent) { return (struct state){ .uniq = parent->uniq, .prefix = parent->prefix, .indent = parent->indent, .current = parent->current, .ctx = parent->ctx, .code = parent->code, .decls = parent->decls, .defns = parent->defns, .types = parent->types, .procs = parent->procs, }; } static size_t uniq(struct state *state) { return state->uniq++; } static void increase_indent(struct state *state) { state->indent++; } static void decrease_indent(struct state *state) { state->indent--; } static void indent(struct state *state) { if (state->indent != 0) fprintf(state->code, "%*c", (int)(2 * state->indent), ' '); } static bool proc_lowered(struct state *state, struct ast *proc) { assert(proc->k == AST_PROC_DEF); return proc_set_find(state->procs, proc) != NULL; } static void add_proc(struct state *state, struct ast *proc) { assert(proc->k == AST_PROC_DEF); struct ast **p = proc_set_insert(state->procs, proc, proc); assert(p && *p == proc); } static void add_decl(struct state *state, char *decl) { string_vec_append(state->decls, decl); } static void add_defn(struct state *state, char *defn) { string_vec_append(state->defns, defn); } static void add_type(struct state *state, char *type) { string_vec_append(state->types, type); } static int lower_stmt_list(struct state *state, struct ast *stmt_list, bool last); static int lower_block(struct state *state, struct ast *block, bool last); static int lower_params(struct state *state, struct ast *params); static int lower_expr(struct state *state, struct ast *expr); static int lower_proc(struct state *state, struct ast *proc); static void _type_str(FILE *f, struct type *type) { assert(type); switch (type->k) { case TYPE_I8: fprintf(f, "int8_t"); break; case TYPE_U8: fprintf(f, "uint8_t"); break; case TYPE_I16: fprintf(f, "int16_t"); break; case TYPE_U16: fprintf(f, "uint16_t"); break; case TYPE_I32: fprintf(f, "int32_t"); break; case TYPE_U32: fprintf(f, "uint32_t"); break; case TYPE_I64: fprintf(f, "int64_t"); break; case TYPE_U64: fprintf(f, "uint64_t"); break; case TYPE_CLOSURE: fprintf(f, "fwd_closure_t"); break; default: internal_error("unhandled type lowering for %s", type_str(type)); abort(); break; } } static char *lower_type_str(struct type *type) { assert(type); char *type_buf = NULL; size_t type_len = 0; FILE *f = open_memstream(&type_buf, &type_len); assert(f); _type_str(f, type); fclose(f); assert(type_buf); return type_buf; } static int lower_closure_call(struct state *state, struct ast *call, struct ast *def, bool last) { char *q = buildstr("%s_call%zu", state->prefix, uniq(state)); char *args = mangle2(def); char *ctx_buf = NULL; size_t ctx_size = 0; FILE *ctx = open_memstream(&ctx_buf, &ctx_size); fprintf(ctx, "struct %s {\n fwd_start_t start;\n", q); indent(state); fprintf(state->code, "fwd_call_t %s = ctx->%s.call;\n", q, args); indent(state); fprintf(state->code, "struct %s *%s_ctx = ctx->%s.args;\n", q, q, args); int ret = 0; size_t idx = 0; bool returning = false; foreach_node(a, call_args(call)) { returning |= a->k == AST_CLOSURE; char *type = lower_type_str(a->t); fprintf(ctx, " %s a%zu;\n", type, idx); free(type); indent(state); fprintf(state->code, "%s_ctx->a%zu = ", q, idx); int ret = lower_expr(state, a); fprintf(state->code, ";\n"); if (ret) goto out; idx++; } out: if (!returning && last && ast_flags(state->current, AST_REQ_FRAME)) { indent(state); fprintf(state->code, "fwd_stack_free(&stack, ctx);\n"); } indent(state); if (last) { fprintf(state->code, "FWD_MUSTTAIL return %s(stack, %s_ctx);\n", q, q); } else { fprintf(state->code, "stack = %s_call(stack, %s_ctx);\n", q, q); } fprintf(ctx, "};\n\n"); fclose(ctx); assert(ctx_buf); add_type(state, ctx_buf); free(args); free(q); return ret; } static int lower_closure(struct state *state, struct ast *closure, char **name_out, char **args_out) { char *name = buildstr("%s_closure%zu", state->prefix, uniq(state)); char *proto = buildstr("static fwd_stack_t %s(fwd_stack_t stack, fwd_args_t args)", name); char *decl = buildstr("%s;\n", proto); char *start_of = buildstr("%s_start", name); add_decl(state, decl); char *code_buf = NULL; size_t code_size = 0; struct state new_state = create_state(state); new_state.code = open_memstream(&code_buf, &code_size); new_state.indent = 0; assert(new_state.code); fprintf(new_state.code, "%s\n{\n", proto); /** @todo unsure if this has the same address as the first element in * the frame or the last in the previous, I guess we shall soon see */ fprintf(new_state.ctx, " fwd_start_t %s;\n", start_of); int ret = lower_params(&new_state, closure_bindings(closure)); assert(ret == 0); increase_indent(&new_state); indent(&new_state); fprintf(new_state.code, "struct %s_ctx *ctx = FWD_CONTAINER_OF(args, struct %s_ctx, %s);\n", state->prefix, state->prefix, start_of); struct ast *block = closure_body(closure); ret = lower_stmt_list(&new_state, block_body(block), true); fprintf(new_state.code, "}\n\n"); fclose(new_state.code); assert(code_buf); add_defn(state, code_buf); free(proto); assert(name_out); assert(args_out); *name_out = name; *args_out = start_of; return ret; } static int lower_comparison(struct state *state, struct ast *expr) { if (lower_expr(state, comparison_left(expr))) return -1; switch (expr->k) { case AST_LE: fprintf(state->code, " <= "); break; case AST_LT: fprintf(state->code, " < "); break; default: internal_error("unhandled lowering for comparison %s", ast_str(expr->k)); abort(); } return lower_expr(state, comparison_right(expr)); } static int lower_binop(struct state *state, struct ast *expr) { if (lower_expr(state, binop_left(expr))) return -1; switch (expr->k) { case AST_SUB: fprintf(state->code, " - "); break; case AST_ADD: fprintf(state->code, " + "); break; default: internal_error("unhandled comparison %s", ast_str(expr->k)); abort(); } return lower_expr(state, binop_right(expr)); } static int lower_expr(struct state *state, struct ast *expr) { if (is_comparison(expr)) return lower_comparison(state, expr); if (is_binop(expr)) return lower_binop(state, expr); switch (expr->k) { case AST_CLOSURE: { char *name = NULL, *args = NULL; if (lower_closure(state, expr, &name, &args)) return -1; fprintf(state->code, "(fwd_closure_t){%s, &ctx->%s}", name, args); free(name); free(args); break; } case AST_ID: { struct ast *def = file_scope_find_symbol(expr->scope, id_str(expr)); char *name = mangle2(def); fprintf(state->code, "ctx->%s", name); free(name); break; } case AST_CONST_INT: { fprintf(state->code, "%lld", (long long)int_val(expr)); break; } default: internal_error("unhandled expr lowering: %s", ast_str(expr->k)); abort(); } return 0; } static int lower_call(struct state *state, struct ast *call, bool last) { struct ast *expr = call_expr(call); assert(expr->k == AST_ID); struct ast *def = file_scope_find_symbol(call->scope, id_str(expr)); assert(def); if (def->k == AST_VAR_DEF) return lower_closure_call(state, call, def, last); if (lower_proc(state, def)) return -1; char *q = buildstr("%s_call%zu", state->prefix, uniq(state)); struct state ctx_state = create_state(state); char *ctx_buf = NULL; size_t ctx_size = 0; FILE *ctx = open_memstream(&ctx_buf, &ctx_size); fprintf(ctx, "struct %s {\n fwd_start_t start;\n", q); indent(state); fprintf(state->code, "struct %s *%s = ctx->global_args;\n", q, q); int ret = 0; size_t idx = 0; bool returning = false; foreach_node(a, call_args(call)) { returning |= a->k == AST_CLOSURE; char *type = lower_type_str(a->t); fprintf(ctx, " %s a%zu;\n", type, idx); free(type); indent(&ctx_state); fprintf(ctx_state.code, "%s->a%zu = ", q, idx); int ret = lower_expr(&ctx_state, a); fprintf(ctx_state.code, ";\n"); if (ret) goto out; idx++; } out: if (!returning && last && ast_flags(state->current, AST_REQ_FRAME)) { /** @todo unsure if this applies to all cases but seems to work * for the simple examples I've tried so far */ indent(&ctx_state); fprintf(ctx_state.code, "fwd_stack_free(&stack, ctx);\n"); } char *target = mangle2(def); indent(&ctx_state); if (last) { fprintf(ctx_state.code, "FWD_MUSTTAIL return %s(stack, %s);\n", target, q); } else { fprintf(ctx_state.code, "stack = %s(stack, %s);\n", target, q); } fprintf(ctx, "};\n\n"); fclose(ctx); assert(ctx_buf); add_type(state, ctx_buf); free(target); free(q); return ret; } static int lower_if(struct state *state, struct ast *stmt, bool last) { indent(state); fprintf(state->code, "if ("); if (lower_expr(state, if_cond(stmt))) return -1; fprintf(state->code, ")\n"); if (lower_block(state, if_body(stmt), last)) return -1; if (!if_else(stmt)) return 0; indent(state); fprintf(state->code, "else\n"); return lower_block(state, if_else(stmt), last); } static int lower_stmt(struct state *state, struct ast *stmt, bool last) { switch (stmt->k) { case AST_CALL: return lower_call(state, stmt, last); case AST_IF: return lower_if(state, stmt, last); default: internal_error("unhandled statement kind %s", ast_str(stmt->k)); abort(); break; } return 0; } static int lower_stmt_list(struct state *state, struct ast *stmt_list, bool last) { foreach_node(stmt, stmt_list) { bool req_ret = last && !stmt->n; if (lower_stmt(state, stmt, req_ret)) return -1; } return 0; } static int lower_block(struct state *state, struct ast *block, bool last) { indent(state); fprintf(state->code, "{\n"); increase_indent(state); int ret = lower_stmt_list(state, block_body(block), last); decrease_indent(state); indent(state); fprintf(state->code, "}\n"); return ret; } static int lower_param(struct state *state, struct ast *param) { char *type = lower_type_str(param->t); char *name = mangle2(param); fprintf(state->ctx, " %s %s;\n", type, name); free(type); free(name); return 0; } static int lower_params(struct state *state, struct ast *params) { foreach_node(p, params) { if (lower_param(state, p)) return -1; } return 0; } static size_t fwd_typeid(struct type *t) { /** @todo this should maybe be cached somewhere earlier? */ switch (t->k) { case TYPE_I64: return FWD_I64; case TYPE_PTR: return FWD_PTR; default: abort(); } return 0; } static const char *fwd_typeparam(struct type *t) { switch (t->k) { case TYPE_I64: return "i64"; case TYPE_PTR: return "p"; default: abort(); } return 0; } static int lower_extern_proc(struct state *state, struct ast *proc) { /* only void external functions supported atm */ struct type *rtype = proc_rtype(proc); assert(rtype->k == TYPE_VOID); add_proc(state, proc); char *name = mangle2(proc); char *proto = buildstr("static fwd_stack_t %s(fwd_stack_t stack, fwd_args_t args)", name); char *decl = buildstr("%s;\n", proto); add_decl(state, decl); char *code_buf = NULL, *ctx_buf = NULL; size_t code_size = 0, ctx_size = 0; struct state new_state = create_state(state); new_state.indent = 0; new_state.prefix = name; new_state.current = proc; new_state.uniq = 0; new_state.code = open_memstream(&code_buf, &code_size); new_state.ctx = open_memstream(&ctx_buf, &ctx_size); assert(new_state.code); fprintf(new_state.ctx, "struct %s_ctx {\n", name); fprintf(new_state.code, "%s\n{\n", proto); increase_indent(&new_state); indent(&new_state); fprintf(new_state.code, "extern long %s(fwd_extern_args_t);\n", proc_id(proc)); /* for now, allocate a new frame for arguments. Might want to * simplify this by always using the 'external' format for argument * passing, even internally */ indent(&new_state); fprintf(new_state.code, "struct %s_ctx *ctx = args;\n", name); indent(&new_state); fprintf(new_state.code, "struct fwd_arg *extern_args = fwd_stack_alloc(&stack);\n"); size_t idx = 0; foreach_node(p, proc_params(proc)) { indent(&new_state); /* leave place for return value */ fprintf(new_state.code, "extern_args[%zu] = (fwd_arg_t){%zu, " "{.%s = ctx->a%zu}};\n", idx + 1, fwd_typeid(p->t), fwd_typeparam(p->t), idx); char *type_str = lower_type_str(p->t); fprintf(new_state.ctx, " %s a%zu;\n", type_str, idx); free(type_str); idx++; } indent(&new_state); fprintf(new_state.code, "%s((fwd_extern_args_t){.argc = %zu, .args = extern_args});\n", proc_id(proc), idx); indent(&new_state); fprintf(new_state.code, "fwd_stack_free(&stack, extern_args);\n"); indent(&new_state); fprintf(new_state.code, "return stack;\n"); fprintf(new_state.code, "}\n\n"); fprintf(new_state.ctx, "};\n\n"); fclose(new_state.code); assert(code_buf); add_defn(state, code_buf); fclose(new_state.ctx); assert(ctx_buf); add_type(state, ctx_buf); free(proto); free(name); return 0; } static int lower_param_copy(struct state *state, struct ast *param, FILE *f, size_t idx) { char *type = lower_type_str(param->t); fprintf(f, " %s a%zu;\n", type, idx); free(type); return 0; } static int lower_param_copies(struct state *state, struct ast *params) { char *param_buf = NULL; size_t param_size = 0; FILE *f = open_memstream(¶m_buf, ¶m_size); fprintf(f, "struct %s_params {\n fwd_start_t start;\n", state->prefix); size_t idx = 0; foreach_node(p, params) { if (lower_param_copy(state, p, f, idx)) { fclose(f); free(param_buf); return -1; } idx++; } fprintf(f, "};\n\n"); fclose(f); assert(param_buf); add_type(state, param_buf); indent(state); fprintf(state->code, "*((struct %s_params *)&ctx->%s_start) =" " *((struct %s_params *)args);\n", state->prefix, state->prefix, state->prefix); return 0; } static int lower_proc(struct state *state, struct ast *proc) { if (proc_lowered(state, proc)) return 0; if (!proc_body(proc)) return lower_extern_proc(state, proc); add_proc(state, proc); char *name = mangle2(proc); char *proto = buildstr("static fwd_stack_t %s(fwd_stack_t stack, fwd_args_t args)", name); char *decl = buildstr("%s;\n", proto); add_decl(state, decl); char *code_buf = NULL, *ctx_buf = NULL; size_t code_size = 0, ctx_size = 0; struct state new_state = create_state(state); new_state.indent = 0; new_state.prefix = name; new_state.current = proc; new_state.uniq = 0; new_state.code = open_memstream(&code_buf, &code_size); new_state.ctx = open_memstream(&ctx_buf, &ctx_size); assert(new_state.code); assert(new_state.ctx); char *start_of = buildstr("%s_start", name); fprintf(new_state.code, "%s\n{\n", proto); fprintf(new_state.ctx, "struct %s_ctx {\n", name); fprintf(new_state.ctx, " fwd_args_t global_args;\n fwd_start_t %s;\n", start_of); increase_indent(&new_state); indent(&new_state); fprintf(new_state.code, "static_assert(FWD_FRAME_SIZE >= sizeof(struct %s_ctx), \"context exceeds frame size\");\n", name); if (ast_flags(proc, AST_REQ_FRAME)) { indent(&new_state); fprintf(new_state.code, "struct %s_ctx *ctx " "= fwd_stack_alloc(&stack);\n", name); } else { indent(&new_state); fprintf(new_state.code, "struct %s_ctx ctx_buf;\n", name); indent(&new_state); fprintf(new_state.code, "struct %s_ctx *ctx = &ctx_buf;\n", name); } /* allocates parameter slots */ int ret = lower_params(&new_state, proc_params(proc)); assert(ret == 0); /* actually copies values into parameter slots */ ret = lower_param_copies(&new_state, proc_params(proc)); assert(ret == 0); indent(&new_state); fprintf(new_state.code, "ctx->global_args = args;\n"); struct ast *block = proc_body(proc); ret = lower_stmt_list(&new_state, block_body(block), true); fprintf(new_state.code, "}\n\n"); fprintf(new_state.ctx, "};\n\n"); fclose(new_state.code); fclose(new_state.ctx); assert(code_buf); assert(ctx_buf); add_defn(state, code_buf); add_type(state, ctx_buf); free(start_of); free(proto); free(name); return ret; } int lower(struct scope *root) { struct ast *main = file_scope_find_symbol(root, "main"); if (!main) { error("no main"); return -1; } if (main->k != AST_PROC_DEF) { error("main is not a procedure"); return -1; } struct string_vec defns = string_vec_create(0); struct string_vec decls = string_vec_create(0); struct string_vec types = string_vec_create(0); struct proc_set procs = proc_set_create(0); struct state state = { .indent = 0, .ctx = NULL, .code = NULL, .defns = &defns, .decls = &decls, .types = &types, .procs = &procs }; if (lower_proc(&state, main)) return -1; /* placeholder, should really be calculated to be some maximum frame * size or something */ printf("#define FWD_FRAME_SIZE %d\n", 1024); printf("#include \n\n"); foreach(string_vec, s, state.types) { puts(*s); free(*s); } foreach(string_vec, s, state.decls) { puts(*s); free(*s); } foreach(string_vec, s, state.defns) { puts(*s); free(*s); } char *name = mangle2(main); printf("int main()\n"); printf("{\n"); printf(" fwd_stack_t stack = create_fwd_stack();\n"); printf(" void *args = fwd_stack_alloc(&stack);\n"); printf(" stack = %s(stack, args);\n", name); printf(" fwd_stack_free(&stack, args);\n"); printf(" destroy_fwd_stack(&stack);\n\n"); printf("}\n"); /* modules require a register function of some kind, implement a stub */ printf("int fwd_register(struct fwd_state *state, const char *name, fwd_extern_t func, fwd_type_t rtype, ...) {return 0;}\n"); string_vec_destroy(&defns); string_vec_destroy(&decls); string_vec_destroy(&types); proc_set_destroy(&procs); free(name); return 0; }