#include <stdio.h>
#include <string.h>
#include <stdarg.h>
#include <stdlib.h>
#include <assert.h>
#include <math.h>

#include <posthaste/ast.h>
#include <posthaste/vec.h>

static struct vec nodes = {0};

static void destroy_ast_node(struct ast *n)
{
	if (!n)
		return;

	if (n->id)
		free(n->id);

	if (n->type)
		free(n->type);

	free(n);
}

void destroy_ast_nodes()
{
	foreach_vec(ni, nodes) {
		struct ast *n = vect_at(struct ast *, nodes, ni);
		destroy_ast_node(n);
	}

	vec_destroy(&nodes);
}

static struct ast *create_empty_ast()
{
	if (vec_uninit(nodes)) {
		nodes = vec_create(sizeof(struct ast *));
	}

	struct ast *n = calloc(1, sizeof(struct ast));
	vect_append(struct ast *, nodes, &n);
	return n;
}

struct ast *gen_ast(enum ast_kind kind,
                    struct ast *a0,
                    struct ast *a1,
                    struct ast *a2,
                    char *id,
                    char *type,
                    int64_t v,
                    struct src_loc loc)
{
	struct ast *n = create_empty_ast();
	n->k = kind;
	n->a0 = a0;
	n->a1 = a1;
	n->a2 = a2;
	n->id = id;
	n->type = type;
	n->v = v;
	n->loc = loc;
	return n;
}

#ifdef DEBUG
static void dump(int depth, const char *fmt, ...)
{
	va_list args;
	va_start(args, fmt);
	printf("//");
	for (int i = 0; i < depth; ++i)
		printf("  ");

	vprintf(fmt, args);

	va_end(args);
}

void type_dump(enum type_kind t)
{
	switch (t) {
	case TYPE_DATE: printf("TYPE_DATE"); break;
	case TYPE_INT: printf("TYPE_INT"); break;
	default: printf("NO TYPE"); break;
	}
}

void ast_dump(int depth, struct ast *n)
{
	if (!n) {
		dump(depth, "{NULL}\n");
		return;
	}

#define DUMP(x) case x: dump(depth, #x); break;
	switch (n->k) {
	DUMP(AST_BUILTIN_CALL);
	DUMP(AST_PRINT_STRING);
	DUMP(AST_PRINT_BOOL);
	DUMP(AST_PRINT_DATE);
	DUMP(AST_PRINT_INT);
	DUMP(AST_CONST_INT);
	DUMP(AST_CONST_DATE);
	DUMP(AST_CONST_STRING);
	DUMP(AST_DATE_ADD);
	DUMP(AST_DATE_SUB);
	DUMP(AST_DATE_DIFF);
	DUMP(AST_EQ);
	DUMP(AST_LT);
	DUMP(AST_ADD);
	DUMP(AST_SUB);
	DUMP(AST_MUL);
	DUMP(AST_DIV);
	DUMP(AST_POS);
	DUMP(AST_NEG);
	DUMP(AST_UNLESS);
	DUMP(AST_UNLESS_EXPR);
	DUMP(AST_UNTIL);
	DUMP(AST_FUNC_CALL);
	DUMP(AST_PROC_CALL);
	DUMP(AST_ID);
	DUMP(AST_RETURN);
	DUMP(AST_ASSIGN);
	DUMP(AST_DOT);
	DUMP(AST_ATTR);
	DUMP(AST_PRINT);
	DUMP(AST_FUNC_DEF);
	DUMP(AST_PROC_DEF);
	DUMP(AST_FORMAL_DEF);
	DUMP(AST_VAR_DEF);
	}
#undef DUMP

	depth++;

	printf(" ("); type_dump(n->t); printf(")\n");

	if (n->id)
		dump(depth, "%s\n", n->id);

	if (n->type)
		dump(depth, "%s\n", n->type);

	if (n->a0)
		ast_dump_list(depth, n->a0);

	if (n->a1)
		ast_dump_list(depth, n->a1);

	if (n->a2)
		ast_dump_list(depth, n->a2);
}

void ast_dump_list(int depth, struct ast *root)
{
	if (!root) {
		dump(depth, "{NULL}\n");
		return;
	}

	foreach_node(n, root) {
		ast_dump(depth, n);
	}
}
#endif /* DEBUG */

int ast_visit(ast_callback_t before, ast_callback_t after, struct ast *n,
              void *d)
{
	int ret = 0;
	if (!n)
		return ret;

	if (before && (ret = before(n, d)))
		return ret;

	if (n->a0 && (ret = ast_visit_list(before, after, n->a0, d)))
		return ret;

	if (n->a1 && (ret = ast_visit_list(before, after, n->a1, d)))
		return ret;

	if (n->a2 && (ret = ast_visit_list(before, after, n->a2, d)))
		return ret;

	if (after && (ret = after(n, d)))
		return ret;

	return ret;
}

int ast_visit_list(ast_callback_t before, ast_callback_t after, struct ast *l,
                   void *d)
{
	int ret = 0;
	foreach_node(n, l) {
		if ((ret = ast_visit(before, after, n, d)))
			return ret;
	}

	return ret;
}

struct ast *ast_last(struct ast *list)
{
	if (!list)
		return NULL;

	while (list->n)
		list = list->n;

	return list;
}

size_t ast_list_len(struct ast *l)
{
	size_t count = 0;
	foreach_node(n, l) {
		count++;
	}

	return count;
}