diff options
author | Kimplul <kimi.h.kuparinen@gmail.com> | 2025-01-09 22:26:02 +0200 |
---|---|---|
committer | Kimplul <kimi.h.kuparinen@gmail.com> | 2025-01-09 22:26:02 +0200 |
commit | 598be4cd1bdd79e4859ae30291f4d65682cc672a (patch) | |
tree | 6e7e7ad537214c78049c4b3b2ee694c3b549fa4e /src | |
parent | 6f7c2d6daa5c706d441ddc42c5c6409e5266049a (diff) | |
download | fwd-598be4cd1bdd79e4859ae30291f4d65682cc672a.tar.gz fwd-598be4cd1bdd79e4859ae30291f4d65682cc672a.zip |
initial reference checking
Diffstat (limited to 'src')
-rw-r--r-- | src/analyze.c | 77 | ||||
-rw-r--r-- | src/ast.c | 4 | ||||
-rw-r--r-- | src/debug.c | 20 | ||||
-rw-r--r-- | src/lower.c | 8 | ||||
-rw-r--r-- | src/move.c | 158 | ||||
-rw-r--r-- | src/parser.y | 38 |
6 files changed, 257 insertions, 48 deletions
diff --git a/src/analyze.c b/src/analyze.c index ab08f01..dd85c08 100644 --- a/src/analyze.c +++ b/src/analyze.c @@ -36,8 +36,8 @@ static int analyze_proc(struct state *state, struct scope *scope, if (analyze_list(&proc_state, proc_scope, proc_params(node))) return -1; - struct type *callable = tgen_callable(NULL, node->loc); - if (!callable) { + struct type *proc_type = tgen_func_ptr(NULL, node->loc); + if (!proc_type) { internal_error("failed allocating proc type"); return -1; } @@ -53,10 +53,10 @@ static int analyze_proc(struct state *state, struct scope *scope, /* otherwise same or ending group, don't increment */ param->t->group = group; - type_append(&callable->t0, clone_type(param->t)); + type_append(&proc_type->t0, clone_type(param->t)); } - node->t = callable; + node->t = proc_type; return analyze(&proc_state, proc_scope, proc_body(node)); } @@ -179,7 +179,6 @@ static int analyze_let(struct state *state, struct scope *scope, return -1; } - /** @todo check move semantics, maybe in another pass? */ node->t = l; return 0; } @@ -249,7 +248,28 @@ static int analyze_call(struct state *state, struct scope *scope, return -1; } - /** @todo check move semantics? */ + return 0; +} + +static int analyze_ref(struct state *state, struct scope *scope, + struct ast *node) +{ + struct ast *expr = ref_base(node); + if (analyze(state, scope, expr)) + return -1; + + if (!is_lvalue(expr)) { + semantic_error(node->scope, node, "trying to reference rvalue"); + return -1; + } + + struct type *ref = tgen_ref(expr->t, node->loc); + if (!ref) { + internal_error("failed allocating ref type"); + return -1; + } + + node->t = ref; return 0; } @@ -295,15 +315,17 @@ static int analyze_closure(struct state *state, struct scope *scope, if (analyze(state, closure_scope, closure_body(node))) return -1; - struct type *callable = tgen_callable(NULL, node->loc); + struct type *callable = NULL; + if (ast_flags(node, AST_FLAG_NOMOVES)) + callable = tgen_pure_closure(NULL, node->loc); + else + callable = tgen_closure(NULL, node->loc); + if (!callable) { internal_error("failed allocating closure type"); return -1; } - /** @todo use analysis to figure out if this closure can be called - * multiple times or just once */ - foreach_node(binding, closure_bindings(node)) { type_append(&callable->t0, clone_type(binding->t)); } @@ -417,28 +439,19 @@ static int analyze(struct state *state, struct scope *scope, struct ast *node) } switch (node->k) { - case AST_PROC_DEF: ret = analyze_proc (state, scope, node); - break; - case AST_VAR_DEF: ret = analyze_var (state, scope, node); - break; - case AST_BLOCK: ret = analyze_block (state, scope, node); - break; - case AST_LET: ret = analyze_let (state, scope, node); - break; - case AST_INIT: ret = analyze_init (state, scope, node); - break; - case AST_CALL: ret = analyze_call (state, scope, node); - break; - case AST_ID: ret = analyze_id (state, scope, node); - break; - case AST_IF: ret = analyze_if (state, scope, node); - break; - case AST_CLOSURE: ret = analyze_closure (state, scope, node); - break; - case AST_CONST_INT: ret = analyze_int (state, scope, node); - break; - case AST_CONST_STR: ret = analyze_str (state, scope, node); - break; + case AST_CLOSURE: ret = analyze_closure(state, scope, node); break; + case AST_PROC_DEF: ret = analyze_proc(state, scope, node); break; + case AST_VAR_DEF: ret = analyze_var(state, scope, node); break; + case AST_BLOCK: ret = analyze_block(state, scope, node); break; + case AST_INIT: ret = analyze_init(state, scope, node); break; + case AST_CALL: ret = analyze_call(state, scope, node); break; + case AST_REF: ret = analyze_ref(state, scope, node); break; + case AST_LET: ret = analyze_let(state, scope, node); break; + case AST_ID: ret = analyze_id(state, scope, node); break; + case AST_IF: ret = analyze_if(state, scope, node); break; + + case AST_CONST_INT: ret = analyze_int(state, scope, node); break; + case AST_CONST_STR: ret = analyze_str(state, scope, node); break; case AST_EMPTY: node->t = tgen_void(node->loc); @@ -228,6 +228,8 @@ void ast_dump(int depth, struct ast *n) DUMP(AST_NEG); DUMP(AST_LNOT); DUMP(AST_NOT); + DUMP(AST_REF); + DUMP(AST_DEREF); DUMP(AST_CONST_INT); DUMP(AST_CONST_CHAR); DUMP(AST_CONST_BOOL); @@ -652,6 +654,8 @@ const char *ast_str(enum ast_kind k) CASE(AST_NEG); CASE(AST_LNOT); CASE(AST_NOT); + CASE(AST_REF); + CASE(AST_DEREF); CASE(AST_CONST_INT); CASE(AST_CONST_CHAR); CASE(AST_CONST_BOOL); diff --git a/src/debug.c b/src/debug.c index e65919b..1f4d3a8 100644 --- a/src/debug.c +++ b/src/debug.c @@ -150,6 +150,12 @@ void move_error(struct ast *new_use, struct ast *prev_use) semantic_info(prev_use->scope, prev_use, "previously moved here"); } +void reference_error(struct ast *new_use, struct ast *prev_use) +{ + semantic_error(new_use->scope, new_use, "using referenced value"); + semantic_info(prev_use->scope, prev_use, "previously referenced here"); +} + static void _type_str(FILE *f, struct type *t); static void _type_list_str(FILE *f, struct type *types) @@ -189,12 +195,24 @@ static void _type_str(FILE *f, struct type *type) fprintf(f, "%s", type->id); break; - case TYPE_CALLABLE: + case TYPE_CLOSURE: fprintf(f, "("); _type_list_str(f, type->t0); fprintf(f, ")"); break; + case TYPE_PURE_CLOSURE: + fprintf(f, "&("); + _type_list_str(f, type->t0); + fprintf(f, ")"); + break; + + case TYPE_FUNC_PTR: + fprintf(f, "*("); + _type_list_str(f, type->t0); + fprintf(f, ")"); + break; + case TYPE_CONSTRUCT: fprintf(f, "%s![", type->id); _type_list_str(f, type->t0); diff --git a/src/lower.c b/src/lower.c index 4b9fae9..9f806f3 100644 --- a/src/lower.c +++ b/src/lower.c @@ -130,7 +130,7 @@ static int lower_type_callable(struct type *type) * exact semantics, but for now this is good enough. */ printf("std::function<void("); - if (lower_types(tcallable_args(type))) + if (lower_types(type->t0)) return -1; printf(")>"); @@ -161,7 +161,9 @@ static int lower_type(struct type *type) switch (type->k) { case TYPE_ID: printf("%s", tid_str(type)); return 0; case TYPE_CONSTRUCT: return lower_type_construct(type); - case TYPE_CALLABLE: return lower_type_callable(type); + case TYPE_CLOSURE: return lower_type_callable(type); + case TYPE_FUNC_PTR: return lower_type_callable(type); + case TYPE_PURE_CLOSURE: return lower_type_callable(type); case TYPE_REF: return lower_type_ref(type); case TYPE_PTR: return lower_type_ptr(type); default: @@ -245,6 +247,7 @@ static int lower_expr(struct state *state, struct ast *expr) case AST_CONST_CHAR: printf("'%c'", (char)char_val(expr)); break; case AST_INIT: return lower_init(state, expr); case AST_CLOSURE: return lower_closure(state, expr); + case AST_REF: return lower_expr(state, ref_base(expr)); default: internal_error("missing expr lowering"); return -1; @@ -341,6 +344,7 @@ static int lower_statement(struct state *state, struct ast *stmt) case AST_LET: return lower_let(state, stmt); case AST_CALL: return lower_call(state, stmt); case AST_IF: return lower_if(state, stmt); + case AST_EMPTY: return 0; default: internal_error("missing statement lowering"); return -1; @@ -12,20 +12,38 @@ struct ast_pair { struct state { struct moved moved; + struct moved referenced; struct state *parent; + + /* pure means no moves allowed */ + bool pure; }; +static bool in_pure(struct state *state) +{ + if (!state) + return false; + + if (state->pure) + return true; + + return in_pure(state->parent); +} + static struct state create_state(struct state *parent) { struct state state = {}; state.parent = parent; + state.pure = false; state.moved = moved_create(); + state.referenced = moved_create(); return state; } static void destroy_state(struct state *state) { moved_destroy(&state->moved); + moved_destroy(&state->referenced); } static struct ast_pair *find_move(struct state *state, struct ast *def) @@ -41,13 +59,33 @@ static struct ast_pair *find_move(struct state *state, struct ast *def) return find_move(state->parent, def); } +static struct ast_pair *find_reference(struct state *state, struct ast *def) +{ + if (!state) + return NULL; + + struct ast_pair search = {.def = def}; + struct ast_pair *found = moved_find(&state->referenced, search); + if (found) + return found; + + return find_reference(state->parent, def); +} + static void insert_move(struct state *state, struct ast *def, struct ast *use) { - struct ast_pair pair = {.def = def, .use = use}; + struct ast_pair pair = {.def = def, .use = use}; moved_insert(&state->moved, pair); } -static void merge(struct state *to, struct state *from) +static void insert_reference(struct state *state, struct ast *def, + struct ast *use) +{ + struct ast_pair pair = {.def = def, .use = use}; + moved_insert(&state->referenced, pair); +} + +static void merge_moves(struct state *to, struct state *from) { if (moved_len(&from->moved) == 0) return; @@ -57,6 +95,22 @@ static void merge(struct state *to, struct state *from) } } +static void merge_references(struct state *to, struct state *from) +{ + if (moved_len(&from->referenced) == 0) + return; + + foreach(moved, n, &from->referenced) { + moved_insert(&to->referenced, *n); + } +} + +static void merge(struct state *to, struct state *from) +{ + merge_moves(to, from); + merge_references(to, from); +} + static void push_up(struct state *state) { struct state *parent = state->parent; @@ -66,8 +120,60 @@ static void push_up(struct state *state) merge(parent, state); } +static void forget_references(struct state *state) +{ + moved_destroy(&state->referenced); +} + static int mvcheck(struct state *state, struct ast *node); static int mvcheck_list(struct state *state, struct ast *nodes); +static int mvcheck_statements(struct state *state, struct ast *nodes); + +static int refcheck_id(struct state *state, struct ast *node) +{ + struct ast *def = file_scope_find_symbol(node->scope, id_str(node)); + assert(def); + + if (def->k != AST_VAR_DEF) { + semantic_error(node->scope, node, "cannot reference functions"); + return -1; + } + + struct ast_pair *prev = find_move(state, def); + if (prev) { + move_error(node, prev->use); + return -1; + } + + prev = find_reference(state, def); + if (prev) { + reference_error(node, prev->use); + return -1; + } + + insert_reference(state, def, node); + return 0; +} + +static int refcheck(struct state *state, struct ast *node) +{ + assert(is_lvalue(node)); + + switch (node->k) { + case AST_ID: return refcheck_id(state, node); + default: + internal_error("unhandled node %s for refcheck", + ast_str(node->k)); + return -1; + } + + return 0; +} + +static int mvcheck_ref(struct state *state, struct ast *node) +{ + return refcheck(state, ref_base(node)); +} static int mvcheck_proc(struct state *state, struct ast *node) { @@ -84,7 +190,7 @@ static int mvcheck_proc(struct state *state, struct ast *node) static int mvcheck_block(struct state *state, struct ast *node) { - return mvcheck_list(state, block_body(node)); + return mvcheck_statements(state, block_body(node)); } static int mvcheck_call(struct state *state, struct ast *node) @@ -97,16 +203,18 @@ static int mvcheck_call(struct state *state, struct ast *node) return 0; int ret = 0; - long prev_group = args->t->group; struct state group_state = create_state(state); foreach_node(arg, call_args(node)) { struct state arg_state = create_state(state); ret = mvcheck(&arg_state, arg); - long group = arg->t->group; - if (prev_group != group) { - /* previous group ended, push states up */ + merge(&group_state, &arg_state); + destroy_state(&arg_state); + + struct ast *next = arg->n; + if (!next || next->t->group != arg->t->group) { + /* this group ended, push state immediately */ push_up(&group_state); /* something like reset_state would maybe me more clear? */ @@ -114,9 +222,6 @@ static int mvcheck_call(struct state *state, struct ast *node) group_state = create_state(state); } - merge(&group_state, &arg_state); - destroy_state(&arg_state); - prev_group = group; if (ret) break; } @@ -128,6 +233,8 @@ static int mvcheck_call(struct state *state, struct ast *node) static int mvcheck_closure(struct state *state, struct ast *node) { struct state new_state = create_state(state); + new_state.pure = ast_flags(node, AST_FLAG_NOMOVES); + int ret = mvcheck(&new_state, closure_body(node)); push_up(&new_state); @@ -164,6 +271,12 @@ static int mvcheck_id(struct state *state, struct ast *node) if (def->k != AST_VAR_DEF) return 0; + if (in_pure(state)) { + semantic_error(node->scope, node, + "move in pure context not allowed"); + return -1; + } + struct ast_pair *prev = find_move(state, def); if (prev) { /* error messages for opt groups could be improved */ @@ -171,6 +284,12 @@ static int mvcheck_id(struct state *state, struct ast *node) return -1; } + prev = find_reference(state, def); + if (prev) { + reference_error(node, prev->use); + return -1; + } + insert_move(state, def, node); if (def->or || def->ol) { @@ -252,6 +371,23 @@ static int mvcheck_list(struct state *state, struct ast *nodes) return 0; } +static int mvcheck_statements(struct state *state, struct ast *nodes) +{ + foreach_node(node, nodes) { + struct state new_state = create_state(state); + if (mvcheck(&new_state, node)) { + destroy_state(&new_state); + return -1; + } + + forget_references(&new_state); + push_up(&new_state); + destroy_state(&new_state); + } + + return 0; +} + static int mvcheck(struct state *state, struct ast *node) { if (is_unop(node)) @@ -272,6 +408,8 @@ static int mvcheck(struct state *state, struct ast *node) case AST_ID: return mvcheck_id (state, node); case AST_INIT: return mvcheck_init (state, node); case AST_IF: return mvcheck_if (state, node); + case AST_REF: return mvcheck_ref (state, node); + case AST_EMPTY: case AST_CONST_INT: case AST_CONST_STR: case AST_CONST_FLOAT: diff --git a/src/parser.y b/src/parser.y index dd177d0..751aaa1 100644 --- a/src/parser.y +++ b/src/parser.y @@ -227,9 +227,31 @@ type | APPLY "[" opt_types "]" { $$ = tgen_construct($[APPLY], $[opt_types], src_loc(@$)); } - | "(" opt_types ")" { $$ = tgen_callable($2, src_loc(@$)); } - | "&" type { $$ = tgen_ref($2, src_loc(@$)); } - | "*" type { $$ = tgen_ptr($2, src_loc(@$)); } + | "(" opt_types ")" { $$ = tgen_closure($2, src_loc(@$)); } + | "&&" type { + if ($2->k == TYPE_CLOSURE) { + $$ = $2; $$->k = TYPE_PURE_CLOSURE; + } else { + $$ = tgen_ref($2, src_loc(@$)); + } + + /* heh */ + $$ = tgen_ref($$, src_loc(@$)); + } + | "&" type { + if ($2->k == TYPE_CLOSURE) { + $$ = $2; $$->k = TYPE_PURE_CLOSURE; + } else { + $$ = tgen_ref($2, src_loc(@$)); + } + } + | "*" type { + if ($2->k == TYPE_CLOSURE) { + $$ = $2; $$->k = TYPE_FUNC_PTR; + } else { + $$ = tgen_ptr($2, src_loc(@$)); + } + } rev_types : rev_types "," type { $$ = $3; $$->n = $1; } @@ -259,6 +281,8 @@ expr | INT { $$ = gen_const_int($1, src_loc(@$)); } | ID { $$ = gen_id($1, src_loc(@$)); } | "(" expr ")" { $$ = $2; } + | expr "&" { $$ = gen_ref($1, src_loc(@$)); } + | expr "*" { $$ = gen_deref($1, src_loc(@$)); } | construct | binop | unop @@ -276,6 +300,10 @@ opt_exprs closure : "=>" opt_vars body { $$ = gen_closure($[opt_vars], $[body], src_loc(@$)); } + | "&" "=>" opt_vars body { + $$ = gen_closure($[opt_vars], $[body], src_loc(@$)); + ast_set_flags($$, AST_FLAG_NOMOVES); + } rev_closures : rev_closures closure { $$ = $2; $2->n = $1; } @@ -283,6 +311,10 @@ rev_closures trailing_closure : "=>" opt_vars ";" { $$ = gen_closure($[opt_vars], NULL, src_loc(@$));} + | "&" "=>" opt_vars ";" { + $$ = gen_closure($[opt_vars], NULL, src_loc(@$)); + ast_set_flags($$, AST_FLAG_NOMOVES); + } opt_trailing_closure : trailing_closure |