/* SPDX-License-Identifier: copyleft-next-0.3.1 */
/* Copyright 2024 Kim Kuparinen < kimi.h.kuparinen@gmail.com > */

#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <stdarg.h>
#include <stdint.h>
#include <assert.h>

#include <fwd/lower.h>
#include <fwd/scope.h>
#include <fwd/mod.h>

/* stuff to be careful of:
 *
 * + Really easy to forget to pass the stack argument via reference, since it's
 *   just a void* we don't get a warning for it
 *
 * + lower_call() and lower_closure_call() share a lot of the same stuff, might
 *   want to try and merge them somehow, otherwise too easy to make a change in
 *   one and forget the other
 */

/* placeholder, should probably scale according to biggest need */
#define FWD_FRAME_SIZE 1024

static inline void mangle(FILE *f, struct ast *id)
{
	assert(id->k == AST_PROC_DEF || id->k == AST_VAR_DEF || id->k == AST_ID);
	assert(id->s && id->scope);
	fprintf(f, "%s_s%zu", id->s, id->scope->number);
}

static inline char *mangle2(struct ast *id)
{
	char *buf = NULL; size_t size = 0;
	FILE *f = open_memstream(&buf, &size);
	assert(f);

	mangle(f, id);
	fclose(f);
	assert(buf);

	return buf;
}

static inline __attribute__((format (printf, 1, 2)))
char *buildstr(const char *fmt, ...)
{
	va_list args;
	va_start(args, fmt);

	char *buf = NULL; size_t size = 0;
	FILE *f = open_memstream(&buf, &size);
	assert(f);

	vfprintf(f, fmt, args);
	fclose(f);

	va_end(args);
	assert(buf);
	return buf;
}

/** @todo semantics in this file are a bit unclear, should probably do some kind
 * of "each function starts and ends on an indented empty line" or something */

#define VEC_NAME string_vec
#define VEC_TYPE char *
#include <conts/vec.h>

#define MAP_NAME proc_set
#define MAP_KEY struct ast *
#define MAP_TYPE struct ast *
#define MAP_CMP(a, b) ((uintptr_t)(a) - (uintptr_t)(b))
#define MAP_HASH(a) ((uintptr_t)(a))
#include <conts/map.h>

struct state {
	char *prefix;
	long indent;
	size_t uniq;
	FILE *ctx;
	FILE *code;
	struct ast *current;
	struct string_vec *decls;
	struct string_vec *defns;
	struct string_vec *types;
	struct proc_set *procs;
};

static struct state create_state(struct state *parent)
{
	return (struct state){
		.uniq = parent->uniq,
		.prefix = parent->prefix,
		.indent = parent->indent,
		.current = parent->current,
		.ctx = parent->ctx,
		.code = parent->code,
		.decls = parent->decls,
		.defns = parent->defns,
		.types = parent->types,
		.procs = parent->procs,
	};
}

static size_t uniq(struct state *state)
{
	return state->uniq++;
}

static void increase_indent(struct state *state)
{
	state->indent++;
}

static void decrease_indent(struct state *state)
{
	state->indent--;
}

static void indent(struct state *state)
{
	if (state->indent != 0)
		fprintf(state->code, "%*c", (int)(2 * state->indent), ' ');
}

static bool proc_lowered(struct state *state, struct ast *proc)
{
	assert(proc->k == AST_PROC_DEF);
	return proc_set_find(state->procs, proc) != NULL;
}

static void add_proc(struct state *state, struct ast *proc)
{
	assert(proc->k == AST_PROC_DEF);
	struct ast **p = proc_set_insert(state->procs, proc, proc);
	assert(p && *p == proc);
}

static void add_decl(struct state *state, char *decl)
{
	string_vec_append(state->decls, decl);
}

static void add_defn(struct state *state, char *defn)
{
	string_vec_append(state->defns, defn);
}

static void add_type(struct state *state, char *type)
{
	string_vec_append(state->types, type);
}

static int lower_stmt_list(struct state *state, struct ast *stmt_list, bool last);
static int lower_block(struct state *state, struct ast *block, bool last);
static int lower_params(struct state *state, struct ast *params);
static int lower_expr(struct state *state, struct ast *expr);
static int lower_proc(struct state *state, struct ast *proc);

static void _type_str(FILE *f, struct type *type)
{
	assert(type);
	switch (type->k) {
	case TYPE_I8:
		fprintf(f, "int8_t");
		break;

	case TYPE_U8:
		fprintf(f, "uint8_t");
		break;

	case TYPE_I16:
		fprintf(f, "int16_t");
		break;

	case TYPE_U16:
		fprintf(f, "uint16_t");
		break;

	case TYPE_I32:
		fprintf(f, "int32_t");
		break;

	case TYPE_U32:
		fprintf(f, "uint32_t");
		break;

	case TYPE_I64:
		fprintf(f, "int64_t");
		break;

	case TYPE_U64:
		fprintf(f, "uint64_t");
		break;

	case TYPE_CLOSURE:
		fprintf(f, "fwd_closure_t");
		break;

	default:
		internal_error("unhandled type lowering for %s", type_str(type));
		abort();
		break;
	}
}

static char *lower_type_str(struct type *type)
{
	assert(type);

	char *type_buf = NULL; size_t type_len = 0;
	FILE *f = open_memstream(&type_buf, &type_len);
	assert(f);

	_type_str(f, type);
	fclose(f);
	assert(type_buf);
	return type_buf;
}

static int lower_closure_call(struct state *state, struct ast *call, struct ast *def, bool last)
{
	char *q = buildstr("%s_call%zu", state->prefix, uniq(state));
	char *args = mangle2(def);

	char *ctx_buf = NULL; size_t ctx_size = 0;
	FILE *ctx = open_memstream(&ctx_buf, &ctx_size);
	fprintf(ctx, "struct %s {\n  fwd_start_t start;\n", q);

	indent(state);
	fprintf(state->code, "fwd_call_t %s = ctx->%s.call;\n", q, args);

	indent(state);
	fprintf(state->code, "struct %s *%s_ctx = ctx->%s.args;\n", q, q, args);

	int ret = 0;
	size_t idx = 0;
	bool returning = false;
	foreach_node(a, call_args(call)) {
		returning |= a->k == AST_CLOSURE;

		char *type = lower_type_str(a->t);
		fprintf(ctx, "  %s a%zu;\n", type, idx);
		free(type);

		indent(state);
		fprintf(state->code, "%s_ctx->a%zu = ", q, idx);
		int ret = lower_expr(state, a);
		fprintf(state->code, ";\n");
		if (ret)
			goto out;

		idx++;
	}

out:
	if (!returning && last && ast_flags(state->current, AST_REQ_FRAME)) {
		indent(state);
		fprintf(state->code, "fwd_stack_free(&stack, ctx);\n");
	}

	indent(state);
	if (last) {
		fprintf(state->code, "FWD_MUSTTAIL return %s(stack, %s_ctx);\n",
				q, q);
	}
	else {
		fprintf(state->code, "stack = %s_call(stack, %s_ctx);\n",
				q, q);
	}

	fprintf(ctx, "};\n\n");
	fclose(ctx);
	assert(ctx_buf);

	add_type(state, ctx_buf);
	free(args);
	free(q);
	return ret;
}

static int lower_closure(struct state *state, struct ast *closure, char **name_out, char **args_out)
{
	char *name = buildstr("%s_closure%zu", state->prefix, uniq(state));
	char *proto = buildstr("static fwd_stack_t %s(fwd_stack_t stack, fwd_args_t args)", name);
	char *decl = buildstr("%s;\n", proto);
	char *start_of = buildstr("%s_start", name);
	add_decl(state, decl);

	char *code_buf = NULL;
	size_t code_size = 0;

	struct state new_state = create_state(state);
	new_state.code = open_memstream(&code_buf, &code_size);
	new_state.indent = 0;
	assert(new_state.code);

	fprintf(new_state.code, "%s\n{\n", proto);

	/** @todo unsure if this has the same address as the first element in
	 * the frame or the last in the previous, I guess we shall soon see */
	fprintf(new_state.ctx, "  fwd_start_t %s;\n", start_of);

	int ret = lower_params(&new_state, closure_bindings(closure));
	assert(ret == 0);

	increase_indent(&new_state);
	indent(&new_state);
	fprintf(new_state.code, "struct %s_ctx *ctx = FWD_CONTAINER_OF(args, struct %s_ctx, %s);\n",
			state->prefix, state->prefix, start_of);

	struct ast *block = closure_body(closure);
	ret = lower_stmt_list(&new_state, block_body(block), true);

	fprintf(new_state.code, "}\n\n");
	fclose(new_state.code);
	assert(code_buf);

	add_defn(state, code_buf);
	free(proto);

	assert(name_out);
	assert(args_out);
	*name_out = name;
	*args_out = start_of;
	return ret;
}

static int lower_comparison(struct state *state, struct ast *expr)
{
	if (lower_expr(state, comparison_left(expr)))
		return -1;

	switch (expr->k) {
	case AST_LE:
		fprintf(state->code, " <= ");
		break;

	case AST_LT:
		fprintf(state->code, " < ");
		break;

	default:
		internal_error("unhandled lowering for comparison %s", ast_str(expr->k));
		abort();
	}

	return lower_expr(state, comparison_right(expr));
}

static int lower_binop(struct state *state, struct ast *expr)
{
	if (lower_expr(state, binop_left(expr)))
		return -1;

	switch (expr->k) {
	case AST_SUB:
		fprintf(state->code, " - ");
		break;

	case AST_ADD:
		fprintf(state->code, " + ");
		break;

	default:
		internal_error("unhandled comparison %s", ast_str(expr->k));
		abort();
	}

	return lower_expr(state, binop_right(expr));
}

static int lower_expr(struct state *state, struct ast *expr)
{
	if (is_comparison(expr))
		return lower_comparison(state, expr);

	if (is_binop(expr))
		return lower_binop(state, expr);

	switch (expr->k) {
	case AST_CLOSURE: {
		char *name = NULL, *args = NULL;
		if (lower_closure(state, expr, &name, &args))
			return -1;

		fprintf(state->code, "(fwd_closure_t){%s, &ctx->%s}", name, args);
		free(name);
		free(args);
		break;
	}

	case AST_ID: {
		struct ast *def = file_scope_find_symbol(expr->scope, id_str(expr));
		char *name = mangle2(def);
		fprintf(state->code, "ctx->%s", name);
		free(name);
		break;
	}

	case AST_CONST_INT: {
		fprintf(state->code, "%lld", (long long)int_val(expr));
		break;
	}
	default:
		internal_error("unhandled expr lowering: %s", ast_str(expr->k));
		abort();
	}

	return 0;
}

static int lower_call(struct state *state, struct ast *call, bool last)
{
	struct ast *expr = call_expr(call);
	assert(expr->k == AST_ID);

	struct ast *def = file_scope_find_symbol(call->scope, id_str(expr));
	assert(def);

	if (def->k == AST_VAR_DEF)
		return lower_closure_call(state, call, def, last);

	if (lower_proc(state, def))
		return -1;

	char *q = buildstr("%s_call%zu", state->prefix, uniq(state));

	struct state ctx_state = create_state(state);

	char *ctx_buf = NULL; size_t ctx_size = 0;
	FILE *ctx = open_memstream(&ctx_buf, &ctx_size);
	fprintf(ctx, "struct %s {\n  fwd_start_t start;\n", q);

	indent(state);
	fprintf(state->code, "struct %s *%s = ctx->global_args;\n", q, q);

	int ret = 0;
	size_t idx = 0;
	bool returning = false;
	foreach_node(a, call_args(call)) {
		returning |= a->k == AST_CLOSURE;

		char *type = lower_type_str(a->t);
		fprintf(ctx, "  %s a%zu;\n", type, idx);
		free(type);

		indent(&ctx_state);
		fprintf(ctx_state.code, "%s->a%zu = ", q, idx);
		int ret = lower_expr(&ctx_state, a);
		fprintf(ctx_state.code, ";\n");

		if (ret)
			goto out;

		idx++;
	}

out:
	if (!returning && last && ast_flags(state->current, AST_REQ_FRAME)) {
		/** @todo unsure if this applies to all cases but seems to work
		 * for the simple examples I've tried so far */
		indent(&ctx_state);
		fprintf(ctx_state.code, "fwd_stack_free(&stack, ctx);\n");
	}

	char *target = mangle2(def);
	indent(&ctx_state);
	if (last) {
		fprintf(ctx_state.code, "FWD_MUSTTAIL return %s(stack, %s);\n",
				target, q);
	}
	else {
		fprintf(ctx_state.code, "stack = %s(stack, %s);\n",
				target, q);
	}

	fprintf(ctx, "};\n\n");
	fclose(ctx);
	assert(ctx_buf);

	add_type(state, ctx_buf);
	free(target);
	free(q);
	return ret;
}

static int lower_if(struct state *state, struct ast *stmt, bool last)
{
	indent(state);
	fprintf(state->code, "if (");
	if (lower_expr(state, if_cond(stmt)))
		return -1;

	fprintf(state->code, ")\n");
	if (lower_block(state, if_body(stmt), last))
		return -1;

	if (!if_else(stmt))
		return 0;

	indent(state);
	fprintf(state->code, "else\n");
	return lower_block(state, if_else(stmt), last);
}

static int lower_stmt(struct state *state, struct ast *stmt, bool last)
{
	switch (stmt->k) {
	case AST_CALL: return lower_call(state, stmt, last);
	case AST_IF: return lower_if(state, stmt, last);
	default:
		internal_error("unhandled statement kind %s", ast_str(stmt->k));
		abort();
		break;
	}
	return 0;
}

static int lower_stmt_list(struct state *state, struct ast *stmt_list, bool last)
{
	foreach_node(stmt, stmt_list) {
		bool req_ret = last && !stmt->n;
		if (lower_stmt(state, stmt, req_ret))
			return -1;
	}

	return 0;
}

static int lower_block(struct state *state, struct ast *block, bool last)
{
	indent(state);
	fprintf(state->code, "{\n");

	increase_indent(state);
	int ret = lower_stmt_list(state, block_body(block), last);
	decrease_indent(state);

	indent(state);
	fprintf(state->code, "}\n");
	return ret;
}

static int lower_param(struct state *state, struct ast *param)
{
	char *type = lower_type_str(param->t);
	char *name = mangle2(param);

	fprintf(state->ctx, "  %s %s;\n", type, name);

	free(type);
	free(name);
	return 0;
}

static int lower_params(struct state *state, struct ast *params)
{
	foreach_node(p, params) {
		if (lower_param(state, p))
			return -1;
	}

	return 0;
}

static size_t fwd_typeid(struct type *t)
{
	/** @todo this should maybe be cached somewhere earlier? */
	switch (t->k) {
	case TYPE_I64: return FWD_I64;
	case TYPE_PTR: return FWD_PTR;
	default:
		       abort();
	}

	return 0;
}

static const char *fwd_typeparam(struct type *t)
{
	switch (t->k) {
	case TYPE_I64: return "i64";
	case TYPE_PTR: return "p";
	default:
		       abort();
	}

	return 0;
}

static int lower_extern_proc(struct state *state, struct ast *proc)
{
	/* only void external functions supported atm */
	struct type *rtype = proc_rtype(proc);
	assert(rtype->k == TYPE_VOID);

	add_proc(state, proc);
	char *name = mangle2(proc);
	char *proto = buildstr("static fwd_stack_t %s(fwd_stack_t stack, fwd_args_t args)",
			name);

	char *decl = buildstr("%s;\n", proto);
	add_decl(state, decl);

	char *code_buf = NULL, *ctx_buf = NULL;
	size_t code_size = 0, ctx_size = 0;

	struct state new_state = create_state(state);
	new_state.indent = 0;
	new_state.prefix = name;
	new_state.current = proc;
	new_state.uniq = 0;
	new_state.code = open_memstream(&code_buf, &code_size);
	new_state.ctx = open_memstream(&ctx_buf, &ctx_size);
	assert(new_state.code);

	fprintf(new_state.ctx, "struct %s_ctx {\n", name);
	fprintf(new_state.code, "%s\n{\n", proto);
	increase_indent(&new_state);

	indent(&new_state);
	fprintf(new_state.code, "extern long %s(fwd_extern_args_t);\n", proc_id(proc));

	/* for now, allocate a new frame for arguments. Might want to
	 * simplify this by always using the 'external' format for argument
	 * passing, even internally */
	indent(&new_state);
	fprintf(new_state.code, "struct %s_ctx *ctx = args;\n", name);

	indent(&new_state);
	fprintf(new_state.code, "struct fwd_arg *extern_args = fwd_stack_alloc(&stack);\n");

	size_t idx = 0;
	foreach_node(p, proc_params(proc)) {
		indent(&new_state);
		/* leave place for return value */
		fprintf(new_state.code, "extern_args[%zu] = (fwd_arg_t){%zu, "
				"{.%s = ctx->a%zu}};\n",
				idx + 1, fwd_typeid(p->t), fwd_typeparam(p->t), idx);


		char *type_str = lower_type_str(p->t);
		fprintf(new_state.ctx, "  %s a%zu;\n", type_str, idx);
		free(type_str);
		idx++;
	}

	indent(&new_state);
	fprintf(new_state.code, "%s((fwd_extern_args_t){.argc = %zu, .args = extern_args});\n",
			proc_id(proc), idx);


	indent(&new_state);
	fprintf(new_state.code, "fwd_stack_free(&stack, extern_args);\n");

	indent(&new_state);
	fprintf(new_state.code, "return stack;\n");

	fprintf(new_state.code, "}\n\n");
	fprintf(new_state.ctx, "};\n\n");

	fclose(new_state.code);
	assert(code_buf);
	add_defn(state, code_buf);

	fclose(new_state.ctx);
	assert(ctx_buf);
	add_type(state, ctx_buf);

	free(proto);
	free(name);
	return 0;
}

static int lower_param_copy(struct state *state, struct ast *param, FILE *f, size_t idx)
{
	char *type = lower_type_str(param->t);
	fprintf(f, "  %s a%zu;\n", type, idx);
	free(type);
	return 0;
}

static int lower_param_copies(struct state *state, struct ast *params)
{
	char *param_buf = NULL; size_t param_size = 0;
	FILE *f = open_memstream(&param_buf, &param_size);
	fprintf(f, "struct %s_params {\n  fwd_start_t start;\n", state->prefix);

	size_t idx = 0;
	foreach_node(p, params) {
		if (lower_param_copy(state, p, f, idx)) {
			fclose(f);
			free(param_buf);
			return -1;
		}

		idx++;
	}

	fprintf(f, "};\n\n");
	fclose(f);
	assert(param_buf);

	add_type(state, param_buf);

	indent(state);
	fprintf(state->code, "*((struct %s_params *)&ctx->%s_start) ="
			" *((struct %s_params *)args);\n",
			state->prefix, state->prefix, state->prefix);
	return 0;
}

static int lower_proc(struct state *state, struct ast *proc)
{
	if (proc_lowered(state, proc))
		return 0;

	if (!proc_body(proc))
		return lower_extern_proc(state, proc);

	add_proc(state, proc);
	char *name = mangle2(proc);
	char *proto = buildstr("static fwd_stack_t %s(fwd_stack_t stack, fwd_args_t args)", name);
	char *decl = buildstr("%s;\n", proto);
	add_decl(state, decl);

	char *code_buf = NULL, *ctx_buf = NULL;
	size_t code_size = 0, ctx_size = 0;

	struct state new_state = create_state(state);
	new_state.indent = 0;
	new_state.prefix = name;
	new_state.current = proc;
	new_state.uniq = 0;
	new_state.code = open_memstream(&code_buf, &code_size);
	new_state.ctx = open_memstream(&ctx_buf, &ctx_size);
	assert(new_state.code);
	assert(new_state.ctx);

	char *start_of = buildstr("%s_start", name);

	fprintf(new_state.code, "%s\n{\n", proto);
	fprintf(new_state.ctx, "struct %s_ctx {\n", name);
	fprintf(new_state.ctx, "  fwd_args_t global_args;\n  fwd_start_t %s;\n", start_of);

	increase_indent(&new_state);
	indent(&new_state);
	fprintf(new_state.code, "static_assert(FWD_FRAME_SIZE >= sizeof(struct %s_ctx), \"context exceeds frame size\");\n",
			name);

	if (ast_flags(proc, AST_REQ_FRAME)) {
		indent(&new_state);
		fprintf(new_state.code, "struct %s_ctx *ctx "
				"= fwd_stack_alloc(&stack);\n",
				name);
	}
	else {
		indent(&new_state);
		fprintf(new_state.code, "struct %s_ctx ctx_buf;\n", name);
		indent(&new_state);
		fprintf(new_state.code, "struct %s_ctx *ctx = &ctx_buf;\n", name);
	}

	/* allocates parameter slots */
	int ret = lower_params(&new_state, proc_params(proc));
	assert(ret == 0);

	/* actually copies values into parameter slots */
	ret = lower_param_copies(&new_state, proc_params(proc));
	assert(ret == 0);

	indent(&new_state);
	fprintf(new_state.code, "ctx->global_args = args;\n");

	struct ast *block = proc_body(proc);
	ret = lower_stmt_list(&new_state, block_body(block), true);

	fprintf(new_state.code, "}\n\n");
	fprintf(new_state.ctx, "};\n\n");

	fclose(new_state.code);
	fclose(new_state.ctx);
	assert(code_buf);
	assert(ctx_buf);

	add_defn(state, code_buf);
	add_type(state, ctx_buf);

	free(start_of);
	free(proto);
	free(name);
	return ret;
}

int lower(struct scope *root)
{
	struct ast *main = file_scope_find_symbol(root, "main");
	if (!main) {
		error("no main");
		return -1;
	}

	if (main->k != AST_PROC_DEF) {
		error("main is not a procedure");
		return -1;
	}

	struct string_vec defns = string_vec_create(0);
	struct string_vec decls = string_vec_create(0);
	struct string_vec types = string_vec_create(0);
	struct proc_set procs = proc_set_create(0);

	struct state state = {
		.indent = 0,
		.ctx = NULL,
		.code = NULL,
		.defns = &defns,
		.decls = &decls,
		.types = &types,
		.procs = &procs
	};

	if (lower_proc(&state, main))
		return -1;

	/* placeholder, should really be calculated to be some maximum frame
	 * size or something */
	printf("#define FWD_FRAME_SIZE %d\n", 1024);
	printf("#include <fwd.h>\n\n");

	foreach(string_vec, s, state.types) {
		puts(*s);
		free(*s);
	}

	foreach(string_vec, s, state.decls) {
		puts(*s);
		free(*s);
	}

	foreach(string_vec, s, state.defns) {
		puts(*s);
		free(*s);
	}

	char *name = mangle2(main);

	printf("int main()\n");
	printf("{\n");
	printf("  fwd_stack_t stack = create_fwd_stack();\n");
	printf("  void *args = fwd_stack_alloc(&stack);\n");
	printf("  stack = %s(stack, args);\n", name);
	printf("  fwd_stack_free(&stack, args);\n");
	printf("  destroy_fwd_stack(&stack);\n\n");
	printf("}\n");

	/* modules require a register function of some kind, implement a stub */
	printf("int fwd_register(struct fwd_state *state, const char *name, fwd_extern_t func, fwd_type_t rtype, ...) {return 0;}\n");

	string_vec_destroy(&defns);
	string_vec_destroy(&decls);
	string_vec_destroy(&types);
	proc_set_destroy(&procs);
	free(name);
	return 0;
}