From 598be4cd1bdd79e4859ae30291f4d65682cc672a Mon Sep 17 00:00:00 2001
From: Kimplul <kimi.h.kuparinen@gmail.com>
Date: Thu, 9 Jan 2025 22:26:02 +0200
Subject: initial reference checking

---
 src/move.c | 158 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 148 insertions(+), 10 deletions(-)

(limited to 'src/move.c')

diff --git a/src/move.c b/src/move.c
index ef9bf7f..9a89f75 100644
--- a/src/move.c
+++ b/src/move.c
@@ -12,20 +12,38 @@ struct ast_pair {
 
 struct state {
 	struct moved moved;
+	struct moved referenced;
 	struct state *parent;
+
+	/* pure means no moves allowed */
+	bool pure;
 };
 
+static bool in_pure(struct state *state)
+{
+	if (!state)
+		return false;
+
+	if (state->pure)
+		return true;
+
+	return in_pure(state->parent);
+}
+
 static struct state create_state(struct state *parent)
 {
 	struct state state = {};
 	state.parent = parent;
+	state.pure = false;
 	state.moved = moved_create();
+	state.referenced = moved_create();
 	return state;
 }
 
 static void destroy_state(struct state *state)
 {
 	moved_destroy(&state->moved);
+	moved_destroy(&state->referenced);
 }
 
 static struct ast_pair *find_move(struct state *state, struct ast *def)
@@ -41,13 +59,33 @@ static struct ast_pair *find_move(struct state *state, struct ast *def)
 	return find_move(state->parent, def);
 }
 
+static struct ast_pair *find_reference(struct state *state, struct ast *def)
+{
+	if (!state)
+		return NULL;
+
+	struct ast_pair search = {.def = def};
+	struct ast_pair *found = moved_find(&state->referenced, search);
+	if (found)
+		return found;
+
+	return find_reference(state->parent, def);
+}
+
 static void insert_move(struct state *state, struct ast *def, struct ast *use)
 {
-	struct ast_pair pair =  {.def = def, .use = use};
+	struct ast_pair pair = {.def = def, .use = use};
 	moved_insert(&state->moved, pair);
 }
 
-static void merge(struct state *to, struct state *from)
+static void insert_reference(struct state *state, struct ast *def,
+                             struct ast *use)
+{
+	struct ast_pair pair = {.def = def, .use = use};
+	moved_insert(&state->referenced, pair);
+}
+
+static void merge_moves(struct state *to, struct state *from)
 {
 	if (moved_len(&from->moved) == 0)
 		return;
@@ -57,6 +95,22 @@ static void merge(struct state *to, struct state *from)
 	}
 }
 
+static void merge_references(struct state *to, struct state *from)
+{
+	if (moved_len(&from->referenced) == 0)
+		return;
+
+	foreach(moved, n, &from->referenced) {
+		moved_insert(&to->referenced, *n);
+	}
+}
+
+static void merge(struct state *to, struct state *from)
+{
+	merge_moves(to, from);
+	merge_references(to, from);
+}
+
 static void push_up(struct state *state)
 {
 	struct state *parent = state->parent;
@@ -66,8 +120,60 @@ static void push_up(struct state *state)
 	merge(parent, state);
 }
 
+static void forget_references(struct state *state)
+{
+	moved_destroy(&state->referenced);
+}
+
 static int mvcheck(struct state *state, struct ast *node);
 static int mvcheck_list(struct state *state, struct ast *nodes);
+static int mvcheck_statements(struct state *state, struct ast *nodes);
+
+static int refcheck_id(struct state *state, struct ast *node)
+{
+	struct ast *def = file_scope_find_symbol(node->scope, id_str(node));
+	assert(def);
+
+	if (def->k != AST_VAR_DEF) {
+		semantic_error(node->scope, node, "cannot reference functions");
+		return -1;
+	}
+
+	struct ast_pair *prev = find_move(state, def);
+	if (prev) {
+		move_error(node, prev->use);
+		return -1;
+	}
+
+	prev = find_reference(state, def);
+	if (prev) {
+		reference_error(node, prev->use);
+		return -1;
+	}
+
+	insert_reference(state, def, node);
+	return 0;
+}
+
+static int refcheck(struct state *state, struct ast *node)
+{
+	assert(is_lvalue(node));
+
+	switch (node->k) {
+	case AST_ID: return refcheck_id(state, node);
+	default:
+		internal_error("unhandled node %s for refcheck",
+		               ast_str(node->k));
+		return -1;
+	}
+
+	return 0;
+}
+
+static int mvcheck_ref(struct state *state, struct ast *node)
+{
+	return refcheck(state, ref_base(node));
+}
 
 static int mvcheck_proc(struct state *state, struct ast *node)
 {
@@ -84,7 +190,7 @@ static int mvcheck_proc(struct state *state, struct ast *node)
 
 static int mvcheck_block(struct state *state, struct ast *node)
 {
-	return mvcheck_list(state, block_body(node));
+	return mvcheck_statements(state, block_body(node));
 }
 
 static int mvcheck_call(struct state *state, struct ast *node)
@@ -97,16 +203,18 @@ static int mvcheck_call(struct state *state, struct ast *node)
 		return 0;
 
 	int ret = 0;
-	long prev_group = args->t->group;
 	struct state group_state = create_state(state);
 
 	foreach_node(arg, call_args(node)) {
 		struct state arg_state = create_state(state);
 		ret = mvcheck(&arg_state, arg);
 
-		long group = arg->t->group;
-		if (prev_group != group) {
-			/* previous group ended, push states up */
+		merge(&group_state, &arg_state);
+		destroy_state(&arg_state);
+
+		struct ast *next = arg->n;
+		if (!next || next->t->group != arg->t->group) {
+			/* this group ended, push state immediately */
 			push_up(&group_state);
 
 			/* something like reset_state would maybe me more clear? */
@@ -114,9 +222,6 @@ static int mvcheck_call(struct state *state, struct ast *node)
 			group_state = create_state(state);
 		}
 
-		merge(&group_state, &arg_state);
-		destroy_state(&arg_state);
-		prev_group = group;
 		if (ret)
 			break;
 	}
@@ -128,6 +233,8 @@ static int mvcheck_call(struct state *state, struct ast *node)
 static int mvcheck_closure(struct state *state, struct ast *node)
 {
 	struct state new_state = create_state(state);
+	new_state.pure = ast_flags(node, AST_FLAG_NOMOVES);
+
 	int ret = mvcheck(&new_state, closure_body(node));
 	push_up(&new_state);
 
@@ -164,6 +271,12 @@ static int mvcheck_id(struct state *state, struct ast *node)
 	if (def->k != AST_VAR_DEF)
 		return 0;
 
+	if (in_pure(state)) {
+		semantic_error(node->scope, node,
+		               "move in pure context not allowed");
+		return -1;
+	}
+
 	struct ast_pair *prev = find_move(state, def);
 	if (prev) {
 		/* error messages for opt groups could be improved */
@@ -171,6 +284,12 @@ static int mvcheck_id(struct state *state, struct ast *node)
 		return -1;
 	}
 
+	prev = find_reference(state, def);
+	if (prev) {
+		reference_error(node, prev->use);
+		return -1;
+	}
+
 	insert_move(state, def, node);
 
 	if (def->or || def->ol) {
@@ -252,6 +371,23 @@ static int mvcheck_list(struct state *state, struct ast *nodes)
 	return 0;
 }
 
+static int mvcheck_statements(struct state *state, struct ast *nodes)
+{
+	foreach_node(node, nodes) {
+		struct state new_state = create_state(state);
+		if (mvcheck(&new_state, node)) {
+			destroy_state(&new_state);
+			return -1;
+		}
+
+		forget_references(&new_state);
+		push_up(&new_state);
+		destroy_state(&new_state);
+	}
+
+	return 0;
+}
+
 static int mvcheck(struct state *state, struct ast *node)
 {
 	if (is_unop(node))
@@ -272,6 +408,8 @@ static int mvcheck(struct state *state, struct ast *node)
 	case AST_ID:            return mvcheck_id       (state, node);
 	case AST_INIT:          return mvcheck_init     (state, node);
 	case AST_IF:            return mvcheck_if       (state, node);
+	case AST_REF:           return mvcheck_ref      (state, node);
+	case AST_EMPTY:
 	case AST_CONST_INT:
 	case AST_CONST_STR:
 	case AST_CONST_FLOAT:
-- 
cgit v1.2.3