#include struct ast_pair { struct ast *def; struct ast *use; }; #define SPTREE_TYPE struct ast_pair #define SPTREE_CMP(a, b) ((a).def - (b).def) #define SPTREE_NAME moved #include struct state { struct moved moved; struct moved queued; struct moved referenced; struct state *parent; /* pure means no moves allowed */ bool pure; }; static bool in_pure(struct state *state) { if (!state) return false; if (state->pure) return true; return in_pure(state->parent); } static struct state create_state(struct state *parent) { struct state state = {}; state.parent = parent; state.pure = false; state.queued = moved_create(); state.moved = moved_create(); state.referenced = moved_create(); return state; } static void destroy_state(struct state *state) { moved_destroy(&state->moved); } static struct ast_pair *find_move(struct state *state, struct ast *def) { if (!state) return NULL; struct ast_pair search = {.def = def}; struct ast_pair *found = moved_find(&state->moved, search); if (found) return found; return find_move(state->parent, def); } static struct ast_pair *find_reference(struct state *state, struct ast *def) { if (!state) return NULL; struct ast_pair search = {.def = def}; struct ast_pair *found = moved_find(&state->referenced, search); if (found) return found; return find_reference(state->parent, def); } static void insert_move(struct state *state, struct ast *def, struct ast *use) { struct ast_pair pair = {.def = def, .use = use}; moved_insert(&state->moved, pair); } static void insert_reference(struct state *state, struct ast *def, struct ast *use) { struct ast_pair pair = {.def = def, .use = use}; moved_insert(&state->referenced, pair); } static void merge_moves(struct state *to, struct state *from) { if (moved_len(&from->moved) == 0) return; foreach(moved, n, &from->moved) { moved_insert(&to->moved, *n); } } static void merge_queued(struct state *to, struct state *from) { if (moved_len(&from->queued) == 0) return; foreach(moved, n, &from->queued) { moved_insert(&to->queued, *n); } } static void merge_references(struct state *to, struct state *from) { if (moved_len(&from->referenced) == 0) return; foreach(moved, n, &from->referenced) { moved_insert(&to->referenced, *n); } } static void merge(struct state *to, struct state *from) { merge_moves(to, from); merge_queued(to, from); merge_references(to, from); } static void push_up(struct state *state) { struct state *parent = state->parent; if (!parent) return; merge(parent, state); } static void forget_references(struct state *state) { moved_destroy(&state->referenced); state->referenced = moved_create(); } static int mvcheck_expr(struct state *state, struct ast *node); static int mvcheck_expr_list(struct state *state, struct ast *nodes); static int mvcheck_statements(struct state *state, struct ast *nodes, bool last); static int refcheck_id(struct state *state, struct ast *node) { struct ast *def = file_scope_find_symbol(node->scope, id_str(node)); assert(def); if (def->k != AST_VAR_DEF) { semantic_error(node->scope, node, "cannot reference functions"); return -1; } struct ast_pair *prev = find_move(state, def); if (prev) { move_error(node, prev->use); return -1; } prev = find_reference(state, def); if (prev) { reference_error(node, prev->use); return -1; } insert_reference(state, def, node); return 0; } static int refcheck(struct state *state, struct ast *node) { assert(is_lvalue(node)); switch (node->k) { case AST_ID: return refcheck_id(state, node); default: internal_error("unhandled node %s for refcheck", ast_str(node->k)); return -1; } return 0; } static int mvcheck_ref(struct state *state, struct ast *node) { return refcheck(state, ref_base(node)); } static int mvcheck_deref(struct state *state, struct ast *node) { /** @todo good enough for now but probably won't hold when we start * doing element lookups in structs etc. */ return mvcheck_expr(state, deref_base(node)); } static int total_check_single(struct state *state, struct scope *scope) { int ret = 0; foreach(visible, n, &scope->symbols) { struct ast *def = n->data; if (def->k != AST_VAR_DEF) continue; if (is_trivially_copyable(def->t)) continue; if (is_callable(def->t)) continue; struct ast_pair *prev = find_move(state, def); if (prev) continue; semantic_warn(scope, def, "%s not moved, might leak", var_id(def)); ret |= 1; } return ret; } static int total_check_proc(struct state *state, struct scope *scope) { int ret = total_check_single(state, scope); if (!scope->parent) return ret; if (scope->flags & SCOPE_PROC) return ret; return ret | total_check_proc(state, scope->parent); } static int mvcheck_block(struct state *state, struct ast *node, bool last) { assert(node->k == AST_BLOCK); struct state new_state = create_state(state); int ret = mvcheck_statements(&new_state, block_body(node), last); if (ret) { destroy_state(&new_state); return ret; } if (last && total_check_proc(&new_state, node->scope)) semantic_info(node->scope, node, "at end of block"); else if (!last && total_check_single(&new_state, node->scope)) semantic_info(node->scope, node, "at end of block"); /** @todo add exit analysis and run total_check_proc on those blocks */ push_up(&new_state); destroy_state(&new_state); return 0; } static int mvcheck_closure(struct state *state, struct ast *node, bool last) { struct state new_state = create_state(state); new_state.pure = ast_flags(node, AST_FLAG_NOMOVES); int ret = mvcheck_block(&new_state, closure_body(node), last); push_up(&new_state); destroy_state(&new_state); return ret; } static int mvcheck_call(struct state *state, struct ast *node, bool last) { if (mvcheck_expr(state, call_expr(node))) return -1; struct ast *args = call_args(node); if (!args) return 0; /* get all 'normal' parameters moved */ foreach_node(arg, call_args(node)) { if (arg->k == AST_CLOSURE) continue; if (mvcheck_expr(state, arg)) return -1; } /* count how many closure groups call has. If the call is an exit point * (last == true), then a single closure group must also be an exit * point */ size_t groups = 0; foreach_node(arg, call_args(node)) { if (arg->k != AST_CLOSURE) continue; groups++; struct ast *next = arg->n; while (next && next->t->group == arg->t->group) next = next->n; if (!next) break; arg = next; } if (!last && (groups > 0)) { semantic_error(node->scope, node, "calls with closures must currently be exit points, sorry!"); return -1; } /* check into closures */ int ret = 0; struct state buffer_state = create_state(state); struct state group_state = create_state(&buffer_state); foreach_node(arg, call_args(node)) { if (arg->k != AST_CLOSURE) continue; struct state arg_state = create_state(state); ret = mvcheck_closure(&arg_state, arg, last && (groups == 1)); merge(&group_state, &arg_state); destroy_state(&arg_state); struct ast *next = arg->n; if (!next || next->t->group != arg->t->group) { /* this group ended, push state immediately */ push_up(&group_state); /* something like reset_state would maybe be more clear? */ destroy_state(&group_state); group_state = create_state(state); } if (ret) break; } destroy_state(&group_state); push_up(&buffer_state); destroy_state(&buffer_state); return ret; } static void opt_group_left(struct state *state, struct ast *def, struct ast *node) { if (!def) return; insert_move(state, def, node); opt_group_left(state, def->ol, node); } static void opt_group_right(struct state *state, struct ast *def, struct ast *node) { if (!def) return; insert_move(state, def, node); opt_group_right(state, def->or, node); } static int mvcheck_id(struct state *state, struct ast *node) { struct ast *def = file_scope_find_symbol(node->scope, id_str(node)); assert(def); /* moves only apply to variables, functions are 'eternal' */ if (def->k != AST_VAR_DEF) return 0; if (is_trivially_copyable(def->t)) return 0; if (in_pure(state)) { semantic_error(node->scope, node, "move in pure context not allowed"); return -1; } struct ast_pair *prev = find_move(state, def); if (prev) { /* error messages for opt groups could be improved */ move_error(node, prev->use); return -1; } prev = find_reference(state, def); if (prev) { reference_error(node, prev->use); return -1; } insert_move(state, def, node); if (def->or || def->ol) { /* we're part of an opt group, add all other members as if they * were moved as well */ opt_group_left(state, def->ol, node); opt_group_right(state, def->or, node); } return 0; } static int mvcheck_let(struct state *state, struct ast *node) { return mvcheck_expr(state, let_expr(node)); } static int mvcheck_init(struct state *state, struct ast *node) { return mvcheck_expr_list(state, init_body(node)); } static int mvcheck_if(struct state *state, struct ast *node, bool last) { if (!last) { semantic_error(node->scope, node, "`if` statements must currently be exit points, sorry!"); return -1; } assert(if_else(node)); /* don't check cond since it can't take ownership of anything */ struct state body_state = create_state(state); struct state else_state = create_state(state); if (mvcheck_block(&body_state, if_body(node), last)) { destroy_state(&body_state); destroy_state(&else_state); return -1; } if (mvcheck_block(&else_state, if_else(node), last)) { destroy_state(&body_state); destroy_state(&else_state); return -1; } push_up(&body_state); push_up(&else_state); destroy_state(&body_state); destroy_state(&else_state); return 0; } static int mvcheck_expr_list(struct state *state, struct ast *nodes) { foreach_node(node, nodes) { if (mvcheck_expr(state, node)) return -1; } return 0; } static int mvcheck_construct(struct state *state, struct ast *node) { foreach_node(expr, construct_members(node)) { if (mvcheck_expr(state, expr)) return -1; } return 0; } static int mvcheck_construction(struct state *state, struct ast *node) { return mvcheck_expr(state, construction_expr(node)); } static int mvcheck_as(struct state *state, struct ast *as) { return mvcheck_expr(state, as_expr(as)); } static int mvcheck_forget(struct state *state, struct ast *node) { struct ast *def = file_scope_find_symbol(node->scope, forget_id(node)); assert(def); /* act as if a move has happened */ insert_move(state, def, node); return 0; } static int mvcheck_explode(struct state *state, struct ast *node) { return mvcheck_expr(state, explode_expr(node)); } static int mvcheck_write(struct state *state, struct ast *node) { return mvcheck_expr(state, write_src(node)); } static int mvcheck_nil_check(struct state *state, struct ast *node, bool last) { if (!last) { /** @todo would this be an internal error? */ semantic_error(node->scope, node, "`nil check` must be exit point, sorry"); return -1; } assert(nil_check_rest(node)); struct state body_state = create_state(state); struct state rest_state = create_state(state); if (mvcheck_block(&body_state, nil_check_body(node), last)) { destroy_state(&body_state); destroy_state(&rest_state); return -1; } if (mvcheck_block(&rest_state, nil_check_rest(node), last)) { destroy_state(&body_state); destroy_state(&rest_state); return -1; } push_up(&body_state); push_up(&rest_state); destroy_state(&body_state); destroy_state(&rest_state); return 0; } static int mvcheck_expr(struct state *state, struct ast *node) { /* unary, binary and comparison operators (a kind of binary operator, * fair enough) must operate on primitive types, and as such don't need * to be move checked as none of them can change ownership or a value */ if (is_unop(node)) return 0; if (is_binop(node)) return 0; if (is_comparison(node)) return 0; switch (node->k) { case AST_ID: return mvcheck_id (state, node); case AST_INIT: return mvcheck_init (state, node); case AST_REF: return mvcheck_ref (state, node); case AST_DEREF: return mvcheck_deref (state, node); case AST_CONSTRUCT: return mvcheck_construct (state, node); case AST_AS: return mvcheck_as (state, node); case AST_FORGET: return mvcheck_forget (state, node); case AST_EXPLODE: return mvcheck_explode (state, node); case AST_WRITE: return mvcheck_write (state, node); case AST_CONSTRUCTION: return mvcheck_construction(state, node); case AST_SIZEOF: case AST_STRUCT_DEF: case AST_EMPTY: case AST_IMPORT: case AST_CONST_INT: case AST_CONST_STR: case AST_CONST_FLOAT: case AST_CONST_CHAR: case AST_NIL: return 0; default: break; } internal_error("missing move check for %s", ast_str(node->k)); return -1; } static int mvcheck_statement(struct state *state, struct ast *node, bool last) { switch (node->k) { case AST_CALL: return mvcheck_call (state, node, last); case AST_IF: return mvcheck_if (state, node, last); case AST_NIL_CHECK: return mvcheck_nil_check(state, node, last); case AST_EXPLODE: return mvcheck_explode (state, node); case AST_LET: return mvcheck_let (state, node); case AST_WRITE: return mvcheck_write (state, node); case AST_FORGET: return mvcheck_forget (state, node); case AST_EMPTY: return 0; default: break; } internal_error("unhandled move statement: %s", ast_str(node->k)); return -1; } static int mvcheck_statements(struct state *state, struct ast *nodes, bool last) { foreach_node(node, nodes) { struct state new_state = create_state(state); if (mvcheck_statement(&new_state, node, last && !node->n)) { destroy_state(&new_state); return -1; } forget_references(&new_state); push_up(&new_state); destroy_state(&new_state); } return 0; } static int mvcheck_proc(struct state *state, struct ast *node) { /* extern, can't really do anything so just say it's fine */ if (!proc_body(node)) return 0; struct state new_state = create_state(state); /* we don't need to merge things into the parent state */ int ret = mvcheck_block(&new_state, proc_body(node), true); destroy_state(&new_state); return ret; } static int mvcheck_top(struct state *state, struct ast *node) { switch (node->k) { case AST_PROC_DEF: return mvcheck_proc(state, node); case AST_IMPORT: return 0; case AST_STRUCT_DEF: return 0; default: break; } internal_error("missing top move check for %s", ast_str(node->k)); return -1; } int mvcheck_root(struct ast *root) { foreach_node(node, root) { struct state state = create_state(NULL); if (mvcheck_top(&state, node)) return -1; destroy_state(&state); } return 0; }