#include <fwd/move.h>

struct ast_pair {
	struct ast *def;
	struct ast *use;
};

#define SPTREE_TYPE struct ast_pair
#define SPTREE_CMP(a, b) ((a).def - (b).def)
#define SPTREE_NAME moved
#include <fwd/sptree.h>

struct state {
	struct moved moved;
	struct state *parent;
};

static struct state create_state(struct state *parent)
{
	struct state state = {};
	state.parent = parent;
	state.moved = moved_create();
	return state;
}

static void destroy_state(struct state *state)
{
	moved_destroy(&state->moved);
}

static struct ast_pair *find_move(struct state *state, struct ast *def)
{
	if (!state)
		return NULL;

	struct ast_pair search = {.def = def};
	struct ast_pair *found = moved_find(&state->moved, search);
	if (found)
		return found;

	return find_move(state->parent, def);
}

static void insert_move(struct state *state, struct ast *def, struct ast *use)
{
	struct ast_pair pair =  {.def = def, .use = use};
	moved_insert(&state->moved, pair);
}

static void push_up(struct state *state)
{
	struct state *parent = state->parent;
	if (!parent)
		return;

	if (moved_len(&state->moved) == 0)
		return;

	foreach(moved, n, &state->moved) {
		moved_insert(&parent->moved, *n);
	}
}

static int mvcheck(struct state *state, struct ast *node);
static int mvcheck_list(struct state *state, struct ast *nodes);

static int mvcheck_proc(struct state *state, struct ast *node)
{
	/* extern, can't really do anything so just say it's fine */
	if (!proc_body(node))
		return 0;

	struct state new_state = create_state(state);
	/* we don't need to merge things into the parent state */
	int ret = mvcheck(&new_state, proc_body(node));
	destroy_state(&new_state);
	return ret;
}

static int mvcheck_block(struct state *state, struct ast *node)
{
	return mvcheck_list(state, block_body(node));
}

static int mvcheck_list(struct state *state, struct ast *nodes)
{
	foreach_node(node, nodes) {
		if (mvcheck(state, node))
			return -1;
	}

	return 0;
}

static int mvcheck_call(struct state *state, struct ast *node)
{
	return mvcheck_list(state, call_args(node));
}

static int mvcheck_closure(struct state *state, struct ast *node)
{
	struct state new_state = create_state(state);
	int ret = mvcheck(&new_state, closure_body(node));
	push_up(&new_state);

	destroy_state(&new_state);
	return ret;
}

static int mvcheck_id(struct state *state, struct ast *node)
{
	struct ast *def = file_scope_find_symbol(node->scope, id_str(node));
	assert(def);

	struct ast_pair *prev = find_move(state, def);
	if (prev) {
		move_error(node, prev->use);
		return -1;
	}

	insert_move(state, def, node);
	return 0;
}

static int mvcheck(struct state *state, struct ast *node)
{
	switch (node->k) {
	case AST_PROC_DEF:	return mvcheck_proc	(state, node);
	case AST_BLOCK:		return mvcheck_block	(state, node);
	case AST_CALL:		return mvcheck_call	(state, node);
	case AST_CLOSURE:	return mvcheck_closure	(state, node);
	case AST_ID:		return mvcheck_id	(state, node);
	default: break;
	}

	internal_error("missing move check for %s", ast_str(node->k));
	return -1;
}

int mvcheck_root(struct ast *root)
{
	foreach_node(node, root) {
		struct state state = create_state(NULL);
		if (mvcheck(&state, node))
			return -1;
	}

	return 0;
}