diff options
Diffstat (limited to 'src/lower.c')
| -rw-r--r-- | src/lower.c | 1403 |
1 files changed, 1006 insertions, 397 deletions
diff --git a/src/lower.c b/src/lower.c index 89d9aca..3b6193f 100644 --- a/src/lower.c +++ b/src/lower.c @@ -5,18 +5,118 @@ #include <stdlib.h> #include <string.h> #include <stdarg.h> +#include <stdint.h> #include <assert.h> #include <fwd/lower.h> #include <fwd/scope.h> +#include <fwd/mod.h> + +/* stuff to be careful of: + * + * + Really easy to forget to pass the stack argument via reference, since it's + * just a void* we don't get a warning for it + * + * + lower_call() and lower_closure_call() share a lot of the same stuff, might + * want to try and merge them somehow, otherwise too easy to make a change in + * one and forget the other + */ + +/* placeholder, should probably scale according to biggest need */ +#define FWD_FRAME_SIZE 1024 + +static inline void mangle(FILE *f, struct ast *id) +{ + assert(id->k == AST_STRUCT_DEF + || id->k == AST_PROC_DEF + || id->k == AST_VAR_DEF + || id->k == AST_ID); + + assert(id->s && id->scope); + fprintf(f, "%s_s%zu", id->s, id->scope->number); +} + +static inline char *mangle2(struct ast *id) +{ + char *buf = NULL; size_t size = 0; + FILE *f = open_memstream(&buf, &size); + assert(f); + + mangle(f, id); + fclose(f); + assert(buf); + + return buf; +} + +static inline __attribute__((format (printf, 1, 2))) +char *buildstr(const char *fmt, ...) +{ + va_list args; + va_start(args, fmt); + + char *buf = NULL; size_t size = 0; + FILE *f = open_memstream(&buf, &size); + assert(f); + + vfprintf(f, fmt, args); + fclose(f); + + va_end(args); + assert(buf); + return buf; +} /** @todo semantics in this file are a bit unclear, should probably do some kind * of "each function starts and ends on an indented empty line" or something */ +#define VEC_NAME string_vec +#define VEC_TYPE char * +#include <conts/vec.h> + +#define MAP_NAME proc_set +#define MAP_KEY struct ast * +#define MAP_TYPE struct ast * +#define MAP_CMP(a, b) ((uintptr_t)(a) - (uintptr_t)(b)) +#define MAP_HASH(a) ((uintptr_t)(a)) +#include <conts/map.h> + struct state { + char *prefix; long indent; + size_t uniq; + FILE *ctx; + FILE *code; + struct ast *current; + struct string_vec *decls; + struct string_vec *defns; + struct string_vec *types; + struct proc_set *procs; }; +static struct state create_state(struct state *parent) +{ + return (struct state){ + .uniq = parent->uniq, + .prefix = parent->prefix, + .indent = parent->indent, + .current = parent->current, + .ctx = parent->ctx, + .code = parent->code, + .decls = parent->decls, + .defns = parent->defns, + .types = parent->types, + .procs = parent->procs, + }; +} + +static size_t uniq(struct state *state) +{ + (void)state; + static size_t q = 0; + return q++; +} + static void increase_indent(struct state *state) { state->indent++; @@ -30,624 +130,1133 @@ static void decrease_indent(struct state *state) static void indent(struct state *state) { if (state->indent != 0) - printf("%*c", (int)(2 * state->indent), ' '); + fprintf(state->code, "%*c", (int)(2 * state->indent), ' '); } -static int lower_var(struct ast *expr); - -static int lower_expr(struct state *state, struct ast *expr); -static int lower_closure(struct state *state, struct ast *closure); - -static int lower_block(struct state *state, struct ast *block, bool ret); -static int lower_statement(struct state *state, struct ast *stmt, bool ret); +static bool proc_lowered(struct state *state, struct ast *proc) +{ + assert(proc->k == AST_PROC_DEF); + return proc_set_find(state->procs, proc) != NULL; +} -static int lower_type(struct type *type); -static int lower_types(struct type *types); +static void add_proc(struct state *state, struct ast *proc) +{ + assert(proc->k == AST_PROC_DEF); + struct ast **p = proc_set_insert(state->procs, proc, proc); + assert(p && *p == proc); +} -static int lower_binop(struct state *state, struct ast *binop) +static void add_decl(struct state *state, char *decl) { - printf("("); - if (lower_expr(state, binop_left(binop))) - return -1; + string_vec_append(state->decls, decl); +} - switch (binop->k) { - case AST_ADD: printf(" + "); break; - case AST_SUB: printf(" - "); break; - case AST_MUL: printf(" * "); break; - case AST_DIV: printf(" / "); break; - case AST_REM: printf(" %% "); break; - case AST_LSHIFT: printf(" << "); break; - case AST_RSHIFT: printf(" >> "); break; - default: - internal_error("missing binop lowering"); - return -1; - } +static void add_defn(struct state *state, char *defn) +{ + string_vec_append(state->defns, defn); +} - if (lower_expr(state, binop_right(binop))) - return -1; +static bool type_lowered(struct ast *type) +{ + assert(type->k == AST_STRUCT_DEF); + return ast_flags(type, AST_FLAG_LOWERED); +} - printf(")"); - return 0; +static void add_type(struct state *state, char *type) +{ + string_vec_append(state->types, type); } -static int lower_comparison(struct state *state, struct ast *comp) +static int lower_stmt_list(struct state *state, struct ast *stmt_list, bool last); +static int lower_block(struct state *state, struct ast *block, bool last); +static int lower_params(struct state *state, struct ast *params); +static int lower_expr(struct state *state, struct ast *expr); +static int lower_proc(struct state *state, struct ast *proc); +static int lower_type_def(struct state *state, struct ast *type); +static char *lower_type_str(struct state *state, struct scope *scope, struct type *type); + +static int _type_str(FILE *f, struct state *state, struct scope *scope, struct type *type) { - printf("("); - if (lower_expr(state, comparison_left(comp))) - return -1; + assert(type); + switch (type->k) { + case TYPE_I8: + fprintf(f, "int8_t"); + break; - switch (comp->k) { - case AST_LT: printf(" < "); break; - case AST_GT: printf(" > "); break; - case AST_LE: printf(" <= "); break; - case AST_GE: printf(" <= "); break; - case AST_NE: printf(" != "); break; - case AST_EQ: printf(" == "); break; - default: - internal_error("missing comparison lowering"); - return -1; - } + case TYPE_U8: + fprintf(f, "uint8_t"); + break; - if (lower_expr(state, comparison_right(comp))) - return -1; + case TYPE_I16: + fprintf(f, "int16_t"); + break; - printf(")"); - return 0; -} + case TYPE_U16: + fprintf(f, "uint16_t"); + break; -static int lower_exprs(struct state *state, struct ast *exprs) -{ - if (!exprs) - return 0; + case TYPE_I32: + fprintf(f, "int32_t"); + break; - if (lower_expr(state, exprs)) - return -1; + case TYPE_U32: + fprintf(f, "uint32_t"); + break; - foreach_node(expr, exprs->n) { - printf(", "); - if (lower_expr(state, expr)) + case TYPE_I64: + fprintf(f, "int64_t"); + break; + + case TYPE_U64: + fprintf(f, "uint64_t"); + break; + + case TYPE_CLOSURE: + fprintf(f, "fwd_closure_t"); + break; + + case TYPE_PURE_CLOSURE: + fprintf(f, "fwd_closure_t"); + break; + + case TYPE_ID: { + struct ast *def = file_scope_find_type(scope, type_str(type)); + + if (!ast_flags(def, AST_FLAG_LOWERED)) + if (lower_type_def(state, def)) return -1; + + char *name = mangle2(def); + fprintf(f, "%s", name); + free(name); + break; } - return 0; -} + case TYPE_PTR: { + char *rest = lower_type_str(state, scope, tptr_base(type)); + fprintf(f, "%s*", rest); + free(rest); + break; + } -static int lower_type_construct(struct type *type) -{ - printf("%s", tconstruct_id(type)); - printf("<"); + case TYPE_REF: { + char *rest = lower_type_str(state, scope, tref_base(type)); + fprintf(f, "%s*", rest); + free(rest); + break; + } - if (lower_types(tconstruct_args(type))) - return -1; + case TYPE_BOOL: { + fprintf(f, "bool"); + break; + } + + case TYPE_NIL: { + fprintf(f, "void*"); + break; + } + + default: + internal_error("unhandled type lowering for %s", type_str(type)); + abort(); + break; + } - printf(">"); return 0; } -static int lower_type_callable(struct type *type) +static char *lower_type_str(struct state *state, struct scope *scope, struct type *type) { - /* std::function has a slight overhead compared to just using auto here, - * but auto doesn't play well with recursive templates like with our - * fib.fwd example, so use std::function for now. Eventually I might - * instead write a C backend or something to have more control over the - * exact semantics, but for now this is good enough. */ - printf("std::function<fwd_err_t("); + assert(type); - if (lower_types(type->t0)) - return -1; + char *type_buf = NULL; size_t type_len = 0; + FILE *f = open_memstream(&type_buf, &type_len); + assert(f); - printf(")>"); - return 0; -} + int r = _type_str(f, state, scope, type); + fclose(f); + assert(type_buf); -static int lower_type_ref(struct type *type) -{ - if (lower_type(tref_base(type))) - return -1; + if (r) { + free(type_buf); + return NULL; + } - printf("&"); - return 0; + return type_buf; } -static int lower_type_ptr(struct type *type) +static int lower_closure_call(struct state *state, struct ast *call, struct ast *def, bool last) { - /* would I need parentheses in some cases? */ - if (lower_type(tptr_base(type))) - return -1; + char *q = buildstr("%s_call%zu", state->prefix, uniq(state)); + char *args = mangle2(def); - printf("*"); - return 0; -} + char *ctx_buf = NULL; size_t ctx_size = 0; + FILE *ctx = open_memstream(&ctx_buf, &ctx_size); + fprintf(ctx, "struct %s {\n fwd_start_t start;\n", q); -static int lower_type(struct type *type) -{ - switch (type->k) { - case TYPE_ID: printf("%s", tid_str(type)); return 0; - case TYPE_CONSTRUCT: return lower_type_construct(type); - case TYPE_CLOSURE: return lower_type_callable(type); - case TYPE_FUNC_PTR: return lower_type_callable(type); - case TYPE_PURE_CLOSURE: return lower_type_callable(type); - case TYPE_REF: return lower_type_ref(type); - case TYPE_PTR: return lower_type_ptr(type); - default: - internal_error("missing type lowering"); - return -1; + indent(state); + fprintf(state->code, "fwd_call_t %s = ctx->%s.call;\n", q, args); + + indent(state); + fprintf(state->code, "struct %s *%s_ctx = ctx->%s.args;\n", q, q, args); + + int ret = 0; + size_t idx = 0; + bool returning = false; + foreach_node(a, call_args(call)) { + returning |= a->k == AST_CLOSURE; + + char *type = lower_type_str(state, a->scope, a->t); + fprintf(ctx, " %s a%zu;\n", type, idx); + free(type); + + indent(state); + fprintf(state->code, "%s_ctx->a%zu = ", q, idx); + int ret = lower_expr(state, a); + fprintf(state->code, ";\n"); + if (ret) + goto out; + + idx++; } - return 0; +out: + if (!returning && last && ast_flags(state->current, AST_REQ_FRAME)) { + indent(state); + fprintf(state->code, "fwd_stack_free(&stack, ctx);\n"); + } + + indent(state); + if (last) { + fprintf(state->code, "FWD_MUSTTAIL return %s(stack, %s_ctx);\n", + q, q); + } + else { + fprintf(state->code, "stack = %s(stack, %s_ctx);\n", + q, q); + } + + fprintf(ctx, "};\n\n"); + fclose(ctx); + assert(ctx_buf); + + add_type(state, ctx_buf); + free(args); + free(q); + return ret; } -static int lower_types(struct type *types) +static int lower_closure(struct state *state, struct ast *closure, char **name_out, char **args_out) { - if (!types) - return 0; + char *name = buildstr("%s_closure%zu", state->prefix, uniq(state)); + char *proto = buildstr("static fwd_stack_t %s(fwd_stack_t stack, fwd_args_t args)", name); + char *decl = buildstr("%s;\n", proto); + char *start_of = buildstr("%s_start", name); + add_decl(state, decl); - if (lower_type(types)) - return -1; + char *code_buf = NULL; + size_t code_size = 0; - foreach_type(type, types->n) { - printf(", "); - if (lower_type(type)) - return -1; - } + struct state new_state = create_state(state); + new_state.code = open_memstream(&code_buf, &code_size); + new_state.indent = 0; + assert(new_state.code); - return 0; + fprintf(new_state.code, "%s\n{\n", proto); + + /** @todo unsure if this has the same address as the first element in + * the frame or the last in the previous, I guess we shall soon see */ + fprintf(new_state.ctx, " fwd_start_t %s;\n", start_of); + + int ret = lower_params(&new_state, closure_bindings(closure)); + assert(ret == 0); + + increase_indent(&new_state); + indent(&new_state); + fprintf(new_state.code, "struct %s_ctx *ctx = FWD_CONTAINER_OF(args, struct %s_ctx, %s);\n", + state->prefix, state->prefix, start_of); + + struct ast *block = closure_body(closure); + ret = lower_stmt_list(&new_state, block_body(block), true); + + fprintf(new_state.code, "}\n\n"); + fclose(new_state.code); + assert(code_buf); + + add_defn(state, code_buf); + free(proto); + + assert(name_out); + assert(args_out); + *name_out = name; + *args_out = start_of; + return ret; } -static int lower_init(struct state *state, struct ast *init) +static int lower_comparison(struct state *state, struct ast *expr) { - printf("%s", init_id(init)); + if (lower_expr(state, comparison_left(expr))) + return -1; - if (init_args(init)) { - printf("<"); - if (lower_types(init_args(init))) - return -1; + switch (expr->k) { + case AST_LE: + fprintf(state->code, " <= "); + break; - printf(">"); - } + case AST_GE: + fprintf(state->code, " >= "); + break; + + case AST_LT: + fprintf(state->code, " < "); + break; - printf("{"); + case AST_GT: + fprintf(state->code, " > "); + break; - if (lower_exprs(state, init_body(init))) - return -1; + case AST_EQ: + fprintf(state->code, " == "); + break; - printf("}"); - return 0; + case AST_NE: + fprintf(state->code, " != "); + break; + + default: + internal_error("unhandled lowering for comparison %s", ast_str(expr->k)); + abort(); + } + + return lower_expr(state, comparison_right(expr)); } -static int lower_unop(struct state *state, struct ast *expr) +static int lower_binop(struct state *state, struct ast *expr) { + if (lower_expr(state, binop_left(expr))) + return -1; + switch (expr->k) { - case AST_LNOT: printf("-"); break; - case AST_NOT: printf("~"); break; - case AST_NEG: printf("-"); break; + case AST_SUB: + fprintf(state->code, " - "); + break; + + case AST_ADD: + fprintf(state->code, " + "); + break; + + case AST_MUL: + fprintf(state->code, " * "); + break; + default: - internal_error("missing unop lowering"); - return -1; + internal_error("unhandled binop %s", ast_str(expr->k)); + abort(); } - return lower_expr(state, unop_expr(expr)); + return lower_expr(state, binop_right(expr)); } static int lower_expr(struct state *state, struct ast *expr) { - if (is_unop(expr)) - return lower_unop(state, expr); + if (is_comparison(expr)) + return lower_comparison(state, expr); if (is_binop(expr)) return lower_binop(state, expr); - if (is_comparison(expr)) - return lower_comparison(state, expr); - switch (expr->k) { - case AST_ID: printf("%s", id_str(expr)); break; - case AST_CONST_INT: printf("%lld", int_val(expr)); break; - case AST_CONST_FLOAT: printf("%f", float_val(expr)); break; - case AST_CONST_BOOL: printf("%s", bool_val(expr) ? "true" : "false"); - break; - case AST_CONST_STR: printf("\"%s\"", str_val(expr)); break; - case AST_CONST_CHAR: printf("'%c'", (char)char_val(expr)); break; - case AST_INIT: return lower_init(state, expr); - case AST_CLOSURE: return lower_closure(state, expr); - case AST_REF: return lower_expr(state, ref_base(expr)); - case AST_DEREF: return lower_expr(state, deref_base(expr)); - default: - internal_error("missing expr lowering"); - return -1; + case AST_CLOSURE: { + char *name = NULL, *args = NULL; + if (lower_closure(state, expr, &name, &args)) + return -1; + + fprintf(state->code, "(fwd_closure_t){%s, &ctx->%s}", name, args); + free(name); + free(args); + break; } - return 0; -} + case AST_ID: { + struct ast *def = file_scope_find_symbol(expr->scope, id_str(expr)); + char *name = mangle2(def); + fprintf(state->code, "ctx->%s", name); + free(name); + break; + } -static int lower_move(struct state *state, struct ast *move) -{ - if (move->k == AST_ID) { - /** @todo once I start messing about with references, moves - * should only be outputted for parameters that take ownership - */ - printf("move(%s)", id_str(move)); - return 0; + case AST_CONST_INT: { + fprintf(state->code, "%lld", (long long)int_val(expr)); + break; } - return lower_expr(state, move); -} + case AST_NIL: { + fprintf(state->code, "NULL"); + break; + } -static int lower_moves(struct state *state, struct ast *moves) -{ - if (!moves) - return 0; + case AST_CONSTRUCT: { + char *name = lower_type_str(state, expr->scope, expr->t); + fprintf(state->code, "(%s){", name); + free(name); - if (lower_move(state, moves)) - return -1; + foreach_node(n, construct_members(expr)) { + fprintf(state->code, ".%s = ", construction_id(n)); + + if (lower_expr(state, construction_expr(n))) + return -1; + + fprintf(state->code, ", "); + } + fprintf(state->code, "}"); + break; + } + + case AST_AS: { + char *name = lower_type_str(state, expr->scope, expr->t); + fprintf(state->code, "(%s)(", name); + free(name); + + if (lower_expr(state, as_expr(expr))) + return -1; + + fprintf(state->code, ")"); + break; + } - foreach_node(move, moves->n) { - printf(", "); - if (lower_move(state, move)) + case AST_SIZEOF: { + char *type = lower_type_str(state, expr->scope, sizeof_type(expr)); + fprintf(state->code, "sizeof(%s)", type); + free(type); + break; + } + + case AST_PUT: { + fprintf(state->code, "*("); + if (lower_expr(state, put_dst(expr))) + return -1; + + fprintf(state->code, ")"); + break; + } + + case AST_DEREF: { + fprintf(state->code, "*("); + if (lower_expr(state, deref_base(expr))) return -1; + + fprintf(state->code, ")"); + break; + } + + case AST_CONST_STR: { + fprintf(state->code, "\"%s\"", str_val(expr)); + break; + } + + default: + internal_error("unhandled expr lowering: %s", ast_str(expr->k)); + abort(); } return 0; } -static int lower_err_branch(struct state *state, struct ast *err) +static int lower_call(struct state *state, struct ast *call, bool last) { - if (lower_block(state, err_branch_body(err), false)) + struct ast *expr = call_expr(call); + assert(expr->k == AST_ID); + + struct ast *def = file_scope_find_symbol(call->scope, id_str(expr)); + assert(def); + + if (def->k == AST_VAR_DEF) + return lower_closure_call(state, call, def, last); + + if (lower_proc(state, def)) return -1; - printf("\n"); - return 0; -} + char *q = buildstr("%s_call%zu", state->prefix, uniq(state)); -static int lower_mark_moved(struct state *state, struct ast *moves) -{ - if (!moves) - return 0; + struct state ctx_state = create_state(state); + + char *ctx_buf = NULL; size_t ctx_size = 0; + FILE *ctx = open_memstream(&ctx_buf, &ctx_size); + fprintf(ctx, "struct %s {\n fwd_start_t start;\n", q); - foreach_node(move, moves) { - if (move->k != AST_ID) - continue; + indent(state); + fprintf(state->code, "struct %s *%s = ctx->global_args;\n", q, q); - if (is_trivially_copyable(move->t)) - continue; + int ret = 0; + size_t idx = 0; + bool returning = false; + foreach_node(a, call_args(call)) { + returning |= a->k == AST_CLOSURE; - if (is_callable(move->t)) - continue; + char *type = lower_type_str(state, a->scope, a->t); + fprintf(ctx, " %s a%zu;\n", type, idx); + free(type); - printf("%s_owned = false;\n", id_str(move)); - indent(state); + indent(&ctx_state); + fprintf(ctx_state.code, "%s->a%zu = ", q, idx); + int ret = lower_expr(&ctx_state, a); + fprintf(ctx_state.code, ";\n"); + + if (ret) + goto out; + + idx++; } - return 0; +out: + if (!returning && last && ast_flags(state->current, AST_REQ_FRAME)) { + /** @todo unsure if this applies to all cases but seems to work + * for the simple examples I've tried so far */ + indent(&ctx_state); + fprintf(ctx_state.code, "fwd_stack_free(&stack, ctx);\n"); + } + + char *target = mangle2(def); + indent(&ctx_state); + if (last) { + fprintf(ctx_state.code, "FWD_MUSTTAIL return %s(stack, %s);\n", + target, q); + } + else { + fprintf(ctx_state.code, "stack = %s(stack, %s);\n", + target, q); + } + + fprintf(ctx, "};\n\n"); + fclose(ctx); + assert(ctx_buf); + + add_type(state, ctx_buf); + free(target); + free(q); + return ret; } -/** @todo this is probably more complicated than it really needs to be, maybe - * refactor into lower_checked_call and lower_implicit_call or something for - * explicit and implicit error handling cases? */ -static int lower_call(struct state *state, struct ast *call, bool ret) +static int lower_if(struct state *state, struct ast *stmt, bool last) { - struct ast *err = call_err(call); - /** @todo better default error name? */ - const char *err_str = err ? err_branch_id(err) : "_fwd_err"; + indent(state); + fprintf(state->code, "if ("); + if (lower_expr(state, if_cond(stmt))) + return -1; - if (lower_mark_moved(state, call_args(call))) + fprintf(state->code, ")\n"); + if (lower_block(state, if_body(stmt), last)) return -1; - bool direct_ret = ret && !err; - if (direct_ret) - printf("return "); - else - printf("if (auto %s = ", err_str); + if (!if_else(stmt)) + return 0; + + indent(state); + fprintf(state->code, "else\n"); + return lower_block(state, if_else(stmt), last); +} + +static int lower_nil_check(struct state *state, struct ast *stmt, bool last) +{ + struct ast *var = nil_check_ref(stmt); + char *type = lower_type_str(state, var->scope, var->t); + char *name = mangle2(var); + + fprintf(state->ctx, " %s %s;\n", type, name); - if (lower_expr(state, call_expr(call))) + indent(state); + fprintf(state->code, "ctx->%s = ", name); + if (lower_expr(state, nil_check_expr(stmt))) return -1; - printf("("); + fprintf(state->code, ";\n"); - if (lower_moves(state, call_args(call))) + indent(state); + fprintf(state->code, "if (!ctx->%s)", name); + if (lower_block(state, nil_check_body(stmt), last)) return -1; - if (direct_ret) { - printf(");\n"); - return 0; - } + indent(state); + fprintf(state->code, "else\n"); - printf("))"); - if (err) { - if (lower_err_branch(state, err)) - return -1; + free(type); + free(name); - if (ret) { - printf("\n"); - indent(state); - printf("return nullptr;\n"); - } + return lower_block(state, nil_check_rest(stmt), last); +} - return 0; - } +static int lower_explode(struct state *state, struct ast *stmt, bool last) +{ + /* not significant to us */ + (void)last; - printf("\n"); + int u = uniq(state); + struct ast *expr = explode_expr(stmt); + + char *type = lower_type_str(state, expr->scope, expr->t); - increase_indent(state); indent(state); - decrease_indent(state); + fprintf(state->code, "%s explode_%d = ", type, u); + free(type); + + if (lower_expr(state, expr)) + return -1; + + fprintf(state->code, ";\n"); + + foreach_node(node, explode_deconstruct(stmt)) { + struct ast *var = deconstruction_var(node); - printf("return %s;\n", err_str); + char *type = lower_type_str(state, var->scope, var->t); + char *name = mangle2(var); - if (ret) { indent(state); - printf("return nullptr;\n"); + fprintf(state->code, "ctx->%s = explode_%d.%s;\n", + name, u, deconstruction_id(node)); + + fprintf(state->ctx, " %s %s;\n", type, name); + + free(type); + free(name); } return 0; } -static int lower_let(struct state *state, struct ast *let, bool ret) +static int lower_let(struct state *state, struct ast *stmt, bool last) { - if (lower_var(let_var(let))) - return -1; + struct ast *var = let_var(stmt); + char *type = lower_type_str(state, var->scope, var->t); + char *name = mangle2(var); - printf(" = "); + fprintf(state->ctx, " %s %s;\n", type, name); - if (lower_expr(state, let_expr(let))) + indent(state); + fprintf(state->code, "ctx->%s = ", name); + + free(type); + free(name); + + if (lower_expr(state, let_expr(stmt))) return -1; - printf(";\n"); - if (ret) { + fprintf(state->code, ";\n"); + + if (last) { indent(state); - printf("return nullptr;\n"); + fprintf(state->code, "return stack;\n"); } return 0; } -static int lower_if(struct state *state, struct ast *stmt, bool ret) +static int lower_write(struct state *state, struct ast *stmt) { - printf("if ("); - if (lower_expr(state, if_cond(stmt))) + indent(state); + if (lower_expr(state, write_dst(stmt))) return -1; - printf(") "); - - if (lower_block(state, if_body(stmt), ret)) + fprintf(state->code, " = "); + if (lower_expr(state, write_src(stmt))) return -1; - if (!if_else(stmt)) { - printf("\n"); - return 0; + fprintf(state->code, ";\n"); + return 0; +} + +static int lower_stmt(struct state *state, struct ast *stmt, bool last) +{ + switch (stmt->k) { + case AST_CALL: return lower_call(state, stmt, last); + case AST_IF: return lower_if(state, stmt, last); + case AST_NIL_CHECK: return lower_nil_check(state, stmt, last); + case AST_EXPLODE: return lower_explode(state, stmt, last); + case AST_LET: return lower_let(state, stmt, last); + case AST_WRITE: return lower_write(state, stmt); + case AST_FORGET: + case AST_EMPTY: + if (last) { + indent(state); + fprintf(state->code, "return stack;\n"); + } + break; + default: + internal_error("unhandled statement kind %s", ast_str(stmt->k)); + abort(); + break; } + return 0; +} - printf(" else "); - if (lower_block(state, if_else(stmt), ret)) - return -1; +static int lower_stmt_list(struct state *state, struct ast *stmt_list, bool last) +{ + foreach_node(stmt, stmt_list) { + bool req_ret = last && !stmt->n; + if (lower_stmt(state, stmt, req_ret)) + return -1; + } - printf("\n"); return 0; } -static int lower_error(struct ast *err) +static int lower_block(struct state *state, struct ast *block, bool last) { - assert(error_str(err) || error_id(err)); - if (error_str(err)) { - printf("return %s;\n", error_str(err)); - return 0; - } + indent(state); + fprintf(state->code, "{\n"); + + increase_indent(state); + int ret = lower_stmt_list(state, block_body(block), last); + decrease_indent(state); - struct ast *id = error_id(err); - printf("return %s;\n", id_str(id)); + indent(state); + fprintf(state->code, "}\n"); + return ret; +} + +static int lower_param(struct state *state, struct ast *param) +{ + char *type = lower_type_str(state, param->scope, param->t); + char *name = mangle2(param); + + fprintf(state->ctx, " %s %s;\n", type, name); + + free(type); + free(name); return 0; } -static int lower_own(struct state *state, struct ast *stmt, bool ret) +static int lower_params(struct state *state, struct ast *params) { - /** @todo name mangling */ - printf("if (!%s_owned) ", own_id(stmt)); - if (lower_block(state, own_body(stmt), ret)) - return -1; + foreach_node(p, params) { + if (lower_param(state, p)) + return -1; + } - printf("\n"); return 0; } -static int lower_statement(struct state *state, struct ast *stmt, bool ret) +static const char *fwd_typeparam(struct type *t) { - switch (stmt->k) { - case AST_OWN: return lower_own(state, stmt, ret); - case AST_LET: return lower_let(state, stmt, ret); - case AST_CALL: return lower_call(state, stmt, ret); - case AST_IF: return lower_if(state, stmt, ret); - case AST_EMPTY: - if (ret) - printf("return nullptr;\n"); - else - printf("\n"); + switch (t->k) { + case TYPE_I8: return "i8"; + case TYPE_I16: return "i16"; + case TYPE_I32: return "i32"; + case TYPE_I64: return "i64"; + case TYPE_U8: return "u8"; + case TYPE_U16: return "u16"; + case TYPE_U32: return "u32"; + case TYPE_U64: return "u64"; + case TYPE_PTR: return "p"; + case TYPE_NIL: return "p"; + case TYPE_FUNC_PTR: return "p"; + default: + abort(); + } + + return NULL; +} - return 0; +static const char *fwd_ctypestr(struct type *t) +{ + switch (t->k) { + case TYPE_I8: return "FWD_I8"; + case TYPE_I16: return "FWD_I16"; + case TYPE_I32: return "FWD_I32"; + case TYPE_I64: return "FWD_I64"; + case TYPE_U8: return "FWD_U8"; + case TYPE_U16: return "FWD_U16"; + case TYPE_U32: return "FWD_U32"; + case TYPE_U64: return "FWD_U64"; + case TYPE_PTR: return "FWD_PTR"; + case TYPE_NIL: return "FWD_PTR"; /* void ptr */ + case TYPE_FUNC_PTR: return "FWD_PTR"; default: - internal_error("missing statement lowering"); - return -1; + abort(); } - return 0; + return NULL; } -static int lower_block_vars(struct state *state, struct ast *block) +static int lower_extern_closure_call(struct state *state, struct scope *scope, struct type *rtype, size_t idx) { - struct scope *scope = block->scope; + char *q = buildstr("%s_call%zu", state->prefix, uniq(state)); - bool populated = false; - foreach(visible, n, &scope->symbols) { - struct ast *def = n->data; - if (def->k != AST_VAR_DEF) - continue; + char *ctx_buf = NULL; size_t ctx_size = 0; + FILE *ctx = open_memstream(&ctx_buf, &ctx_size); - if (is_trivially_copyable(def->t)) - continue; + char *type = lower_type_str(state, scope, rtype); + fprintf(ctx, "struct %s {\n fwd_start_t start;\n %s a0;\n};\n", q, type); + free(type); - if (is_callable(def->t)) - continue; + fclose(ctx); + assert(ctx_buf); - if (!populated) { - indent(state); - printf("[[maybe_unused]] bool %s_owned = true", - var_id(def)); + add_type(state, ctx_buf); - populated = true; - continue; - } + indent(state); + fprintf(state->code, "fwd_call_t call = ctx->a%zu.call;\n", idx); - printf(", %s_owned = true", var_id(def)); - } + indent(state); + fprintf(state->code, "struct %s *call_ctx = ctx->a%zu.args;\n", q, idx); + + indent(state); + fprintf(state->code, "call_ctx->a0 = extern_args[0].%s;\n", + fwd_typeparam(rtype)); + + indent(state); + fprintf(state->code, "fwd_stack_free(&stack, extern_args);\n"); - if (populated) - printf(";\n\n"); + indent(state); + fprintf(state->code, "FWD_MUSTTAIL return call(stack, call_ctx);\n"); + free(q); return 0; } -static int lower_block(struct state *state, struct ast *block, bool ret) +static int lower_extern_proc(struct state *state, struct ast *proc) { - printf("{\n"); - increase_indent(state); - - if (lower_block_vars(state, block)) - return -1; + add_proc(state, proc); + char *name = mangle2(proc); + char *proto = buildstr("static fwd_stack_t %s(fwd_stack_t stack, fwd_args_t args)", + name); + + char *decl = buildstr("%s;\n", proto); + add_decl(state, decl); + + char *code_buf = NULL, *ctx_buf = NULL; + size_t code_size = 0, ctx_size = 0; + + struct state new_state = create_state(state); + new_state.indent = 0; + new_state.prefix = name; + new_state.current = proc; + new_state.uniq = 0; + new_state.code = open_memstream(&code_buf, &code_size); + new_state.ctx = open_memstream(&ctx_buf, &ctx_size); + assert(new_state.code); + + fprintf(new_state.ctx, "struct %s_ctx {\n", name); + fprintf(new_state.code, "%s\n{\n", proto); + increase_indent(&new_state); + + indent(&new_state); + fprintf(new_state.code, "extern long %s(fwd_extern_args_t);\n", proc_id(proc)); + + /* for now, allocate a new frame for arguments. Might want to + * simplify this by always using the 'external' format for argument + * passing, even internally */ + indent(&new_state); + fprintf(new_state.code, "struct %s_ctx *ctx = args;\n", name); + + indent(&new_state); + fprintf(new_state.code, "struct fwd_arg *extern_args = fwd_stack_alloc(&stack);\n"); + + size_t idx = 0; + struct type *rtype = NULL; + foreach_node(p, proc_params(proc)) { + /* lower arg */ + char *type_str = lower_type_str(state, p->scope, p->t); + fprintf(new_state.ctx, " %s a%zu;\n", type_str, idx); + free(type_str); + + /* don't lower last parameter as an extern arg if it's a + * closure, since that's our return type */ + if (!p->n && is_closure_type(p->t->k)) { + rtype = p->t; + break; + } - foreach_node(stmt, block_body(block)) { - indent(state); + indent(&new_state); + /* leave place for return value */ + fprintf(new_state.code, "extern_args[%zu] = (fwd_arg_t){%s, " + "{.%s = ctx->a%zu}};\n", + idx + 1, fwd_ctypestr(p->t), fwd_typeparam(p->t), idx); - bool returning = block_error(block) ? false : ret && !stmt->n; - if (lower_statement(state, stmt, returning)) - return -1; + idx++; } - if (block_error(block)) { - indent(state); - if (lower_error(block_error(block))) + indent(&new_state); + fprintf(new_state.code, "%s((fwd_extern_args_t){.argc = %zu, .args = extern_args});\n", + proc_id(proc), idx); + + + if (rtype) { + struct type *ctype = rtype->k == TYPE_PURE_CLOSURE + ? tpure_closure_args(rtype) + : tclosure_args(rtype) + ; + + indent(&new_state); + fprintf(new_state.code, "assert(extern_args[0].t == %s);\n", + fwd_ctypestr(ctype)); + + if (lower_extern_closure_call(&new_state, proc->scope, ctype, idx)) return -1; - } - else if (!block_body(block)) { - indent(state); - printf("return nullptr;\n"); + + } else { + /* void func */ + indent(&new_state); + fprintf(new_state.code, "fwd_stack_free(&stack, extern_args);\n"); + + indent(&new_state); + fprintf(new_state.code, "return stack;\n"); } - decrease_indent(state); - indent(state); - printf("}"); - return 0; -} + fprintf(new_state.code, "}\n\n"); + fprintf(new_state.ctx, "};\n\n"); -static int lower_var(struct ast *var) -{ - if (lower_type(var_type(var))) - return -1; + fclose(new_state.code); + assert(code_buf); + add_defn(state, code_buf); + + fclose(new_state.ctx); + assert(ctx_buf); + add_type(state, ctx_buf); - printf(" %s", var_id(var)); + free(proto); + free(name); return 0; } -static int lower_vars(struct ast *vars) +static int lower_param_copy(struct state *state, struct ast *param, FILE *f, size_t idx) { - if (!vars) - return 0; - - if (lower_var(vars)) - return -1; + char *type = lower_type_str(state, param->scope, param->t); + fprintf(f, " %s a%zu;\n", type, idx); + free(type); - foreach_node(var, vars->n) { - printf(", "); - if (lower_var(var)) - return -1; - } + indent(state); + char *p = mangle2(param); + fprintf(state->code, "ctx->%s = params->a%zu;\n", p, idx); + free(p); return 0; } -static int lower_closure(struct state *state, struct ast *closure) +static int lower_param_copies(struct state *state, struct ast *params) { - printf("[&]("); - if (lower_vars(closure_bindings(closure))) - return -1; + char *param_buf = NULL; size_t param_size = 0; + FILE *f = open_memstream(¶m_buf, ¶m_size); + fprintf(f, "struct %s_params {\n fwd_start_t start;\n", state->prefix); + + size_t idx = 0; + foreach_node(p, params) { + if (lower_param_copy(state, p, f, idx)) { + fclose(f); + free(param_buf); + return -1; + } - printf(")"); + idx++; + } - if (lower_block(state, closure_body(closure), true)) - return -1; + fprintf(f, "};\n\n"); + fclose(f); + assert(param_buf); + add_type(state, param_buf); return 0; } -static int lower_proto(struct ast *proc) +static int lower_type_def(struct state *state, struct ast *type) { - /* 'extern' functions should be provided to us by whatever framework the - * user is using */ - if (!proc_body(proc)) + assert(type->k == AST_STRUCT_DEF); + if (type_lowered(type)) return 0; - printf("fwd_err_t "); - if (strcmp("main", proc_id(proc)) == 0) - printf("fwd_main("); - else - printf("%s(", proc_id(proc)); + ast_set_flags(type, AST_FLAG_LOWERED); + char *decl_buf = NULL; size_t decl_len = 0; + FILE *decl = open_memstream(&decl_buf, &decl_len); + assert(decl); - if (lower_vars(proc_params(proc))) - return -1; + char *name = mangle2(type); + fprintf(decl, "struct %s;\n", name); + fclose(decl); + + char *defn_buf = NULL; size_t defn_len = 0; + FILE *defn = open_memstream(&defn_buf, &defn_len); + assert(defn); + + fprintf(defn, "typedef struct %s {\n", name); + + foreach_node(n, struct_body(type)) { + assert(n->k == AST_VAR_DEF); + char *t = lower_type_str(state, n->scope, var_type(n)); + fprintf(defn, "\t%s %s;\n", t, var_id(n)); + free(t); + } + + fprintf(defn, "} %s;\n", name); + + fclose(defn); + free(name); + + add_type(state, decl_buf); + add_type(state, defn_buf); - printf(");\n\n"); return 0; } -static int lower_proc(struct ast *proc) +static int lower_proc(struct state *state, struct ast *proc) { - if (!proc_body(proc)) + if (proc_lowered(state, proc)) return 0; - printf("fwd_err_t "); - if (strcmp("main", proc_id(proc)) == 0) - printf("fwd_main("); - else - printf("%s(", proc_id(proc)); + ast_set_flags(proc, AST_FLAG_LOWERED); - if (lower_vars(proc_params(proc))) - return -1; + if (!proc_body(proc)) + return lower_extern_proc(state, proc); + + add_proc(state, proc); + char *name = mangle2(proc); + char *proto = buildstr("static fwd_stack_t %s(fwd_stack_t stack, fwd_args_t args)", name); + char *decl = buildstr("%s;\n", proto); + add_decl(state, decl); + + char *code_buf = NULL, *ctx_buf = NULL; + size_t code_size = 0, ctx_size = 0; + + struct state new_state = create_state(state); + new_state.indent = 0; + new_state.prefix = name; + new_state.current = proc; + new_state.uniq = 0; + new_state.code = open_memstream(&code_buf, &code_size); + new_state.ctx = open_memstream(&ctx_buf, &ctx_size); + assert(new_state.code); + assert(new_state.ctx); + + char *start_of = buildstr("%s_start", name); + + fprintf(new_state.code, "%s\n{\n", proto); + fprintf(new_state.ctx, "struct %s_ctx {\n", name); + fprintf(new_state.ctx, " fwd_args_t global_args;\n fwd_start_t %s;\n", start_of); + + increase_indent(&new_state); + indent(&new_state); + fprintf(new_state.code, "static_assert(FWD_FRAME_SIZE >= sizeof(struct %s_ctx), \"context exceeds frame size\");\n", + name); + + if (ast_flags(proc, AST_REQ_FRAME)) { + indent(&new_state); + fprintf(new_state.code, "struct %s_ctx *ctx " + "= fwd_stack_alloc(&stack);\n", + name); + } + else { + indent(&new_state); + fprintf(new_state.code, "struct %s_ctx ctx_buf;\n", name); + indent(&new_state); + fprintf(new_state.code, "struct %s_ctx *ctx = &ctx_buf;\n", name); + } - printf(")\n"); + if (proc_params(proc)) { + indent(&new_state); + fprintf(new_state.code, "struct %s_params *params = (struct %s_params *)args;\n", + name, name); + } - struct state state = {0}; - if (lower_block(&state, proc_body(proc), true)) - return -1; + /* allocates parameter slots */ + int ret = lower_params(&new_state, proc_params(proc)); + assert(ret == 0); - printf("\n\n"); - return 0; + /* actually copies values into parameter slots */ + ret = lower_param_copies(&new_state, proc_params(proc)); + assert(ret == 0); + + indent(&new_state); + fprintf(new_state.code, "ctx->global_args = args;\n"); + + struct ast *block = proc_body(proc); + ret = lower_stmt_list(&new_state, block_body(block), true); + + fprintf(new_state.code, "}\n\n"); + fprintf(new_state.ctx, "};\n\n"); + + fclose(new_state.code); + fclose(new_state.ctx); + assert(code_buf); + assert(ctx_buf); + + add_defn(state, code_buf); + add_type(state, ctx_buf); + + free(start_of); + free(proto); + free(name); + + return ret; } int lower(struct scope *root) { - printf("#include <fwdlib.hpp>\n"); + struct ast *main = file_scope_find_symbol(root, "main"); + if (!main) { + error("no main"); + return -1; + } - foreach(visible, visible, &root->symbols) { - struct ast *proc = visible->data; - assert(proc->k == AST_PROC_DEF); - if (lower_proto(proc)) - return -1; + if (main->k != AST_PROC_DEF) { + error("main is not a procedure"); + return -1; } - foreach(visible, visible, &root->symbols) { - struct ast *proc = visible->data; - if (lower_proc(proc)) - return -1; + struct string_vec defns = string_vec_create(0); + struct string_vec decls = string_vec_create(0); + struct string_vec types = string_vec_create(0); + struct proc_set procs = proc_set_create(0); + + struct state state = { + .indent = 0, + .ctx = NULL, + .code = NULL, + .defns = &defns, + .decls = &decls, + .types = &types, + .procs = &procs + }; + + if (lower_proc(&state, main)) + return -1; + + /* placeholder, should really be calculated to be some maximum frame + * size or something */ + printf("#define FWD_FRAME_SIZE %d\n", 1024); + printf("#include <fwd.h>\n\n"); + + foreach(string_vec, s, state.types) { + puts(*s); + free(*s); } + foreach(string_vec, s, state.decls) { + puts(*s); + free(*s); + } - puts("int main()"); - puts("{"); - puts(" fwd_err_t err = fwd_main();"); - puts(" if (err) {"); - puts(" fprintf(stderr, \"%s\", err);"); - puts(" return -1;"); - puts(" }"); - puts("}"); + foreach(string_vec, s, state.defns) { + puts(*s); + free(*s); + } + char *name = mangle2(main); + + printf("int main()\n"); + printf("{\n"); + printf(" fwd_stack_t stack = create_fwd_stack();\n"); + printf(" void *args = fwd_stack_alloc(&stack);\n"); + printf(" stack = %s(stack, args);\n", name); + printf(" fwd_stack_free(&stack, args);\n"); + printf(" destroy_fwd_stack(&stack);\n\n"); + printf("}\n"); + + /* modules require a register function of some kind, implement a stub */ + printf("int fwd_register(struct fwd_state *state, const char *name, fwd_extern_t func, fwd_type_t rtype, ...) {return 0;}\n"); + + string_vec_destroy(&defns); + string_vec_destroy(&decls); + string_vec_destroy(&types); + proc_set_destroy(&procs); + free(name); return 0; } |
