#include <fwd/move.h>

struct ast_pair {
	struct ast *def;
	struct ast *use;
};

#define SPTREE_TYPE struct ast_pair
#define SPTREE_CMP(a, b) ((a).def - (b).def)
#define SPTREE_NAME moved
#include <conts/sptree.h>

struct state {
	struct moved moved;
	struct moved queued;
	struct moved referenced;
	struct state *parent;

	/* pure means no moves allowed */
	bool pure;
};

static bool in_pure(struct state *state)
{
	if (!state)
		return false;

	if (state->pure)
		return true;

	return in_pure(state->parent);
}

static struct state create_state(struct state *parent)
{
	struct state state = {};
	state.parent = parent;
	state.pure = false;
	state.queued = moved_create();
	state.moved = moved_create();
	state.referenced = moved_create();
	return state;
}

static void destroy_state(struct state *state)
{
	moved_destroy(&state->moved);
	moved_destroy(&state->queued);
	moved_destroy(&state->referenced);
}

struct rm_move {
	struct ast_pair data;
	struct state *owner;
};

static struct rm_move remove_move(struct state *state, struct ast *def)
{
	if (!state)
		return (struct rm_move){.data = {}, .owner = NULL};

	struct ast_pair search = {.def = def};
	struct ast_pair *found = moved_find(&state->moved, search);
	if (found) {
		moved_remove_found(&state->moved, found);
		struct rm_move r = {.data = *found, .owner = state};
		moved_free_found(&state->moved, found);
		return r;
	}

	return remove_move(state->parent, def);
}

static void reinsert_move(struct rm_move move)
{
	if (!move.owner)
		return;

	moved_insert(&move.owner->moved, move.data);
}

static struct ast_pair *find_move(struct state *state, struct ast *def)
{
	if (!state)
		return NULL;

	struct ast_pair search = {.def = def};
	struct ast_pair *found = moved_find(&state->moved, search);
	if (found)
		return found;

	return find_move(state->parent, def);
}

static struct ast_pair *find_reference(struct state *state, struct ast *def)
{
	if (!state)
		return NULL;

	struct ast_pair search = {.def = def};
	struct ast_pair *found = moved_find(&state->referenced, search);
	if (found)
		return found;

	return find_reference(state->parent, def);
}

static void insert_move(struct state *state, struct ast *def, struct ast *use)
{
	struct ast_pair pair = {.def = def, .use = use};
	moved_insert(&state->moved, pair);
}

static void insert_reference(struct state *state, struct ast *def,
                             struct ast *use)
{
	struct ast_pair pair = {.def = def, .use = use};
	moved_insert(&state->referenced, pair);
}

static void merge_moves(struct state *to, struct state *from)
{
	if (moved_len(&from->moved) == 0)
		return;

	foreach(moved, n, &from->moved) {
		moved_insert(&to->moved, *n);
	}
}

static void merge_queued(struct state *to, struct state *from)
{
	if (moved_len(&from->queued) == 0)
		return;

	foreach(moved, n, &from->queued) {
		moved_insert(&to->queued, *n);
	}
}

static void merge_references(struct state *to, struct state *from)
{
	if (moved_len(&from->referenced) == 0)
		return;

	foreach(moved, n, &from->referenced) {
		moved_insert(&to->referenced, *n);
	}
}

static void merge(struct state *to, struct state *from)
{
	merge_moves(to, from);
	merge_queued(to, from);
	merge_references(to, from);
}

static void push_up(struct state *state)
{
	struct state *parent = state->parent;
	if (!parent)
		return;

	merge(parent, state);
}

static void mark_queued(struct state *state)
{
	if (moved_len(&state->moved) == 0)
		return;

	foreach(moved, m, &state->moved) {
		moved_insert(&state->queued, *m);
	}

	/* empty out moves now that they're all queued */
	moved_destroy(&state->moved);
	state->moved = moved_create();
}

static void mark_unqueued(struct state *state)
{
	if (moved_len(&state->queued) == 0)
		return;

	foreach(moved, q, &state->queued) {
		moved_insert(&state->moved, *q);
	}

	moved_destroy(&state->queued);
	state->queued = moved_create();
}

static void forget_references(struct state *state)
{
	moved_destroy(&state->referenced);
	state->referenced = moved_create();
}

static int mvcheck(struct state *state, struct ast *node);
static int mvcheck_list(struct state *state, struct ast *nodes);
static int mvcheck_statements(struct state *state, struct ast *nodes);

static int refcheck_id(struct state *state, struct ast *node)
{
	struct ast *def = file_scope_find_symbol(node->scope, id_str(node));
	assert(def);

	if (def->k != AST_VAR_DEF) {
		semantic_error(node->scope, node, "cannot reference functions");
		return -1;
	}

	struct ast_pair *prev = find_move(state, def);
	if (prev) {
		move_error(node, prev->use);
		return -1;
	}

	prev = find_reference(state, def);
	if (prev) {
		reference_error(node, prev->use);
		return -1;
	}

	insert_reference(state, def, node);
	return 0;
}

static int refcheck(struct state *state, struct ast *node)
{
	assert(is_lvalue(node));

	switch (node->k) {
	case AST_ID: return refcheck_id(state, node);
	default:
		internal_error("unhandled node %s for refcheck",
		               ast_str(node->k));
		return -1;
	}

	return 0;
}

static int mvcheck_ref(struct state *state, struct ast *node)
{
	return refcheck(state, ref_base(node));
}

static int mvcheck_deref(struct state *state, struct ast *node)
{
	/** @todo good enough for now but probably won't hold when we start
	 * doing element lookups in structs etc. */
	return mvcheck(state, deref_base(node));
}

static int mvcheck_proc(struct state *state, struct ast *node)
{
	/* extern, can't really do anything so just say it's fine */
	if (!proc_body(node))
		return 0;

	struct state new_state = create_state(state);
	/* we don't need to merge things into the parent state */
	int ret = mvcheck(&new_state, proc_body(node));
	destroy_state(&new_state);
	return ret;
}

static int total_check_single(struct state *state, struct scope *scope)
{
	int ret = 0;
	foreach_visible(n, scope->symbols) {
		struct ast *def = n->node;
		if (def->k != AST_VAR_DEF)
			continue;

		if (is_trivially_copyable(def->t))
			continue;

		if (is_callable(def->t))
			continue;

		struct ast_pair *prev = find_move(state, def);
		if (prev)
			continue;

		semantic_warn(scope, def, "%s not moved, might leak", var_id(def));
		ret |= 1;
	}

	return ret;
}

static int total_check_proc(struct state *state, struct scope *scope)
{
	int ret = total_check_single(state, scope);
	if (scope->flags & SCOPE_PROC)
		return ret;

	if (!scope->parent)
		return ret;

	return ret | total_check_proc(state, scope->parent);
}

static int mvcheck_block(struct state *state, struct ast *node)
{
	struct state new_state = create_state(state);
	int ret = mvcheck_statements(&new_state, block_body(node));
	if (ret) {
		destroy_state(&new_state);
		return ret;
	}

	if (block_error(node)) {
		if (total_check_proc(&new_state, node->scope))
			semantic_info(node->scope, node, "at end of block");

		/* mark us queued so moves are not visible for the next scopes,
		 * but appear to the caller function if we're in a closure */
		mark_queued(&new_state);
		push_up(&new_state);
		destroy_state(&new_state);
		return 0;
	}

	if (total_check_single(&new_state, node->scope))
		semantic_info(node->scope, node, "at end of block");

	push_up(&new_state);
	destroy_state(&new_state);
	return 0;
}

static int mvcheck_call(struct state *state, struct ast *node)
{
	if (mvcheck(state, call_expr(node)))
		return -1;

	struct ast *args = call_args(node);
	if (!args)
		return 0;

	/* get all 'normal' parameters moved */
	foreach_node(arg, call_args(node)) {
		if (arg->k == AST_CLOSURE)
			continue;

		if (mvcheck(state, arg))
			return -1;
	}

	/* check into closures */
	int ret = 0;
	struct state buffer_state = create_state(state);
	struct state group_state = create_state(&buffer_state);
	foreach_node(arg, call_args(node)) {
		if (arg->k != AST_CLOSURE)
			continue;

		struct state arg_state = create_state(state);
		ret = mvcheck(&arg_state, arg);
		mark_unqueued(&arg_state);

		merge(&group_state, &arg_state);
		destroy_state(&arg_state);

		struct ast *next = arg->n;
		if (!next || next->t->group != arg->t->group) {
			/* this group ended, push state immediately */
			push_up(&group_state);

			/* something like reset_state would maybe be more clear? */
			destroy_state(&group_state);
			group_state = create_state(state);
		}

		if (ret)
			break;
	}

	destroy_state(&group_state);

	/* the next section looks a bit weird, but it kind of makes sense. We
	 * don't know when the error branch might get taken, so we first check
	 * it 'before' the moves from any possible closures take place to see if
	 * there's a possibility that some variable gets leaked. The second time
	 * around we take the closure moves into account to see if the error
	 * branch might accidentally move an already moved variable, i.e. the
	 * user would need to add `own` blocks. */

	/** @todo this is more or less what needs to happen for correctness, but
	 * might mean that warnings from leaks in the first check get reprinted
	 * in the second round, should probably add in some mechanism to check
	 * against that */
	if (call_err(node)) {
		struct state err_state = create_state(state);
		ret |= mvcheck(&err_state, call_err(node));
		destroy_state(&err_state);
	}
	else if (total_check_proc(state, node->scope)) {
		semantic_info(node->scope, node, "in implicit err branch");
	}

	push_up(&buffer_state);
	destroy_state(&buffer_state);

	if (call_err(node)) {
		struct state err_state = create_state(state);
		ret |= mvcheck(&err_state, call_err(node));

		/* store results of this check */
		push_up(&err_state);
		destroy_state(&err_state);
	}
	/* no need to check implicit error branch since it by definition can't
	 * move already moved variables */

	return ret;
}

static int mvcheck_closure(struct state *state, struct ast *node)
{
	struct state new_state = create_state(state);
	new_state.pure = ast_flags(node, AST_FLAG_NOMOVES);

	int ret = mvcheck(&new_state, closure_body(node));
	push_up(&new_state);
	destroy_state(&new_state);

	if (total_check_single(state, node->scope))
		semantic_info(node->scope, node, "in closure");

	return ret;
}

static void opt_group_left(struct state *state, struct ast *def,
                           struct ast *node)
{
	if (!def)
		return;

	insert_move(state, def, node);
	opt_group_left(state, def->ol, node);
}

static void opt_group_right(struct state *state, struct ast *def,
                            struct ast *node)
{
	if (!def)
		return;

	insert_move(state, def, node);
	opt_group_right(state, def->or, node);
}

static int mvcheck_id(struct state *state, struct ast *node)
{
	struct ast *def = file_scope_find_symbol(node->scope, id_str(node));
	assert(def);

	/* moves only apply to variables, functions are 'eternal' */
	if (def->k != AST_VAR_DEF)
		return 0;

	if (is_trivially_copyable(def->t))
		return 0;

	if (in_pure(state)) {
		semantic_error(node->scope, node,
		               "move in pure context not allowed");
		return -1;
	}

	struct ast_pair *prev = find_move(state, def);
	if (prev) {
		/* error messages for opt groups could be improved */
		move_error(node, prev->use);
		return -1;
	}

	prev = find_reference(state, def);
	if (prev) {
		reference_error(node, prev->use);
		return -1;
	}

	insert_move(state, def, node);

	if (def->or || def->ol) {
		/* we're part of an opt group, add all other members as if they
		 * were moved as well  */
		opt_group_left(state, def->ol, node);
		opt_group_right(state, def->or, node);
	}

	return 0;
}

static int mvcheck_let(struct state *state, struct ast *node)
{
	return mvcheck(state, let_expr(node));
}

static int mvcheck_init(struct state *state, struct ast *node)
{
	return mvcheck_list(state, init_body(node));
}

static int mvcheck_if(struct state *state, struct ast *node)
{
	if (mvcheck(state, if_cond(node)))
		return -1;

	struct state body_state = create_state(state);
	struct state else_state = create_state(state);

	if (mvcheck(&body_state, if_body(node))) {
		destroy_state(&body_state);
		destroy_state(&else_state);
		return -1;
	}

	if (mvcheck(&else_state, if_else(node))) {
		destroy_state(&body_state);
		destroy_state(&else_state);
		return -1;
	}

	push_up(&body_state);
	push_up(&else_state);

	destroy_state(&body_state);
	destroy_state(&else_state);
	return 0;
}

static int mvcheck_unop(struct state *state, struct ast *node)
{
	return mvcheck(state, unop_expr(node));
}

static int mvcheck_binop(struct state *state, struct ast *node)
{
	if (mvcheck(state, binop_left(node)))
		return -1;

	return mvcheck(state, binop_right(node));
}

static int mvcheck_comparison(struct state *state, struct ast *node)
{
	if (mvcheck(state, comparison_left(node)))
		return -1;

	return mvcheck(state, comparison_right(node));
}

static int mvcheck_list(struct state *state, struct ast *nodes)
{
	foreach_node(node, nodes) {
		if (mvcheck(state, node))
			return -1;
	}

	return 0;
}

static int mvcheck_statements(struct state *state, struct ast *nodes)
{
	foreach_node(node, nodes) {
		struct state new_state = create_state(state);
		if (mvcheck(&new_state, node)) {
			destroy_state(&new_state);
			return -1;
		}

		forget_references(&new_state);
		push_up(&new_state);
		destroy_state(&new_state);
	}

	return 0;
}

static int mvcheck_err_branch(struct state *state, struct ast *node)
{
	return mvcheck(state, err_branch_body(node));
}

static int mvcheck_own(struct state *state, struct ast *node)
{
	struct ast *def = file_scope_find_symbol(node->scope, own_id(node));
	assert(def);

	if (is_callable(def->t)) {
		semantic_error(node->scope, node,
				"ownership of callable cannot be checked");
		return -1;
	}

	struct rm_move prev = remove_move(state, def);
	int ret = mvcheck(state, own_body(node));

	/* insert back */
	reinsert_move(prev);
	return ret;
}

static int mvcheck(struct state *state, struct ast *node)
{
	if (is_unop(node))
		return mvcheck_unop(state, node);

	if (is_binop(node))
		return mvcheck_binop(state, node);

	if (is_comparison(node))
		return mvcheck_comparison(state, node);

	switch (node->k) {
	case AST_PROC_DEF:      return mvcheck_proc      (state, node);
	case AST_BLOCK:         return mvcheck_block     (state, node);
	case AST_CALL:          return mvcheck_call      (state, node);
	case AST_CLOSURE:       return mvcheck_closure   (state, node);
	case AST_ERR_BRANCH:	return mvcheck_err_branch(state, node);
	case AST_OWN:		return mvcheck_own	 (state, node);
	case AST_LET:           return mvcheck_let       (state, node);
	case AST_ID:            return mvcheck_id        (state, node);
	case AST_INIT:          return mvcheck_init      (state, node);
	case AST_IF:            return mvcheck_if        (state, node);
	case AST_REF:           return mvcheck_ref       (state, node);
	case AST_DEREF:         return mvcheck_deref     (state, node);
	case AST_EMPTY:
	case AST_CONST_INT:
	case AST_CONST_STR:
	case AST_CONST_FLOAT:
	case AST_CONST_CHAR:
		return 0;
	default: break;
	}

	internal_error("missing move check for %s", ast_str(node->k));
	return -1;
}

int mvcheck_root(struct ast *root)
{
	foreach_node(node, root) {
		struct state state = create_state(NULL);
		if (mvcheck(&state, node))
			return -1;

		destroy_state(&state);
	}

	return 0;
}