/* SPDX-License-Identifier: copyleft-next-0.3.1 */
/* Copyright 2024 Kim Kuparinen < kimi.h.kuparinen@gmail.com > */

/**
 * @file ast.c
 *
 * Abstract syntax tree handling implementations.
 */

#include <stdio.h>
#include <string.h>
#include <stdarg.h>
#include <stdlib.h>
#include <assert.h>
#include <math.h>

#include <fwd/ast.h>
#include <fwd/vec.h>
#include <fwd/scope.h>

static struct vec nodes = {0};
static struct vec types = {0};

static void destroy_ast_node(struct ast *n)
{
	if (!n)
		return;

	if (n->s)
		free(n->s);

	free(n);
}

static void destroy_type(struct type *n)
{
	if (!n)
		return;

	if (n->id)
		free(n->id);

	free(n);
}

static void destroy_ast_nodes()
{
	foreach_vec(ni, nodes) {
		struct ast *n = vect_at(struct ast *, nodes, ni);
		destroy_ast_node(n);
	}

	vec_destroy(&nodes);
}

static void destroy_types()
{
	foreach_vec(ti, types) {
		struct type *t = vect_at(struct type *, types, ti);
		destroy_type(t);
	}

	vec_destroy(&types);
}

void destroy_allocs()
{
	destroy_ast_nodes();
	destroy_types();
}

static struct ast *create_empty_ast()
{
	if (vec_uninit(nodes)) {
		nodes = vec_create(sizeof(struct ast *));
	}

	struct ast *n = calloc(1, sizeof(struct ast));
	/* just to be safe */
	n->k = AST_EMPTY;
	vect_append(struct ast *, nodes, &n);
	return n;
}

static struct type *create_empty_type()
{
	if (vec_uninit(types)) {
		types = vec_create(sizeof(struct type *));
	}

	struct type *n = calloc(1, sizeof(struct type));
	vect_append(struct ast *, types, &n);
	return n;
}

struct ast *gen_ast(enum ast_kind kind,
                    struct ast *a0,
                    struct ast *a1,
                    struct ast *a2,
                    struct ast *a3,
                    struct type *t2,
                    char *s,
                    long long v,
                    double d,
                    struct src_loc loc)
{
	struct ast *n = create_empty_ast();
	n->k = kind;
	n->a0 = a0;
	n->a1 = a1;
	n->a2 = a2;
	n->a3 = a3;
	n->t2 = t2;
	n->s = s;
	n->v = v;
	n->d = d;
	n->loc = loc;
	return n;
}

struct type *tgen_type(enum type_kind kind,
                       struct type *t0,
                       char *id,
                       struct src_loc loc)
{
	struct type *n = create_empty_type();
	n->k = kind;
	n->t0 = t0;
	n->id = id;
	n->loc = loc;
	return n;
}

void ast_append(struct ast **list, struct ast *elem)
{
	assert(list);
	assert(elem);

	struct ast *cur = *list;
	if (!cur) {
		*list = elem;
		return;
	}

	while (cur->n)
		cur = cur->n;

	cur->n = elem;
}

void type_append(struct type **list, struct type *elem)
{
	assert(list);
	assert(elem);

	struct type *cur = *list;
	if (!cur) {
		*list = elem;
		return;
	}

	while (cur->n)
		cur = cur->n;

	cur->n = elem;
}

struct ast *ast_prepend(struct ast *list, struct ast *elem)
{
	elem->n = list;
	return elem;
}

static void dump(int depth, const char *fmt, ...)
{
	va_list args;
	va_start(args, fmt);
	printf("//");
	for (int i = 0; i < depth; ++i)
		printf("  ");

	vprintf(fmt, args);

	va_end(args);
}

void ast_dump(int depth, struct ast *n)
{
	if (!n) {
		dump(depth, "{NULL}\n");
		return;
	}

#define DUMP(x) case x: dump(depth, #x); break;
	switch (n->k) {
	DUMP(AST_CLOSURE);
	DUMP(AST_IF);
	DUMP(AST_LET);
	DUMP(AST_INIT);
	DUMP(AST_CALL);
	DUMP(AST_PROC_DEF);
	DUMP(AST_VAR_DEF);
	DUMP(AST_DOT);
	DUMP(AST_BLOCK);
	DUMP(AST_ID);
	DUMP(AST_EMPTY);
	DUMP(AST_ADD);
	DUMP(AST_SUB);
	DUMP(AST_MUL);
	DUMP(AST_DIV);
	DUMP(AST_REM);
	DUMP(AST_LAND);
	DUMP(AST_LOR);
	DUMP(AST_LSHIFT);
	DUMP(AST_RSHIFT);
	DUMP(AST_LT);
	DUMP(AST_GT);
	DUMP(AST_LE);
	DUMP(AST_GE);
	DUMP(AST_NE);
	DUMP(AST_EQ);
	DUMP(AST_NEG);
	DUMP(AST_LNOT);
	DUMP(AST_NOT);
	DUMP(AST_CONST_INT);
	DUMP(AST_CONST_CHAR);
	DUMP(AST_CONST_BOOL);
	DUMP(AST_CONST_FLOAT);
	DUMP(AST_CONST_STR);
	}
#undef DUMP

	depth++;

	if (n->scope)
		printf(" {%llu}", (unsigned long long)n->scope->number);

	printf("\n");

	if (n->s)
		dump(depth, "%s\n", n->s);

	if (is_const(n))
		dump(depth, "%lli\n", n->v);

	if (n->a0)
		ast_dump_list(depth, n->a0);

	if (n->a1)
		ast_dump_list(depth, n->a1);

	if (n->a2)
		ast_dump_list(depth, n->a2);

	if (n->a3)
		ast_dump_list(depth, n->a3);
}

void ast_dump_list(int depth, struct ast *root)
{
	if (!root) {
		dump(depth, "{NULL}\n");
		return;
	}

	foreach_node(n, root) {
		ast_dump(depth, n);
	}
}

struct ast *clone_ast(struct ast *n)
{
	if (!n)
		return NULL;

	assert(n->k);
	struct ast *new = create_empty_ast();
	new->scope = n->scope;
	new->loc = n->loc;
	new->k = n->k;
	new->v = n->v;
	new->d = n->d;

	if (n->s)
		new->s = strdup(n->s);

	if (n->a0)
		new->a0 = clone_ast_list(n->a0);

	if (n->a1)
		new->a1 = clone_ast_list(n->a1);

	if (n->a2)
		new->a2 = clone_ast_list(n->a2);

	if (n->a3)
		new->a3 = clone_ast_list(n->a3);

	return new;
}

struct ast *clone_ast_list(struct ast *root)
{
	struct ast *n = root, *new_root = NULL, *prev = NULL;
	while (n) {
		struct ast *new = clone_ast(n);

		if (prev) prev->n = new;
		else new_root = new;

		prev = new;
		n = n->n;
	}

	return new_root;
}

struct type *clone_type(struct type *type)
{
	if (!type)
		return NULL;

	struct type *new = create_empty_type();
	if (!new)
		return NULL;

	new->k = type->k;
	if (type->id && !(new->id = strdup(type->id)))
		return NULL;

	if (type->t0 && !(new->t0 = clone_type_list(type->t0)))
		return NULL;

	return new;
}

struct type *clone_type_list(struct type *root)
{
	struct type *n = root, *new_root = NULL, *prev = NULL;
	while (n) {
		struct type *new = clone_type(n);

		if (prev) prev->n = new;
		else new_root = new;

		prev = new;
		n = n->n;
	}

	return new_root;
}

int ast_visit(ast_callback_t before, ast_callback_t after, struct ast *n,
              void *d)
{
	int ret = 0;
	if (!n)
		return ret;

	if (before && (ret = before(n, d)))
		return ret;

	if (n->a0 && (ret = ast_visit_list(before, after, n->a0, d)))
		return ret;

	if (n->a1 && (ret = ast_visit_list(before, after, n->a1, d)))
		return ret;

	if (n->a2 && (ret = ast_visit_list(before, after, n->a2, d)))
		return ret;

	if (n->a3 && (ret = ast_visit_list(before, after, n->a3, d)))
		return ret;

	if (after && (ret = after(n, d)))
		return ret;

	return ret;
}

int ast_visit_list(ast_callback_t before, ast_callback_t after, struct ast *l,
                   void *d)
{
	int ret = 0;
	foreach_node(n, l) {
		if ((ret = ast_visit(before, after, n, d)))
			return ret;
	}

	return ret;
}

size_t ast_list_len(struct ast *node)
{
	size_t count = 0;
	while (node) {
		count++;
		node = node->n;
	}

	return count;
}

size_t type_list_len(struct type *node)
{
	size_t count = 0;
	while (node) {
		count++;
		node = node->n;
	}

	return count;
}

struct ast *ast_last(struct ast *list)
{
	if (!list)
		return NULL;

	while (list->n)
		list = list->n;

	return list;
}

struct ast *ast_block_last(struct ast *block)
{
	struct ast *b = ast_last(block);
	if (b && b->k == AST_BLOCK)
		return ast_block_last(block_body(b));

	return b;
}

int same_id(char *id1, char *id2)
{
	return strcmp(id1, id2) == 0;
}

int equiv_nodes(struct ast *n1, struct ast *n2)
{
	if (n1 && !n2)
		return 0;

	if (!n1 && n2)
		return 0;

	if (!n1 && !n2)
		return 1;

	if (n1->k != n2->k)
		return 0;

	if (n1->s && strcmp(n1->s, n2->s) != 0)
		return 0;

	if (n1->a0 && !equiv_node_lists(n1->a0, n2->a0))
		return 0;

	if (n1->a1 && !equiv_node_lists(n1->a1, n2->a1))
		return 0;

	if (n1->a2 && !equiv_node_lists(n1->a2, n2->a2))
		return 0;

	if (n1->a3 && !equiv_node_lists(n1->a3, n2->a3))
		return 0;

	return 1;
}

int equiv_node_lists(struct ast *c1, struct ast *c2)
{
	do {
		if (!equiv_nodes(c1, c2))
			return 0;

		c1 = c1->n;
		c2 = c2->n;

	} while (c1 && c2);

	return 1;
}

size_t align3k(size_t o)
{
	size_t rem = o % 3;
	if (rem)
		o += 3 - rem;

	return o;
}

struct ast *reverse_ast_list(struct ast *root)
{
	struct ast *new_root = NULL;
	while (root) {
		struct ast *next = root->n;
		root->n = new_root;
		new_root = root;
		root = next;
	}

	return new_root;
}

struct type *reverse_type_list(struct type *root)
{
	struct type *new_root = NULL;
	while (root) {
		struct type *next = root->n;
		root->n = new_root;
		new_root = root;
		root = next;
	}

	return new_root;
}

void fix_closures(struct ast *root)
{
	while (root) {
		if (root->k != AST_CALL) {
			root = root->n;
			continue;
		}
		struct ast *arg = ast_last(call_args(root));
		if (!arg) {
			root = root->n;
			continue;
		}

		if (arg->k != AST_CLOSURE) {
			root = root->n;
			continue;
		}

		if (closure_body(arg) != NULL) {
			root = root->n;
			continue;
		}

		struct ast *next = root->n;
		struct ast *block = gen_block(next, next->loc);
		closure_body(arg) = block;
		root->n = NULL;
		root = next;
	}
}

static bool special_auto_very_bad(struct type *a, struct type *b)
{
	/** @todo massive hack, accept 'auto' as a placeholder and match it
	 * against anything, will need to be fixed eventually */
	if (a->k == TYPE_ID && strcmp(a->id, "auto") == 0)
		return true;

	if (b->k == TYPE_ID && strcmp(b->id, "auto") == 0)
		return true;

	return false;
}

bool types_match(struct type *a, struct type *b)
{
	if (!a && !b)
		return true;

	if (a && !b)
		return false;

	if (!a && b)
		return false;

	if (special_auto_very_bad(a, b))
		return true;

	if (a->k != b->k)
		return false;

	if (a->id && b->id && strcmp(a->id, b->id) != 0)
		return false;

	if (!type_lists_match(a->t0, b->t0))
		return false;

	return true;
}

bool type_lists_match(struct type *al, struct type *bl)
{
	if (type_list_len(al) != type_list_len(bl))
		return false;

	struct type *a = al;
	struct type *b = bl;
	while (a && b) {
		if (!types_match(a, b))
			return false;

		a = a->n;
		b = b->n;
	}
	return true;
}

const char *ast_str(enum ast_kind k)
{
#define CASE(x) case x: return #x;
	switch (k) {
	CASE(AST_CLOSURE);
	CASE(AST_IF);
	CASE(AST_LET);
	CASE(AST_INIT);
	CASE(AST_CALL);
	CASE(AST_PROC_DEF);
	CASE(AST_VAR_DEF);
	CASE(AST_DOT);
	CASE(AST_BLOCK);
	CASE(AST_ID);
	CASE(AST_EMPTY);
	CASE(AST_ADD);
	CASE(AST_SUB);
	CASE(AST_MUL);
	CASE(AST_DIV);
	CASE(AST_REM);
	CASE(AST_LAND);
	CASE(AST_LOR);
	CASE(AST_LSHIFT);
	CASE(AST_RSHIFT);
	CASE(AST_LT);
	CASE(AST_GT);
	CASE(AST_LE);
	CASE(AST_GE);
	CASE(AST_NE);
	CASE(AST_EQ);
	CASE(AST_NEG);
	CASE(AST_LNOT);
	CASE(AST_NOT);
	CASE(AST_CONST_INT);
	CASE(AST_CONST_CHAR);
	CASE(AST_CONST_BOOL);
	CASE(AST_CONST_FLOAT);
	CASE(AST_CONST_STR);
	}
#undef CASE

	return "UNKNOWN";
}