/* SPDX-License-Identifier: copyleft-next-0.3.1 */
/* Copyright 2024 Kim Kuparinen < kimi.h.kuparinen@gmail.com > */

#include <fwd/compiler.h>
#include <fwd/analyze.h>
#include <fwd/mod.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#include <assert.h>
#include <dlfcn.h>

enum state_flags {
	STATE_PROC_REQ_FRAME = (1 << 0)
};

struct state {
	enum state_flags flags;
};

static int analyze(struct state *state, struct scope *scope, struct ast *node);
static int analyze_known_block(struct state *state, struct scope *scope, struct ast *node);
static int analyze_list(struct state *state, struct scope *scope,
                        struct ast *nodes);

static int analyze_type(struct state *state, struct scope *scope,
                        struct type *type);
static int analyze_type_list(struct state *state, struct scope *scope,
                             struct type *types);

static int analyze_proc(struct state *state, struct scope *scope,
                        struct ast *node)
{
	(void)state;
	(void)scope;

	struct scope *proc_scope = create_scope();
	if (!proc_scope) {
		internal_error("failed allocating proc scope");
		return -1;
	}

	proc_scope->flags |= SCOPE_PROC;
	scope_add_scope(scope, proc_scope);

	struct state proc_state = {};
	if (analyze_list(&proc_state, proc_scope, proc_params(node)))
		return -1;

	struct type *proc_type = tgen_func_ptr(NULL, node->loc);
	if (!proc_type) {
		internal_error("failed allocating proc type");
		return -1;
	}

	size_t group = 0;
	foreach_node(param, proc_params(node)) {
		if (param->ol == NULL && param->or == NULL)
			group++; /* no opt group */

		if (param->ol == NULL && param->or)
			group++; /* start of new group */

		/* otherwise same or ending group, don't increment */

		param->t->group = group;
		type_append(&proc_type->t0, clone_type(param->t));
	}

	node->t = proc_type;
	if (!proc_body(node))
		return 0;

	if (analyze_known_block(&proc_state, proc_scope, proc_body(node)))
		return -1;

	if (proc_state.flags & STATE_PROC_REQ_FRAME)
		ast_set_flags(node, AST_REQ_FRAME);

	return 0;
}

static int analyze_unop(struct state *state, struct scope *scope,
                        struct ast *node)
{
	/** @todo check expr is some primitive type */
	struct ast *expr = unop_expr(node);
	if (analyze(state, scope, expr))
		return -1;

	node->t = expr->t;
	return 0;
}

static int analyze_binop(struct state *state, struct scope *scope,
                         struct ast *node)
{
	struct ast *lhs = binop_left(node);
	struct ast *rhs = binop_right(node);

	if (analyze(state, scope, lhs))
		return -1;

	if (analyze(state, scope, rhs))
		return -1;


	if (!types_match(lhs->t, rhs->t)) {
		type_mismatch(scope, node, lhs->t, rhs->t);
		return -1;
	}

	/** @todo check type is some primitive */
	node->t = lhs->t;
	return 0;
}

static int analyze_comparison(struct state *state, struct scope *scope,
                              struct ast *node)
{
	struct ast *lhs = comparison_left(node);
	struct ast *rhs = comparison_right(node);

	if (analyze(state, scope, lhs))
		return -1;

	if (analyze(state, scope, rhs))
		return -1;


	if (!types_match(lhs->t, rhs->t)) {
		type_mismatch(scope, node, lhs->t, rhs->t);
		return -1;
	}

	/** @todo check type is some primitive */
	char *tf = strdup("bool");
	if (!tf) {
		internal_error("failed allocating comparison bool str");
		return -1;
	}

	node->t = tgen_id(tf, node->loc);
	if (!node->t) {
		internal_error("failed allocating comparison bool type");
		free(tf);
		return -1;
	}

	return 0;
}

static int analyze_known_block(struct state *state, struct scope *scope, struct ast *node)
{
	assert(node && node->k == AST_BLOCK);

	node->scope = scope;
	node->t = tgen_void(node->loc);
	if (analyze_list(state, scope, block_body(node)))
		return -1;

	return 0;
}

static int analyze_block(struct state *state, struct scope *scope,
                         struct ast *node)
{
	struct scope *block_scope = create_scope();
	if (!block_scope) {
		internal_error("failed to allocate block scope");
		return -1;
	}

	scope_add_scope(scope, block_scope);
	return analyze_known_block(state, scope, node);
}

static int analyze_var(struct state *state, struct scope *scope,
                       struct ast *node)
{
	if (analyze_type(state, scope, var_type(node)))
		return -1;

	node->t = var_type(node);
	return scope_add_symbol(scope, node);
}

static int analyze_let(struct state *state, struct scope *scope,
                       struct ast *node)
{
	if (analyze(state, scope, let_var(node)))
		return -1;

	if (analyze(state, scope, let_expr(node)))
		return -1;

	struct type *l = (let_var(node))->t;
	struct type *r = (let_expr(node))->t;
	if (!types_match(l, r)) {
		type_mismatch(scope, node, l, r);
		return -1;
	}

	node->t = l;
	return 0;
}

static int analyze_init(struct state *state, struct scope *scope,
                        struct ast *node)
{
	if (analyze_type_list(state, scope, init_args(node)))
		return -1;

	if (analyze(state, scope, init_body(node)))
		return -1;

	/** @todo check that all parameters match, though that requires that the
	 * type is defined in fwd and not in C++, so this might have to wait a
	 * bit */
	node->t = tgen_construct(strdup(init_id(node)), init_args(node),
	                         node->loc);
	if (!node->t) {
		internal_error("failed allocating type for construct");
		return -1;
	}

	return 0;
}

static int analyze_call(struct state *state, struct scope *scope,
                        struct ast *node)
{
	struct ast *expr = call_expr(node);
	if (analyze(state, scope, expr))
		return -1;

	if (!is_callable(expr->t)) {
		semantic_error(scope, node, "expected callable expression");
		return -1;
	}

	struct type *types = callable_types(expr->t);
	struct ast *args = call_args(node);
	if (analyze_list(state, scope, args))
		return -1;

	size_t expected = type_list_len(types);
	size_t got = ast_list_len(args);
	if (expected != got) {
		semantic_error(scope, node, "expected %d params, got %d",
		               expected, got);
		return -1;
	}

	struct type *type = types;
	foreach_node(arg, args) {
		if (arg->k == AST_CLOSURE)
			state->flags |= STATE_PROC_REQ_FRAME;

		if (!types_match(type, arg->t)) {
			type_mismatch(scope, node, type, arg->t);
			return -1;
		}

		/* clone group info */
		arg->t->group = type->group;
		type = type->n;
	}

	node->t = tgen_void(node->loc);
	if (!node->t) {
		internal_error("failed allocating void type for call");
		return -1;
	}

	return 0;
}

static int analyze_ref(struct state *state, struct scope *scope,
                       struct ast *node)
{
	struct ast *expr = ref_base(node);
	if (analyze(state, scope, expr))
		return -1;

	if (!is_lvalue(expr)) {
		semantic_error(node->scope, node, "trying to reference rvalue");
		return -1;
	}

	struct type *ref = tgen_ref(expr->t, node->loc);
	if (!ref) {
		internal_error("failed allocating ref type");
		return -1;
	}

	node->t = ref;
	return 0;
}

static int analyze_deref(struct state *state, struct scope *scope,
                         struct ast *node)
{
	struct ast *expr = deref_base(node);
	if (analyze(state, scope, expr))
		return -1;

	if (expr->t->k == TYPE_PTR) {
		semantic_error(node->scope, node,
		               "deref of raw ptr not allowed");
		semantic_info(node->scope, node,
		              "use fwd_null() to convert to ref");
		return -1;
	}

	if (expr->t->k != TYPE_REF) {
		semantic_error(node->scope, node,
		               "deref of something not a reference");
		return -1;
	}

	node->t = tptr_base(expr->t);
	return 0;
}

static int analyze_id(struct state *state, struct scope *scope,
                      struct ast *node)
{
	struct ast *found = file_scope_find_symbol(scope, id_str(node));
	if (!found) {
		semantic_error(scope, node, "no symbol named \"%s\"",
		               id_str(node));
		return -1;
	}

	/* kind of hacky, functions are given their type before they've been
	 * analyzed fully, this enables recursion */
	if (!found->t && !ast_flags(found, AST_FLAG_ANALYZED)) {
		assert(found->k == AST_PROC_DEF);
		/* a proc def will use its own state and scope, but pass these
		 * on in case the analysis wants to print errors or something
		 */
		if (analyze(state, scope, found))
			return -1;
	}

	node->t = found->t;
	return 0;
}

static int analyze_closure(struct state *state, struct scope *scope,
                           struct ast *node)
{
	struct scope *closure_scope = create_scope();
	if (!closure_scope) {
		internal_error("failed allocating closure scope");
		return -1;
	}

	node->scope = closure_scope;
	closure_scope->flags |= SCOPE_PROC;
	scope_add_scope(scope, closure_scope);

	if (analyze_list(state, closure_scope, closure_bindings(node)))
		return -1;

	if (analyze_known_block(state, closure_scope, closure_body(node)))
		return -1;

	struct type *callable = NULL;
	if (ast_flags(node, AST_FLAG_NOMOVES))
		callable = tgen_pure_closure(NULL, node->loc);
	else
		callable = tgen_closure(NULL, node->loc);

	if (!callable) {
		internal_error("failed allocating closure type");
		return -1;
	}

	foreach_node(binding, closure_bindings(node)) {
		type_append(&callable->t0, clone_type(binding->t));
	}

	node->t = callable;
	return 0;
}

static int analyze_int(struct state *state, struct scope *scope,
                       struct ast *node)
{
	(void)state;
	(void)scope;

	/* start with largest possible and work down */
	node->t = tgen_type(TYPE_I64, NULL, NULL, node->loc);
	if (!node->t) {
		internal_error("failed allocating constant int type");
		return -1;
	}

	return 0;
}

static int analyze_str(struct state *state, struct scope *scope,
                       struct ast *node)
{
	(void)state;
	(void)scope;

	/** @todo do this properly, very hacky, bad bad bad */
	char *i = strdup("char");
	if (!i) {
		internal_error("failed allocating constant char type string");
		return -1;
	}

	struct type *ch = tgen_id(i, node->loc);
	if (!ch) {
		internal_error("failed allocating constant char type");
		free(i);
		return -1;
	}

	struct type *str = tgen_ptr(ch, node->loc);
	if (!str) {
		internal_error("failed allocating constant str type");
		return -1;
	}

	node->t = str;
	return 0;
}

static int analyze_if(struct state *state, struct scope *scope,
                      struct ast *node)
{
	if (analyze(state, scope, if_cond(node)))
		return -1;

	if (analyze(state, scope, if_body(node)))
		return -1;

	if (analyze(state, scope, if_else(node)))
		return -1;

	node->t = tgen_void(node->loc);
	if (!node->t) {
		internal_error("failed allocating 'if' void type");
		return -1;
	}

	return 0;
}

static int analyze(struct state *state, struct scope *scope, struct ast *node)
{
	if (!node)
		return 0;

	if (ast_flags(node, AST_FLAG_ANALYZED)) {
		assert(node->t);
		return 0;
	}

	if (ast_flags(node, AST_FLAG_PREANALYZIS)) {
		semantic_error(scope, node, "semantic loop detected");
		return -1;
	}

	ast_set_flags(node, AST_FLAG_PREANALYZIS);

	if (!node->scope)
		node->scope = scope;

	int ret = 0;
	if (is_binop(node)) {
		ret = analyze_binop(state, scope, node);
		goto out;
	}

	if (is_comparison(node)) {
		ret = analyze_comparison(state, scope, node);
		goto out;
	}

	if (is_unop(node)) {
		ret = analyze_unop(state, scope, node);
		goto out;
	}

	switch (node->k) {
	case AST_CLOSURE: ret = analyze_closure(state, scope, node); break;
	case AST_PROC_DEF: ret = analyze_proc(state, scope, node); break;
	case AST_VAR_DEF: ret = analyze_var(state, scope, node); break;
	case AST_BLOCK: ret = analyze_block(state, scope, node); break;
	case AST_DEREF: ret = analyze_deref(state, scope, node); break;
	case AST_INIT: ret = analyze_init(state, scope, node); break;
	case AST_CALL: ret = analyze_call(state, scope, node); break;
	case AST_REF: ret = analyze_ref(state, scope, node); break;
	case AST_LET: ret = analyze_let(state, scope, node); break;
	case AST_ID: ret = analyze_id(state, scope, node); break;
	case AST_IF: ret = analyze_if(state, scope, node); break;

	case AST_CONST_INT: ret = analyze_int(state, scope, node); break;
	case AST_CONST_STR: ret = analyze_str(state, scope, node); break;

	case AST_EMPTY:
		node->t = tgen_void(node->loc);
		ret = 0;
		break;

	case AST_IMPORT:
		node->t = tgen_void(node->loc);
		ret = 0;
		break;

	default:
		internal_error("missing ast analysis for %s", ast_str(node->k));
		return -1;
	}

out:
	if (ret)
		return ret;

	assert(node->t);
	assert(node->scope);
	ast_set_flags(node, AST_FLAG_ANALYZED);
	return 0;
}

static int analyze_list(struct state *state, struct scope *scope,
                        struct ast *nodes)
{
	foreach_node(node, nodes) {
		if (analyze(state, scope, node))
			return -1;
	}

	return 0;
}

static int analyze_type(struct state *state, struct scope *scope,
                        struct type *type)
{
	if (type->t0 && analyze_type(state, scope, type->t0))
		return -1;

	/* temporary */
	if (type->k != TYPE_ID)
		return 0;

	char *id = type->id;

	if (strcmp(id, "i8") == 0)
		type->k = TYPE_I8;

	else if (strcmp(id, "u8") == 0)
		type->k = TYPE_U8;

	else if (strcmp(id, "i16") == 0)
		type->k = TYPE_I16;

	else if (strcmp(id, "u16") == 0)
		type->k = TYPE_U16;

	else if (strcmp(id, "i32") == 0)
		type->k = TYPE_I32;

	else if (strcmp(id, "u32") == 0)
		type->k = TYPE_U32;

	else if (strcmp(id, "i64") == 0)
		type->k = TYPE_I64;

	else if (strcmp(id, "u64") == 0)
		type->k = TYPE_U64;
	else {
		internal_error("unhandled type id: %s", id);
		abort();
	}

	return 0;
}

static int analyze_type_list(struct state *state, struct scope *scope,
                             struct type *types)
{
	foreach_type(type, types) {
		if (analyze_type(state, scope, type))
			return -1;
	}

	return 0;
}

static int copy_scope(bool reexport, struct scope *to, struct scope *from)
{
	foreach(visible, symbol, &from->symbols) {
		struct ast *def = symbol->data;
		if (!reexport && def->scope != from)
			continue;

		if (!ast_flags(def, AST_FLAG_PUBLIC))
			continue;

		if (scope_add_symbol(to, def))
			return -1;
	}

	foreach(visible, symbol, &from->types) {
		struct ast *def = symbol->data;
		if (!reexport && def->scope != from)
			continue;

		if (!ast_flags(def, AST_FLAG_PUBLIC))
			continue;

		if (scope_add_symbol(to, def))
			return -1;
	}

	foreach(visible, symbol, &from->traits) {
		struct ast *def = symbol->data;
		if (!reexport && def->scope != from)
			continue;

		if (!ast_flags(def, AST_FLAG_PUBLIC))
			continue;

		if (scope_add_symbol(to, def))
			return -1;
	}

	return 0;
}

/* allowed to be noisy */
static int try_import_file(struct scope *scope, struct ast *node)
{
	const char *file = import_file(node);
	struct scope *child = compile_file(file);
	if (!child) {
		semantic_info(scope, node, "imported here");
		return -1;
	}

	if (copy_scope(ast_flags(node, AST_FLAG_PUBLIC), scope, child)) {
		semantic_info(scope, node, "imported here");
		return -1;
	}

	return 0;
}

/* should be quiet upon failure */
static int try_import_mod(struct scope *scope, struct ast *node)
{
	/** @todo cleanup */
	void *mod = dlopen(import_file(node), RTLD_LAZY);
	if (!mod)
		return -1;

	fwd_open_t fwdopen = dlsym(mod, "fwdopen");
	if (!fwdopen) {
		dlclose(mod);
		return -1;
	}

	mod_vec_append(&scope->mods, mod);
	return fwdopen((void *)scope);
}

static struct type *fwd_type_kind(fwd_type_t type)
{
	switch (type) {
	case FWD_VOID: return tgen_void(NULL_LOC());
	case FWD_I64: return tgen_type(TYPE_I64, NULL, NULL, NULL_LOC());
	case FWD_PTR: {
		struct type *base = tgen_void(NULL_LOC());
		return tgen_ptr(base, NULL_LOC());
	}

	default:
		      break;
	}

	abort();
	return NULL;
}

int fwd_register(struct fwd_state *state, const char *name, fwd_extern_t func, fwd_type_t rtype, ...)
{
	struct scope *scope = (void *)state;
	struct ast *vars = NULL;

	va_list args;
	va_start(args, rtype);
	while (1) {
		fwd_type_t type = va_arg(args, enum fwd_type);
		if (type == FWD_END)
			break;

		struct ast *new = gen_var(strdup("arg"), fwd_type_kind(type), NULL_LOC());

		if (vars)
			vars->n = new;
		else
			vars = new;
	}
	va_end(args);

	vars = reverse_ast_list(vars);

	struct ast *def = gen_proc(strdup(name), vars,
			fwd_type_kind(rtype), NULL, NULL_LOC());

	if (scope_add_symbol(scope, def))
		return -1;

	return 0;
}

int analyze_root(struct scope *scope, struct ast *root)
{
	foreach_node(node, root) {
		switch (node->k) {
		case AST_PROC_DEF:
		      if (scope_add_symbol(scope, node))
			      return -1;
		      break;

		case AST_STRUCT_DEF:
		      if (scope_add_type(scope, node))
			      return -1;
		      break;

		case AST_STRUCT_CONT:
		      if (scope_add_chain(scope, node))
			      return -1;
		      break;

		case AST_TRAIT_DEF:
		      if (scope_add_trait(scope, node))
			      return -1;
		      break;

		case AST_IMPORT: {
			if (!try_import_mod(scope, node))
				break;

			if (!try_import_file(scope, node))
				break;

			return -1;
		}

		default:
		      abort();
		}
	}

	foreach_node(node, root) {
		struct state state = {};
		if (analyze(&state, scope, node))
			return -1;
	}

	return 0;
}