/* SPDX-License-Identifier: copyleft-next-0.3.1 */ /* Copyright 2024 Kim Kuparinen < kimi.h.kuparinen@gmail.com > */ #include #include #include #include #include #include #include /** @todo semantics in this file are a bit unclear, should probably do some kind * of "each function starts and ends on an indented empty line" or something */ struct state { long indent; }; static void increase_indent(struct state *state) { state->indent++; } static void decrease_indent(struct state *state) { state->indent--; } static void indent(struct state *state) { if (state->indent != 0) printf("%*c", (int)(2 * state->indent), ' '); } static int lower_var(struct ast *expr); static int lower_expr(struct state *state, struct ast *expr); static int lower_closure(struct state *state, struct ast *closure); static int lower_block(struct state *state, struct ast *block, bool ret); static int lower_statement(struct state *state, struct ast *stmt, bool ret); static int lower_type(struct type *type); static int lower_types(struct type *types); static int lower_binop(struct state *state, struct ast *binop) { printf("("); if (lower_expr(state, binop_left(binop))) return -1; switch (binop->k) { case AST_ADD: printf(" + "); break; case AST_SUB: printf(" - "); break; case AST_MUL: printf(" * "); break; case AST_DIV: printf(" / "); break; case AST_REM: printf(" %% "); break; case AST_LSHIFT: printf(" << "); break; case AST_RSHIFT: printf(" >> "); break; default: internal_error("missing binop lowering"); return -1; } if (lower_expr(state, binop_right(binop))) return -1; printf(")"); return 0; } static int lower_comparison(struct state *state, struct ast *comp) { printf("("); if (lower_expr(state, comparison_left(comp))) return -1; switch (comp->k) { case AST_LT: printf(" < "); break; case AST_GT: printf(" > "); break; case AST_LE: printf(" <= "); break; case AST_GE: printf(" <= "); break; case AST_NE: printf(" != "); break; case AST_EQ: printf(" == "); break; default: internal_error("missing comparison lowering"); return -1; } if (lower_expr(state, comparison_right(comp))) return -1; printf(")"); return 0; } static int lower_exprs(struct state *state, struct ast *exprs) { if (!exprs) return 0; if (lower_expr(state, exprs)) return -1; foreach_node(expr, exprs->n) { printf(", "); if (lower_expr(state, expr)) return -1; } return 0; } static int lower_type_construct(struct type *type) { printf("%s", tconstruct_id(type)); printf("<"); if (lower_types(tconstruct_args(type))) return -1; printf(">"); return 0; } static int lower_type_callable(struct type *type) { /* std::function has a slight overhead compared to just using auto here, * but auto doesn't play well with recursive templates like with our * fib.fwd example, so use std::function for now. Eventually I might * instead write a C backend or something to have more control over the * exact semantics, but for now this is good enough. */ printf("std::functiont0)) return -1; printf(")>"); return 0; } static int lower_type_ref(struct type *type) { if (lower_type(tref_base(type))) return -1; printf("&"); return 0; } static int lower_type_ptr(struct type *type) { /* would I need parentheses in some cases? */ if (lower_type(tptr_base(type))) return -1; printf("*"); return 0; } static int lower_type(struct type *type) { switch (type->k) { case TYPE_ID: printf("%s", tid_str(type)); return 0; case TYPE_CONSTRUCT: return lower_type_construct(type); case TYPE_CLOSURE: return lower_type_callable(type); case TYPE_FUNC_PTR: return lower_type_callable(type); case TYPE_PURE_CLOSURE: return lower_type_callable(type); case TYPE_REF: return lower_type_ref(type); case TYPE_PTR: return lower_type_ptr(type); default: internal_error("missing type lowering"); return -1; } return 0; } static int lower_types(struct type *types) { if (!types) return 0; if (lower_type(types)) return -1; foreach_type(type, types->n) { printf(", "); if (lower_type(type)) return -1; } return 0; } static int lower_init(struct state *state, struct ast *init) { printf("%s", init_id(init)); if (init_args(init)) { printf("<"); if (lower_types(init_args(init))) return -1; printf(">"); } printf("{"); if (lower_exprs(state, init_body(init))) return -1; printf("}"); return 0; } static int lower_unop(struct state *state, struct ast *expr) { switch (expr->k) { case AST_LNOT: printf("-"); break; case AST_NOT: printf("~"); break; case AST_NEG: printf("-"); break; default: internal_error("missing unop lowering"); return -1; } return lower_expr(state, unop_expr(expr)); } static int lower_expr(struct state *state, struct ast *expr) { if (is_unop(expr)) return lower_unop(state, expr); if (is_binop(expr)) return lower_binop(state, expr); if (is_comparison(expr)) return lower_comparison(state, expr); switch (expr->k) { case AST_ID: printf("%s", id_str(expr)); break; case AST_CONST_INT: printf("%lld", int_val(expr)); break; case AST_CONST_FLOAT: printf("%f", float_val(expr)); break; case AST_CONST_BOOL: printf("%s", bool_val(expr) ? "true" : "false"); break; case AST_CONST_STR: printf("\"%s\"", str_val(expr)); break; case AST_CONST_CHAR: printf("'%c'", (char)char_val(expr)); break; case AST_INIT: return lower_init(state, expr); case AST_CLOSURE: return lower_closure(state, expr); case AST_REF: return lower_expr(state, ref_base(expr)); case AST_DEREF: return lower_expr(state, deref_base(expr)); default: internal_error("missing expr lowering"); return -1; } return 0; } static int lower_move(struct state *state, struct ast *move) { if (move->k == AST_ID) { /** @todo once I start messing about with references, moves * should only be outputted for parameters that take ownership */ printf("move(%s)", id_str(move)); return 0; } return lower_expr(state, move); } static int lower_moves(struct state *state, struct ast *moves) { if (!moves) return 0; if (lower_move(state, moves)) return -1; foreach_node(move, moves->n) { printf(", "); if (lower_move(state, move)) return -1; } return 0; } static int lower_err_branch(struct state *state, struct ast *err) { if (lower_block(state, err_branch_body(err), false)) return -1; printf("\n"); return 0; } static int lower_mark_moved(struct state *state, struct ast *moves) { if (!moves) return 0; foreach_node(move, moves) { if (move->k != AST_ID) continue; if (is_trivially_copyable(move->t)) continue; if (is_callable(move->t)) continue; printf("%s_owned = false;\n", id_str(move)); indent(state); } return 0; } /** @todo this is probably more complicated than it really needs to be, maybe * refactor into lower_checked_call and lower_implicit_call or something for * explicit and implicit error handling cases? */ static int lower_call(struct state *state, struct ast *call, bool ret) { struct ast *err = call_err(call); /** @todo better default error name? */ const char *err_str = err ? err_branch_id(err) : "_fwd_err"; if (lower_mark_moved(state, call_args(call))) return -1; bool direct_ret = ret && !err; if (direct_ret) printf("return "); else printf("if (auto %s = ", err_str); if (lower_expr(state, call_expr(call))) return -1; printf("("); if (lower_moves(state, call_args(call))) return -1; if (direct_ret) { printf(");\n"); return 0; } printf("))"); if (err) { if (lower_err_branch(state, err)) return -1; if (ret) { printf("\n"); indent(state); printf("return nullptr;\n"); } return 0; } printf("\n"); increase_indent(state); indent(state); decrease_indent(state); printf("return %s;\n", err_str); if (ret) { indent(state); printf("return nullptr;\n"); } return 0; } static int lower_let(struct state *state, struct ast *let, bool ret) { if (lower_var(let_var(let))) return -1; printf(" = "); if (lower_expr(state, let_expr(let))) return -1; printf(";\n"); if (ret) { indent(state); printf("return nullptr;\n"); } return 0; } static int lower_if(struct state *state, struct ast *stmt, bool ret) { printf("if ("); if (lower_expr(state, if_cond(stmt))) return -1; printf(") "); if (lower_block(state, if_body(stmt), ret)) return -1; if (!if_else(stmt)) { printf("\n"); return 0; } printf(" else "); if (lower_block(state, if_else(stmt), ret)) return -1; printf("\n"); return 0; } static int lower_error(struct ast *err) { assert(error_str(err) || error_id(err)); if (error_str(err)) { printf("return %s;\n", error_str(err)); return 0; } struct ast *id = error_id(err); printf("return %s;\n", id_str(id)); return 0; } static int lower_own(struct state *state, struct ast *stmt, bool ret) { /** @todo name mangling */ printf("if (!%s_owned) ", own_id(stmt)); if (lower_block(state, own_body(stmt), ret)) return -1; printf("\n"); return 0; } static int lower_statement(struct state *state, struct ast *stmt, bool ret) { switch (stmt->k) { case AST_OWN: return lower_own(state, stmt, ret); case AST_LET: return lower_let(state, stmt, ret); case AST_CALL: return lower_call(state, stmt, ret); case AST_IF: return lower_if(state, stmt, ret); case AST_EMPTY: if (ret) printf("return nullptr;\n"); else printf("\n"); return 0; default: internal_error("missing statement lowering"); return -1; } return 0; } static int lower_block_vars(struct state *state, struct ast *block) { struct scope *scope = block->scope; bool populated = false; foreach(visible, n, &scope->symbols) { struct ast *def = n->data; if (def->k != AST_VAR_DEF) continue; if (is_trivially_copyable(def->t)) continue; if (is_callable(def->t)) continue; if (!populated) { indent(state); printf("[[maybe_unused]] bool %s_owned = true", var_id(def)); populated = true; continue; } printf(", %s_owned = true", var_id(def)); } if (populated) printf(";\n\n"); return 0; } static int lower_block(struct state *state, struct ast *block, bool ret) { printf("{\n"); increase_indent(state); if (lower_block_vars(state, block)) return -1; foreach_node(stmt, block_body(block)) { indent(state); bool returning = block_error(block) ? false : ret && !stmt->n; if (lower_statement(state, stmt, returning)) return -1; } if (block_error(block)) { indent(state); if (lower_error(block_error(block))) return -1; } else if (!block_body(block)) { indent(state); printf("return nullptr;\n"); } decrease_indent(state); indent(state); printf("}"); return 0; } static int lower_var(struct ast *var) { if (lower_type(var_type(var))) return -1; printf(" %s", var_id(var)); return 0; } static int lower_vars(struct ast *vars) { if (!vars) return 0; if (lower_var(vars)) return -1; foreach_node(var, vars->n) { printf(", "); if (lower_var(var)) return -1; } return 0; } static int lower_closure(struct state *state, struct ast *closure) { printf("[&]("); if (lower_vars(closure_bindings(closure))) return -1; printf(")"); if (lower_block(state, closure_body(closure), true)) return -1; return 0; } static int lower_proto(struct ast *proc) { /* 'extern' functions should be provided to us by whatever framework the * user is using */ if (!proc_body(proc)) return 0; printf("fwd_err_t "); if (strcmp("main", proc_id(proc)) == 0) printf("fwd_main("); else printf("%s(", proc_id(proc)); if (lower_vars(proc_params(proc))) return -1; printf(");\n\n"); return 0; } static int lower_proc(struct ast *proc) { if (!proc_body(proc)) return 0; printf("fwd_err_t "); if (strcmp("main", proc_id(proc)) == 0) printf("fwd_main("); else printf("%s(", proc_id(proc)); if (lower_vars(proc_params(proc))) return -1; printf(")\n"); struct state state = {0}; if (lower_block(&state, proc_body(proc), true)) return -1; printf("\n\n"); return 0; } int lower(struct scope *root) { printf("#include \n"); foreach(visible, visible, &root->symbols) { struct ast *proc = visible->data; assert(proc->k == AST_PROC_DEF); if (lower_proto(proc)) return -1; } foreach(visible, visible, &root->symbols) { struct ast *proc = visible->data; if (lower_proc(proc)) return -1; } puts("int main()"); puts("{"); puts(" fwd_err_t err = fwd_main();"); puts(" if (err) {"); puts(" fprintf(stderr, \"%s\", err);"); puts(" return -1;"); puts(" }"); puts("}"); return 0; }