/* SPDX-License-Identifier: copyleft-next-0.3.1 */
/* Copyright 2024 Kim Kuparinen < kimi.h.kuparinen@gmail.com > */

/** @file scope.c
 *
 * Implementations for scope handling stuff.
 */

#include <stddef.h>
#include <string.h>
#include <stdio.h>
#include <stdarg.h>
#include <stdlib.h>
#include <assert.h>

#include <fwd/debug.h>
#include <fwd/scope.h>

struct scope *create_scope()
{
	/* if I ever try making the parser multithreaded, this should be atomic. */
	static size_t counter = 0;

	struct scope *scope = calloc(1, sizeof(struct scope));
	if (!scope) {
		internal_error("ran out of memory allocating scope");
		return NULL;
	}

	scope->number = counter++;
	return scope;
}

static void destroy_visible(struct visible *visible)
{
	struct visible *prev = visible, *cur;
	if (prev)
		do {
			cur = prev->next;
			free(prev);
		} while ((prev = cur));
}

void destroy_scope(struct scope *scope)
{
	if (!scope)
		return;

	if (!scope->parent) {
		free((void *)scope->fctx.fbuf);
		free((void *)scope->fctx.fname);
	}

	destroy_visible(scope->symbols);

	struct scope *prev = scope->children, *cur;
	if (prev)
		do {
			cur = prev->next;
			destroy_scope(prev);
		} while ((prev = cur));

	free(scope);
}

static struct visible *create_visible(char *id, struct ast *node)
{
	struct visible *visible = calloc(1, sizeof(struct visible));
	if (!visible)
		return NULL;

	visible->id = id;
	visible->node = node;
	return visible;
}

struct visible *create_var(struct scope *scope, char *id, struct ast *var)
{
	struct visible *n = create_visible(id, var);
	if (!n)
		return NULL;

	n->next = scope->symbols;
	scope->symbols = n;

	return n;
}

struct visible *create_proc(struct scope *scope, char *id, struct ast *proc)
{
	struct visible *n = create_visible(id, proc);
	if (!n)
		return NULL;

	n->next = scope->symbols;
	scope->symbols = n;

	return n;
}

int scope_add_var(struct scope *scope, struct ast *var)
{
	struct ast *exists = scope_find_symbol(scope, var_id(var));
	if (exists) {
		semantic_error(scope->fctx, var, "var redefined");
		semantic_info(scope->fctx, exists, "previously here");
		return -1;
	}

	create_var(scope, var_id(var), var);
	return 0;
}

int scope_add_proc(struct scope *scope, struct ast *proc)
{
	assert(proc->k == AST_PROC_DEF);
	struct ast *exists = file_scope_find_symbol(scope, proc_id(proc));
	if (exists) {
		semantic_error(scope->fctx, proc, "proc redefined");
		semantic_info(scope->fctx, exists, "previously here");
		return -1;
	}

	/* always add to scope, do resolve checking later */
	create_proc(scope, proc_id(proc), proc);
	return 0;
}

static struct ast *scope_find_visible(struct visible *v, char *id)
{
	if (!v)
		return NULL;

	foreach_visible(n, v) {
		struct ast *node = n->node;
		if (same_id(node->s, id))
			return node;
	}

	return NULL;
}

struct ast *scope_find_proc(struct scope *scope, char *id)
{
	struct ast *n = scope_find_visible(scope->symbols, id);
	if (!n)
		return NULL;

	if (n->k != AST_PROC_DEF)
		return NULL;

	return n;
}

struct ast *file_scope_find_proc(struct scope *scope, char *id)
{
	struct ast *n = file_scope_find_symbol(scope, id);
	if (!n)
		return NULL;

	if (n->k != AST_PROC_DEF)
		return NULL;

	return n;
}

struct ast *scope_find_symbol(struct scope *scope, char *id)
{
	return scope_find_visible(scope->symbols, id);
}

struct ast *file_scope_find_symbol(struct scope *scope, char *id)
{
	if (!scope)
		return NULL;

	struct ast *found = scope_find_symbol(scope, id);
	if (found)
		return found;

	return file_scope_find_symbol(scope->parent, id);
}

struct ast *scope_find_var(struct scope *scope, char *id)
{
	struct ast *n = scope_find_visible(scope->symbols, id);
	if (!n)
		return NULL;

	if (n->k != AST_VAR_DEF)
		return NULL;

	return n;
}

struct ast *file_scope_find_var(struct scope *scope, char *id)
{
	if (!scope)
		return NULL;

	struct ast *found = scope_find_var(scope, id);
	if (found)
		return found;

	return file_scope_find_var(scope->parent, id);
}

void scope_add_scope(struct scope *parent, struct scope *child)
{
	assert(parent);
	assert(child);

	child->fctx = parent->fctx;
	child->parent = parent;
	child->next = parent->children;
	parent->children = child;
}