#include <posthaste/check.h>
#include <posthaste/debug.h>
#include <posthaste/utils.h>

/* state required to implement semantic checking. In this case it turned out
 * that I only needed a pointer to the function/procedure the checker is
 * currently in, mainly for return type matching. */
struct state {
	struct ast *parent;
};

int analyze_visibility(struct scope *scope, struct ast *n)
{
	switch (n->k) {
	case AST_FUNC_DEF:
		if (scope_add_func(scope, n))
			return -1;

		break;

	case AST_PROC_DEF:
		if (scope_add_proc(scope, n))
			return -1;

		break;

	default:
	}

	return 0;
}

static bool has_type(struct ast *n)
{
	return n->t != 0;
}

static bool concrete_type(enum type_kind k)
{
	switch (k) {
	case TYPE_VOID:
	case TYPE_AUTO:
		return false;
	default: break;
	}

	return true;
}

static int analyze(struct state *state, struct scope *scope, struct ast *n);
static int analyze_list(struct state *state, struct scope *scope, struct ast *l)
{
	foreach_node(n, l) {
		if (analyze(state, scope, n))
			return -1;
	}

	return 0;
}

static int analyze_statement_list(struct state *state, struct scope *scope, struct ast *l)
{
	foreach_node(n, l) {
		if (analyze(state, scope, n))
			return -1;

		if (n->k != AST_PROC_CALL)
			continue;

		if (n->t != TYPE_VOID) {
			semantic_error(scope, n,
					"non-void proc call not allowed as statement");
			return -1;
		}
	}

	return 0;
}

static struct ast *file_scope_find_analyzed(struct state *state,
                                            struct scope *scope, char *id)
{
	/* look up a definition (func/proc/variable) in this scope and possible
	 * parent scopes up to and including the file scope, and make sure that
	 * whatever we find has been analyzed to completeness. */
	struct ast *exists = file_scope_find(scope, id);
	if (!exists)
		return NULL;

	if (analyze(state, scope, exists))
		return NULL;

	return exists;
}

static int analyze_type(struct scope *scope, struct ast *n)
{
	if (!n->type) {
		n->t = TYPE_VOID;
		return 0;
	}

	if (same_id(n->type, "int")) {
		n->t = TYPE_INT;
		return 0;
	}
	else if (same_id(n->type, "date")) {
		n->t = TYPE_DATE;
		return 0;
	}
	else if (same_id(n->type, "auto")) {
		n->t = TYPE_AUTO;
		return 0;
	}

	semantic_error(scope, n, "no such type: %s", n->type);
	return -1;
}

static int analyze_func_def(struct state *state, struct scope *scope,
                            struct ast *f)
{
	UNUSED(state);
	if (analyze_type(scope, f))
		return -1;

	/* slightly hacky, but easiest way to allow recursion */
	ast_set_flags(f, AST_FLAG_CHECKED);

	struct scope *func_scope = create_scope();
	scope_add_scope(scope, func_scope);
	f->scope = func_scope;

	struct state func_state = {.parent = f};
	if (analyze_list(&func_state, func_scope, func_formals(f)))
		return -1;

	if (analyze_list(&func_state, func_scope, func_vars(f)))
		return -1;

	if (analyze(&func_state, func_scope, func_body(f)))
		return -1;

	if (f->t == TYPE_AUTO) {
		semantic_error(scope, f, "unable to determine return type");
		return -1;
	}

	return 0;
}

static int analyze_proc_def(struct state *state, struct scope *scope,
                            struct ast *p)
{
	UNUSED(state);
	if (analyze_type(scope, p))
		return -1;

	ast_set_flags(p, AST_FLAG_CHECKED);

	struct scope *proc_scope = create_scope();
	scope_add_scope(scope, proc_scope);
	p->scope = proc_scope;

	struct state proc_state = {.parent = p};
	if (analyze_list(&proc_state, proc_scope, proc_formals(p)))
		return -1;

	if (analyze_list(&proc_state, proc_scope, proc_vars(p)))
		return -1;

	if (analyze_statement_list(&proc_state, proc_scope, proc_body(p)))
		return -1;

	struct ast *last = ast_last(proc_body(p));
	/* here we could for example implement checking that all branches in an unless
	 * return, but for simplicity require user to add explicit return as
	 * last statement in body */
	if (p->t != TYPE_VOID && last->k != AST_RETURN) {
		semantic_error(scope, p, "can't prove that proc reaches return statement");
		return -1;
	}

	if (p->t == TYPE_AUTO) {
		semantic_error(scope, p, "unable to determine return type");
		return -1;
	}

	return 0;
}

static int analyze_var_def(struct state *state, struct scope *scope,
                           struct ast *n)
{
	if (scope_add_var(scope, n))
		return -1;

	struct ast *expr = var_expr(n);
	if (analyze(state, scope, expr))
		return -1;

	if (!concrete_type(expr->t)) {
		semantic_error(scope, n, "illegal type in variable definition: %s",
				type_str(expr->t));
		return -1;
	}

	n->t = expr->t;
	return 0;
}

static int analyze_formal_def(struct state *state, struct scope *scope,
                              struct ast *n)
{
	UNUSED(state);
	if (analyze_type(scope, n))
		return -1;

	if (!concrete_type(n->t)) {
		semantic_error(scope, n, "illegal type for formal parameter: %s",
				type_str(n->t));
		return -1;
	}

	if (scope_add_formal(scope, n))
		return -1;

	return 0;
}

static int analyze_neg(struct state *state, struct scope *scope, struct ast *n)
{
	struct ast *expr = neg_expr(n);
	if (analyze(state, scope, expr))
		return -1;

	if (expr->t != TYPE_INT) {
		semantic_error(scope, n,
		               "negation requires %s, got %s\n",
		               type_str(TYPE_INT),
		               type_str(expr->t));
		return -1;
	}

	n->t = expr->t;
	return 0;
}

static int analyze_pos(struct state *state, struct scope *scope, struct ast *p)
{
	struct ast *expr = pos_expr(p);
	if (analyze(state, scope, expr))
		return -1;

	if (expr->t != TYPE_INT) {
		semantic_error(scope, p,
		               "unary pos requires %s, got %s\n",
		               type_str(TYPE_INT),
		               type_str(expr->t));
		return -1;
	}

	p->t = expr->t;
	return 0;
}

static int analyze_const_date(struct state *state, struct scope *scope,
                              struct ast *d)
{
	UNUSED(state);
	UNUSED(scope);
	/* should be in correct format as the lexer took care of it */
	d->t = TYPE_DATE;
	return 0;
}

static int analyze_const_int(struct state *state, struct scope *scope,
                             struct ast *i)
{
	UNUSED(state);
	UNUSED(scope);
	i->t = TYPE_INT;
	return 0;
}

static int analyze_const_string(struct state *state, struct scope *scope,
                                struct ast *s)
{
	UNUSED(state);
	UNUSED(scope);
	s->t = TYPE_STRING;
	return 0;
}

static int analyze_eq(struct state *state, struct scope *scope, struct ast *e)
{
	struct ast *l = eq_l(e);
	if (analyze(state, scope, l))
		return -1;

	struct ast *r = eq_r(e);
	if (analyze(state, scope, r))
		return -1;

	if (r->t != l->t) {
		semantic_error(scope, e, "type mismatch: %s vs %s\n",
		               type_str(l->t),
		               type_str(r->t));
		return -1;
	}

	e->t = TYPE_BOOL;
	return 0;
}

static int analyze_lt(struct state *state, struct scope *scope, struct ast *e)
{
	struct ast *l = lt_l(e);
	if (analyze(state, scope, l))
		return -1;

	struct ast *r = lt_r(e);
	if (analyze(state, scope, r))
		return -1;

	if (r->t != l->t) {
		semantic_error(scope, e, "type mismatch: %s vs %s\n",
		               type_str(l->t),
		               type_str(r->t));
		return -1;
	}

	e->t = TYPE_BOOL;
	return 0;
}

static int analyze_add(struct state *state, struct scope *scope, struct ast *a)
{
	struct ast *l = add_l(a);
	if (analyze(state, scope, l))
		return -1;

	struct ast *r = add_r(a);
	if (analyze(state, scope, r))
		return -1;

	if (r->t == TYPE_DATE) {
		semantic_error(scope, r, "date not allowed as rhs to addition");
		return -1;
	}

	if (l->t == TYPE_DATE) {
		a->k = AST_DATE_ADD;
	}

	a->t = l->t;
	return 0;
}

static int analyze_sub(struct state *state, struct scope *scope, struct ast *s)
{
	struct ast *l = sub_l(s);
	if (analyze(state, scope, l))
		return -1;

	struct ast *r = sub_r(s);
	if (analyze(state, scope, r))
		return -1;

	if (l->t == TYPE_DATE && r->t == TYPE_DATE) {
		s->k = AST_DATE_DIFF;
		s->t = TYPE_INT;
		return 0;
	}

	if (l->t == TYPE_DATE && r->t == TYPE_INT) {
		s->k = AST_DATE_SUB;
		s->t = TYPE_DATE;
		return 0;
	}

	if (l->t == TYPE_INT && r->t == TYPE_INT) {
		s->t = TYPE_INT;
		return 0;
	}

	semantic_error(scope, s, "illegal subtraction types: %s, %s",
	               type_str(l->t), type_str(r->t));
	return -1;
}

static int analyze_mul(struct state *state, struct scope *scope, struct ast *m)
{
	struct ast *l = mul_l(m);
	if (analyze(state, scope, l))
		return -1;

	if (l->t != TYPE_INT) {
		semantic_error(scope, l, "expected %s, got %s",
		               type_str(TYPE_INT), type_str(l->t));
		return -1;
	}

	struct ast *r = mul_r(m);
	if (analyze(state, scope, r))
		return -1;

	if (r->t != TYPE_INT) {
		semantic_error(scope, r, "expected %s, got %s",
		               type_str(TYPE_INT), type_str(r->t));
		return -1;
	}

	m->t = TYPE_INT;
	return 0;
}

static int analyze_div(struct state *state, struct scope *scope, struct ast *d)
{
	struct ast *l = div_l(d);
	if (analyze(state, scope, l))
		return -1;

	if (l->t != TYPE_INT) {
		semantic_error(scope, l, "expected %s, got %s",
		               type_str(TYPE_INT), type_str(l->t));
		return -1;
	}

	struct ast *r = div_r(d);
	if (analyze(state, scope, r))
		return -1;

	if (r->t != TYPE_INT) {
		semantic_error(scope, r, "expected %s, got %s",
		               type_str(TYPE_INT), type_str(r->t));
		return -1;
	}

	d->t = TYPE_INT;
	return 0;
}

static int analyze_print(struct state *state, struct scope *scope,
                         struct ast *p)
{
	struct ast *items = print_items(p);
	struct ast *fixups = NULL, *last_fixup = NULL;

	foreach_node(n, items) {
		if (analyze(state, scope, n))
			return -1;

		if (n->t == TYPE_VOID) {
			semantic_error(scope, n, "void not printable");
			return -1;
		}

		/* build new list of nodes, helps a bit when lowering */
		struct ast *fixup = fixup_print_item(n);
		if (!fixups)
			fixups = fixup;

		if (last_fixup) {
			last_fixup->n = fixup;
		}

		last_fixup = fixup;
	}

	print_items(p) = fixups;
	p->t = TYPE_VOID;
	return 0;
}

static int analyze_return(struct state *state, struct scope *scope,
                          struct ast *r)
{
	struct ast *parent = state->parent;
	if (!parent) {
		semantic_error(scope, r, "stray return");
		return -1;
	}


	if (!has_type(parent)) {
		semantic_error(scope, r,
		               "stray return in proc without return type");
		return -1;
	}

	struct ast *expr = return_expr(r);
	if (analyze(state, scope, expr))
		return -1;

	/* deduce return type from first return we hit */
	if (parent->t == TYPE_AUTO)
		parent->t = expr->t;

	if (expr->t != parent->t) {
		semantic_error(scope, r, "return type mismatch: %s vs %s",
		               type_str(parent->t), type_str(expr->t));
		return -1;
	}

	r->t = expr->t;
	return 0;
}

static int analyze_id(struct state *state, struct scope *scope, struct ast *i)
{
	UNUSED(state);
	struct ast *exists = file_scope_find_analyzed(state, scope, i->id);
	if (!exists) {
		semantic_error(scope, i, "no such symbol");
		return -1;
	}

	if (exists->k != AST_FORMAL_DEF && exists->k != AST_VAR_DEF) {
		semantic_error(scope, i, "no such variable");
		return -1;
	}

	i->t = exists->t;
	return 0;
}

static int analyze_dot(struct state *state, struct scope *scope, struct ast *d)
{
	struct ast *base = dot_base(d);
	if (analyze(state, scope, base))
		return -1;

	if (base->t != TYPE_DATE) {
		semantic_error(scope, d, "expected %s, got %s",
		               type_str(TYPE_DATE), type_str(base->t));
		return -1;
	}

	d->t = TYPE_INT;
	if (same_id(d->id, "day"))
		return 0;

	if (same_id(d->id, "month"))
		return 0;

	if (same_id(d->id, "year"))
		return 0;

	semantic_error(scope, d, "illegal write attribute %s", d->id);
	return -1;
}

static int analyze_attr(struct state *state, struct scope *scope, struct ast *d)
{
	struct ast *base = attr_base(d);
	if (analyze(state, scope, base))
		return -1;

	if (base->t != TYPE_DATE) {
		semantic_error(scope, d, "expected %s, got %s",
		               type_str(TYPE_DATE), type_str(base->t));
		return -1;
	}

	d->t = TYPE_INT;
	if (same_id(d->id, "day"))
		return 0;

	if (same_id(d->id, "month"))
		return 0;

	if (same_id(d->id, "year"))
		return 0;

	if (same_id(d->id, "weekday"))
		return 0;

	if (same_id(d->id, "weeknum"))
		return 0;

	semantic_error(scope, d, "illegal write attribute %s", d->id);
	return -1;
}

static int analyze_assign(struct state *state, struct scope *scope,
                          struct ast *n)
{
	struct ast *l = assign_l(n);
	if (analyze(state, scope, l))
		return -1;

	if (!concrete_type(l->t)) {
		semantic_error(scope, n, "illegal type for lefthand side: %s",
				type_str(l->t));
		return -1;
	}

	struct ast *r = assign_r(n);
	if (analyze(state, scope, r))
		return -1;

	if (!concrete_type(r->t)) {
		semantic_error(scope, n, "illegal type for righthand side: %s",
				type_str(r->t));
		return -1;
	}

	if (l->t != r->t) {
		semantic_error(scope, n, "type mismatch: %s vs %s",
		               type_str(l->t), type_str(r->t));
		return -1;
	}

	n->t = l->t;
	return 0;
}

static int analyze_today(struct state *state, struct scope *scope,
                         struct ast *c)
{
	UNUSED(state);
	if (ast_list_len(func_call_args(c)) != 0) {
		semantic_error(scope, c, "expected 0 arguments, got %zd",
		               ast_list_len(func_call_args(c)));
		return -1;
	}

	c->k = AST_BUILTIN_CALL;
	c->t = TYPE_DATE;
	return 0;
}

static int analyze_proc_call(struct state *state, struct scope *scope,
                             struct ast *c)
{
	struct ast *exists = file_scope_find_analyzed(state, scope,
	                                              proc_call_id(c));
	if (!exists || exists->k != AST_PROC_DEF) {
		semantic_error(scope, c, "no such proc");
		return -1;
	}

	if (exists->t == TYPE_AUTO) {
		semantic_error(scope, c, "proc call before deduction of auto");
		return -1;
	}

	struct ast *args = proc_call_args(c);
	if (analyze_list(state, scope, args))
		return -1;

	struct ast *formal = proc_formals(exists);
	if (ast_list_len(formal) != ast_list_len(args)) {
		semantic_error(scope, c, "expected %zd args, got %zd",
		               ast_list_len(formal), ast_list_len(args));
		return -1;
	}

	foreach_node(a, args) {
		if (a->t != formal->t) {
			semantic_error(scope, a, "expected %s, got %s",
			               type_str(formal->t), type_str(a->t));
			return -1;
		}

		formal = formal->n;
	}

	c->t = exists->t;
	return 0;
}

static int analyze_func_call(struct state *state, struct scope *scope,
                             struct ast *c)
{
	/* handle special Today() built-in */
	if (same_id(func_call_id(c), "Today"))
		return analyze_today(state, scope, c);

	struct ast *exists = file_scope_find_analyzed(state, scope,
	                                              func_call_id(c));
	if (!exists || exists->k != AST_FUNC_DEF) {
		semantic_error(scope, c, "no such func");
		return -1;
	}

	if (exists->t == TYPE_AUTO) {
		semantic_error(scope, c, "func call before deduction of auto");
		return -1;
	}

	struct ast *args = func_call_args(c);
	if (analyze_list(state, scope, args))
		return -1;

	struct ast *formal = func_formals(exists);
	if (ast_list_len(formal) != ast_list_len(args)) {
		semantic_error(scope, c, "expected %zd args, got %zd",
		               ast_list_len(formal), ast_list_len(args));
		return -1;
	}

	foreach_node(a, args) {
		if (a->t != formal->t) {
			semantic_error(scope, c, "expected %s, got %s",
			               type_str(formal->t), type_str(a->t));
			return -1;
		}

		formal = formal->n;
	}

	c->t = exists->t;
	return 0;
}

static int analyze_until(struct state *state, struct scope *scope,
                         struct ast *n)
{
	struct scope *until_scope = create_scope();
	scope_add_scope(scope, until_scope);

	struct ast *body = until_body(n);
	if (analyze_statement_list(state, until_scope, body))
		return -1;

	struct ast *cond = until_cond(n);
	if (analyze(state, until_scope, cond))
		return -1;

	if (cond->t != TYPE_BOOL && cond->t != TYPE_INT) {
		semantic_error(scope, cond, "expected truthy type, got %s",
		               type_str(cond->t));
		return -1;
	}

	n->t = TYPE_VOID;
	return 0;
}

static int analyze_unless(struct state *state, struct scope *scope,
                          struct ast *n)
{
	struct scope *unless_scope = create_scope();
	struct scope *otherwise_scope = create_scope();
	scope_add_scope(scope, unless_scope);
	scope_add_scope(scope, otherwise_scope);

	struct ast *body = unless_body(n);
	if (analyze_statement_list(state, unless_scope, body))
		return -1;

	struct ast *cond = unless_cond(n);
	if (analyze(state, scope, cond))
		return -1;

	if (cond->t != TYPE_BOOL && cond->t != TYPE_INT) {
		semantic_error(scope, cond, "expected truthy type, got %s",
		               type_str(cond->t));
		return -1;
	}

	struct ast *otherwise = unless_otherwise(n);
	if (otherwise && analyze_statement_list(state, otherwise_scope, otherwise))
		return -1;

	n->t = TYPE_VOID;
	return 0;
}

static int analyze_unless_expr(struct state *state, struct scope *scope,
                               struct ast *n)
{
	struct ast *body = unless_expr_body(n);
	if (analyze(state, scope, body))
		return -1;

	struct ast *cond = unless_expr_cond(n);
	if (analyze(state, scope, cond))
		return -1;

	if (cond->t != TYPE_BOOL) {
		semantic_error(scope, cond, "expected %s, got %s",
		               type_str(TYPE_BOOL), type_str(cond->t));
		return -1;
	}

	struct ast *otherwise = unless_expr_otherwise(n);
	if (analyze(state, scope, otherwise))
		return -1;

	if (body->t != otherwise->t) {
		semantic_error(scope, n, "type mismatch: %s vs %s",
		               type_str(body->t), type_str(otherwise->t));
		return -1;
	}

	n->t = body->t;
	return 0;
}

static int analyze(struct state *state, struct scope *scope, struct ast *n)
{
	int ret = 0;

	if (ast_flags(n, AST_FLAG_CHECKED)) {
		assert(has_type(n));
		return 0;
	}

	if (ast_flags(n, AST_FLAG_INIT)) {
		semantic_error(scope, n, "semantic loop detected");
		return -1;
	}

	ast_set_flags(n, AST_FLAG_INIT);

	/* generally what we want to happen, special cases can handle themselves */
	if (!n->scope)
		n->scope = scope;

	switch (n->k) {
	case AST_FUNC_DEF: ret = analyze_func_def(state, scope, n); break;
	case AST_PROC_DEF: ret = analyze_proc_def(state, scope, n); break;
	case AST_VAR_DEF: ret = analyze_var_def(state, scope, n); break;
	case AST_FORMAL_DEF: ret = analyze_formal_def(state, scope, n); break;
	case AST_NEG: ret = analyze_neg(state, scope, n); break;
	case AST_POS: ret = analyze_pos(state, scope, n); break;
	case AST_CONST_DATE: ret = analyze_const_date(state, scope, n); break;
	case AST_CONST_INT: ret = analyze_const_int(state, scope, n); break;
	case AST_CONST_STRING: ret = analyze_const_string(state, scope, n);
		break;
	case AST_EQ: ret = analyze_eq(state, scope, n); break;
	case AST_LT: ret = analyze_lt(state, scope, n); break;
	case AST_ADD: ret = analyze_add(state, scope, n); break;
	case AST_SUB: ret = analyze_sub(state, scope, n); break;
	case AST_MUL: ret = analyze_mul(state, scope, n); break;
	case AST_DIV: ret = analyze_div(state, scope, n); break;
	case AST_PRINT: ret = analyze_print(state, scope, n); break;
	case AST_RETURN: ret = analyze_return(state, scope, n); break;
	case AST_ID: ret = analyze_id(state, scope, n); break;
	case AST_DOT: ret = analyze_dot(state, scope, n); break;
	case AST_ATTR: ret = analyze_attr(state, scope, n); break;
	case AST_ASSIGN: ret = analyze_assign(state, scope, n); break;
	case AST_PROC_CALL: ret = analyze_proc_call(state, scope, n); break;
	case AST_FUNC_CALL: ret = analyze_func_call(state, scope, n); break;
	case AST_UNTIL: ret = analyze_until(state, scope, n); break;
	case AST_UNLESS: ret = analyze_unless(state, scope, n); break;
	case AST_UNLESS_EXPR: ret = analyze_unless_expr(state, scope, n); break;
	default: break;
		 /* not all nodes are in this switch statement, as the internal
		  * ones are assumed to be correct */
	}

	if (ret == 0) {
		/* even though sometimes we might not need the type, this is a
		 * fairly effective sanity check */
		assert(has_type(n));
	}

	ast_set_flags(n, AST_FLAG_CHECKED);
	return ret;
}

/* this I guess would be classified as a two-phase semantic checker, where the
 * first phase checks that there aren't any multiply defined
 * procs/funcs/variables, and the second phase does (mainly) type checking. I didn't
 * explicitly call it type checking as it technically also checks date
 * attributes, date +- etc. which were fairly easy to implement in the same
 * phase. */
int check(struct scope *scope, struct ast *root)
{
	/* first add all procedures/functions to the top level scope so we can
	 * find them later */
	foreach_node(n, root) {
		if (analyze_visibility(scope, n))
			return -1;
	}

	/* actually analyze all nodes */
	struct state state = {0};
	if (analyze_statement_list(&state, scope, root))
		return -1;

#ifdef DEBUG
	foreach_node(n, root) {
		printf("// after checking:\n");
		ast_dump(0, n);
	}
#endif

	return 0;
}