#include struct ast_pair { struct ast *def; struct ast *use; }; #define SPTREE_TYPE struct ast_pair #define SPTREE_CMP(a, b) ((a).def - (b).def) #define SPTREE_NAME moved #include struct state { struct moved moved; struct moved queued; struct moved referenced; struct state *parent; /* pure means no moves allowed */ bool pure; }; static bool in_pure(struct state *state) { if (!state) return false; if (state->pure) return true; return in_pure(state->parent); } static struct state create_state(struct state *parent) { struct state state = {}; state.parent = parent; state.pure = false; state.queued = moved_create(); state.moved = moved_create(); state.referenced = moved_create(); return state; } static void destroy_state(struct state *state) { moved_destroy(&state->moved); moved_destroy(&state->queued); moved_destroy(&state->referenced); } struct rm_move { struct ast_pair data; struct state *owner; }; static struct rm_move remove_move(struct state *state, struct ast *def) { if (!state) return (struct rm_move){.data = {}, .owner = NULL}; struct ast_pair search = {.def = def}; struct ast_pair *found = moved_find(&state->moved, search); if (found) { moved_remove_found(&state->moved, found); struct rm_move r = {.data = *found, .owner = state}; moved_free_found(&state->moved, found); return r; } return remove_move(state->parent, def); } static void reinsert_move(struct rm_move move) { if (!move.owner) return; moved_insert(&move.owner->moved, move.data); } static struct ast_pair *find_move(struct state *state, struct ast *def) { if (!state) return NULL; struct ast_pair search = {.def = def}; struct ast_pair *found = moved_find(&state->moved, search); if (found) return found; return find_move(state->parent, def); } static struct ast_pair *find_reference(struct state *state, struct ast *def) { if (!state) return NULL; struct ast_pair search = {.def = def}; struct ast_pair *found = moved_find(&state->referenced, search); if (found) return found; return find_reference(state->parent, def); } static void insert_move(struct state *state, struct ast *def, struct ast *use) { struct ast_pair pair = {.def = def, .use = use}; moved_insert(&state->moved, pair); } static void insert_reference(struct state *state, struct ast *def, struct ast *use) { struct ast_pair pair = {.def = def, .use = use}; moved_insert(&state->referenced, pair); } static void merge_moves(struct state *to, struct state *from) { if (moved_len(&from->moved) == 0) return; foreach(moved, n, &from->moved) { moved_insert(&to->moved, *n); } } static void merge_queued(struct state *to, struct state *from) { if (moved_len(&from->queued) == 0) return; foreach(moved, n, &from->queued) { moved_insert(&to->queued, *n); } } static void merge_references(struct state *to, struct state *from) { if (moved_len(&from->referenced) == 0) return; foreach(moved, n, &from->referenced) { moved_insert(&to->referenced, *n); } } static void merge(struct state *to, struct state *from) { merge_moves(to, from); merge_queued(to, from); merge_references(to, from); } static void push_up(struct state *state) { struct state *parent = state->parent; if (!parent) return; merge(parent, state); } static void mark_queued(struct state *state) { if (moved_len(&state->moved) == 0) return; foreach(moved, m, &state->moved) { moved_insert(&state->queued, *m); } /* empty out moves now that they're all queued */ moved_destroy(&state->moved); state->moved = moved_create(); } static void mark_unqueued(struct state *state) { if (moved_len(&state->queued) == 0) return; foreach(moved, q, &state->queued) { moved_insert(&state->moved, *q); } moved_destroy(&state->queued); state->queued = moved_create(); } static void forget_references(struct state *state) { moved_destroy(&state->referenced); state->referenced = moved_create(); } static int mvcheck(struct state *state, struct ast *node); static int mvcheck_list(struct state *state, struct ast *nodes); static int mvcheck_statements(struct state *state, struct ast *nodes); static int refcheck_id(struct state *state, struct ast *node) { struct ast *def = file_scope_find_symbol(node->scope, id_str(node)); assert(def); if (def->k != AST_VAR_DEF) { semantic_error(node->scope, node, "cannot reference functions"); return -1; } struct ast_pair *prev = find_move(state, def); if (prev) { move_error(node, prev->use); return -1; } prev = find_reference(state, def); if (prev) { reference_error(node, prev->use); return -1; } insert_reference(state, def, node); return 0; } static int refcheck(struct state *state, struct ast *node) { assert(is_lvalue(node)); switch (node->k) { case AST_ID: return refcheck_id(state, node); default: internal_error("unhandled node %s for refcheck", ast_str(node->k)); return -1; } return 0; } static int mvcheck_ref(struct state *state, struct ast *node) { return refcheck(state, ref_base(node)); } static int mvcheck_deref(struct state *state, struct ast *node) { /** @todo good enough for now but probably won't hold when we start * doing element lookups in structs etc. */ return mvcheck(state, deref_base(node)); } static int mvcheck_proc(struct state *state, struct ast *node) { /* extern, can't really do anything so just say it's fine */ if (!proc_body(node)) return 0; struct state new_state = create_state(state); /* we don't need to merge things into the parent state */ int ret = mvcheck(&new_state, proc_body(node)); destroy_state(&new_state); return ret; } static int total_check_single(struct state *state, struct scope *scope) { int ret = 0; foreach(visible, n, &scope->symbols) { struct ast *def = n->data; if (def->k != AST_VAR_DEF) continue; if (is_trivially_copyable(def->t)) continue; if (is_callable(def->t)) continue; struct ast_pair *prev = find_move(state, def); if (prev) continue; semantic_warn(scope, def, "%s not moved, might leak", var_id(def)); ret |= 1; } return ret; } static int total_check_proc(struct state *state, struct scope *scope) { int ret = total_check_single(state, scope); if (scope->flags & SCOPE_PROC) return ret; if (!scope->parent) return ret; return ret | total_check_proc(state, scope->parent); } static int mvcheck_block(struct state *state, struct ast *node) { struct state new_state = create_state(state); int ret = mvcheck_statements(&new_state, block_body(node)); if (ret) { destroy_state(&new_state); return ret; } if (block_error(node)) { if (total_check_proc(&new_state, node->scope)) semantic_info(node->scope, node, "at end of block"); /* mark us queued so moves are not visible for the next scopes, * but appear to the caller function if we're in a closure */ mark_queued(&new_state); push_up(&new_state); destroy_state(&new_state); return 0; } if (total_check_single(&new_state, node->scope)) semantic_info(node->scope, node, "at end of block"); push_up(&new_state); destroy_state(&new_state); return 0; } static int mvcheck_call(struct state *state, struct ast *node) { if (mvcheck(state, call_expr(node))) return -1; struct ast *args = call_args(node); if (!args) return 0; /* get all 'normal' parameters moved */ foreach_node(arg, call_args(node)) { if (arg->k == AST_CLOSURE) continue; if (mvcheck(state, arg)) return -1; } /* check into closures */ int ret = 0; struct state buffer_state = create_state(state); struct state group_state = create_state(&buffer_state); foreach_node(arg, call_args(node)) { if (arg->k != AST_CLOSURE) continue; struct state arg_state = create_state(state); ret = mvcheck(&arg_state, arg); mark_unqueued(&arg_state); merge(&group_state, &arg_state); destroy_state(&arg_state); struct ast *next = arg->n; if (!next || next->t->group != arg->t->group) { /* this group ended, push state immediately */ push_up(&group_state); /* something like reset_state would maybe be more clear? */ destroy_state(&group_state); group_state = create_state(state); } if (ret) break; } destroy_state(&group_state); /* the next section looks a bit weird, but it kind of makes sense. We * don't know when the error branch might get taken, so we first check * it 'before' the moves from any possible closures take place to see if * there's a possibility that some variable gets leaked. The second time * around we take the closure moves into account to see if the error * branch might accidentally move an already moved variable, i.e. the * user would need to add `own` blocks. */ /** @todo this is more or less what needs to happen for correctness, but * might mean that warnings from leaks in the first check get reprinted * in the second round, should probably add in some mechanism to check * against that */ if (call_err(node)) { struct state err_state = create_state(state); ret |= mvcheck(&err_state, call_err(node)); destroy_state(&err_state); } else if (total_check_proc(state, node->scope)) { semantic_info(node->scope, node, "in implicit err branch"); } push_up(&buffer_state); destroy_state(&buffer_state); if (call_err(node)) { struct state err_state = create_state(state); ret |= mvcheck(&err_state, call_err(node)); /* store results of this check */ push_up(&err_state); destroy_state(&err_state); } /* no need to check implicit error branch since it by definition can't * move already moved variables */ return ret; } static int mvcheck_closure(struct state *state, struct ast *node) { struct state new_state = create_state(state); new_state.pure = ast_flags(node, AST_FLAG_NOMOVES); int ret = mvcheck(&new_state, closure_body(node)); push_up(&new_state); destroy_state(&new_state); if (total_check_single(state, node->scope)) semantic_info(node->scope, node, "in closure"); return ret; } static void opt_group_left(struct state *state, struct ast *def, struct ast *node) { if (!def) return; insert_move(state, def, node); opt_group_left(state, def->ol, node); } static void opt_group_right(struct state *state, struct ast *def, struct ast *node) { if (!def) return; insert_move(state, def, node); opt_group_right(state, def->or, node); } static int mvcheck_id(struct state *state, struct ast *node) { struct ast *def = file_scope_find_symbol(node->scope, id_str(node)); assert(def); /* moves only apply to variables, functions are 'eternal' */ if (def->k != AST_VAR_DEF) return 0; if (is_trivially_copyable(def->t)) return 0; if (in_pure(state)) { semantic_error(node->scope, node, "move in pure context not allowed"); return -1; } struct ast_pair *prev = find_move(state, def); if (prev) { /* error messages for opt groups could be improved */ move_error(node, prev->use); return -1; } prev = find_reference(state, def); if (prev) { reference_error(node, prev->use); return -1; } insert_move(state, def, node); if (def->or || def->ol) { /* we're part of an opt group, add all other members as if they * were moved as well */ opt_group_left(state, def->ol, node); opt_group_right(state, def->or, node); } return 0; } static int mvcheck_let(struct state *state, struct ast *node) { return mvcheck(state, let_expr(node)); } static int mvcheck_init(struct state *state, struct ast *node) { return mvcheck_list(state, init_body(node)); } static int mvcheck_if(struct state *state, struct ast *node) { if (mvcheck(state, if_cond(node))) return -1; struct state body_state = create_state(state); struct state else_state = create_state(state); if (mvcheck(&body_state, if_body(node))) { destroy_state(&body_state); destroy_state(&else_state); return -1; } if (mvcheck(&else_state, if_else(node))) { destroy_state(&body_state); destroy_state(&else_state); return -1; } push_up(&body_state); push_up(&else_state); destroy_state(&body_state); destroy_state(&else_state); return 0; } static int mvcheck_unop(struct state *state, struct ast *node) { return mvcheck(state, unop_expr(node)); } static int mvcheck_binop(struct state *state, struct ast *node) { if (mvcheck(state, binop_left(node))) return -1; return mvcheck(state, binop_right(node)); } static int mvcheck_comparison(struct state *state, struct ast *node) { if (mvcheck(state, comparison_left(node))) return -1; return mvcheck(state, comparison_right(node)); } static int mvcheck_list(struct state *state, struct ast *nodes) { foreach_node(node, nodes) { if (mvcheck(state, node)) return -1; } return 0; } static int mvcheck_statements(struct state *state, struct ast *nodes) { foreach_node(node, nodes) { struct state new_state = create_state(state); if (mvcheck(&new_state, node)) { destroy_state(&new_state); return -1; } forget_references(&new_state); push_up(&new_state); destroy_state(&new_state); } return 0; } static int mvcheck_err_branch(struct state *state, struct ast *node) { return mvcheck(state, err_branch_body(node)); } static int mvcheck_own(struct state *state, struct ast *node) { struct ast *def = file_scope_find_symbol(node->scope, own_id(node)); assert(def); if (is_callable(def->t)) { semantic_error(node->scope, node, "ownership of callable cannot be checked"); return -1; } struct rm_move prev = remove_move(state, def); int ret = mvcheck(state, own_body(node)); /* insert back */ reinsert_move(prev); return ret; } static int mvcheck(struct state *state, struct ast *node) { if (is_unop(node)) return mvcheck_unop(state, node); if (is_binop(node)) return mvcheck_binop(state, node); if (is_comparison(node)) return mvcheck_comparison(state, node); switch (node->k) { case AST_PROC_DEF: return mvcheck_proc (state, node); case AST_BLOCK: return mvcheck_block (state, node); case AST_CALL: return mvcheck_call (state, node); case AST_CLOSURE: return mvcheck_closure (state, node); case AST_ERR_BRANCH: return mvcheck_err_branch(state, node); case AST_OWN: return mvcheck_own (state, node); case AST_LET: return mvcheck_let (state, node); case AST_ID: return mvcheck_id (state, node); case AST_INIT: return mvcheck_init (state, node); case AST_IF: return mvcheck_if (state, node); case AST_REF: return mvcheck_ref (state, node); case AST_DEREF: return mvcheck_deref (state, node); case AST_EMPTY: case AST_CONST_INT: case AST_CONST_STR: case AST_CONST_FLOAT: case AST_CONST_CHAR: return 0; default: break; } internal_error("missing move check for %s", ast_str(node->k)); return -1; } int mvcheck_root(struct ast *root) { foreach_node(node, root) { struct state state = create_state(NULL); if (mvcheck(&state, node)) return -1; destroy_state(&state); } return 0; }