/* SPDX-License-Identifier: copyleft-next-0.3.1 */
/* Copyright 2024 Kim Kuparinen < kimi.h.kuparinen@gmail.com > */

#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <stdarg.h>
#include <assert.h>

#include <fwd/lower.h>
#include <fwd/scope.h>
#include <fwd/vec.h>

struct state {
	long indent;
};

static void increase_indent(struct state *state)
{
	state->indent++;
}

static void decrease_indent(struct state *state)
{
	state->indent--;
}

static void indent(struct state *state)
{
	for (long i = 0; i < state->indent; ++i)
		putchar(' ');
}

static int lower_expr(struct state *state, struct ast *expr);
static int lower_block(struct state *state, struct ast *block);
static int lower_closure(struct state *state, struct ast *closure);

static int lower_types(struct type *types);

static int lower_binop(struct state *state, struct ast *binop)
{
	printf("(");
	if (lower_expr(state, binop_left(binop)))
		return -1;

	switch (binop->k) {
	default:
		internal_error("missing binop lowering");
		return -1;
	}

	if (lower_expr(state, binop_right(binop)))
		return -1;

	printf(")");
	return 0;
}

static int lower_comparison(struct state *state, struct ast *comp)
{
	printf("(");
	if (lower_expr(state, comparison_left(comp)))
		return -1;

	switch (comp->k) {
	default:
		internal_error("missing comparison lowering");
		return -1;
	}

	if (lower_expr(state, comparison_right(comp)))
		return -1;

	printf(")");
	return 0;
}

static int lower_exprs(struct state *state, struct ast *exprs)
{
	if (!exprs)
		return 0;

	if (lower_expr(state, exprs))
		return -1;

	foreach_node(expr, exprs->n) {
		printf(", ");
		if (lower_expr(state, expr))
			return -1;
	}

	return 0;
}

static int lower_type(struct type *type)
{
	if (type->k == TYPE_ID) {
		printf("%s", tid_str(type));
		return 0;
	}

	assert(type->k == TYPE_CONSTRUCT);
	printf("%s", tconstruct_id(type));
	printf("<");

	if (lower_types(tconstruct_args(type)))
		return -1;

	printf(">");
	return 0;
}

static int lower_types(struct type *types)
{
	if (!types)
		return 0;

	if (lower_type(types))
		return -1;

	foreach_type(type, types->n) {
		printf(", ");
		if (lower_type(type))
			return -1;
	}

	return 0;
}

static int lower_init(struct state *state, struct ast *init)
{
	printf("%s", init_id(init));

	if (init_args(init)) {
		printf("<");
		if (lower_types(init_args(init)))
			return -1;

		printf(">");
	}

	printf("{");

	if (lower_exprs(state, init_body(init)))
		return -1;

	printf("}");
	return 0;
}

static int lower_expr(struct state *state, struct ast *expr)
{
	if (is_binop(expr))
		return lower_binop(state, expr);

	if (is_comparison(expr))
		return lower_comparison(state, expr);

	switch (expr->k) {
	case AST_ID:            printf("%s", id_str(expr)); break;
	case AST_CONST_INT:     printf("%lld", int_val(expr)); break;
	case AST_CONST_FLOAT:   printf("%f", float_val(expr)); break;
	case AST_CONST_BOOL:    printf("%s", bool_val(expr) ? "true" : "false");
		break;
	case AST_CONST_STR:     printf("\"%s\"", str_val(expr)); break;
	case AST_CONST_CHAR:    printf("'%c'", (char)char_val(expr)); break;
	case AST_INIT:          return lower_init(state, expr);
	case AST_CLOSURE:       return lower_closure(state, expr);
	default:
		internal_error("missing expr lowering");
		return -1;
	}

	return 0;
}

static int lower_move(struct state *state, struct ast *move)
{
	if (move->k == AST_ID) {
		/** @todo once I start messing about with references, moves
		 * should only be outputted for parameters that take ownership
		 */
		printf("move(%s)", id_str(move));
		return 0;
	}

	return lower_expr(state, move);
}

static int lower_moves(struct state *state, struct ast *moves)
{
	if (!moves)
		return 0;

	if (lower_move(state, moves))
		return -1;

	foreach_node(move, moves->n) {
		printf(", ");
		if (lower_move(state, move))
			return -1;
	}

	return 0;
}

static int lower_call(struct state *state, struct ast *call)
{
	if (lower_expr(state, call_expr(call)))
		return -1;

	printf("(");

	if (lower_moves(state, call_args(call)))
		return -1;

	printf(");\n");
	return 0;
}

static int lower_let(struct state *state, struct ast *let)
{
	printf("auto %s = ", let_id(let));
	if (lower_expr(state, let_expr(let)))
		return -1;

	printf(";\n");
	return 0;
}

static int lower_statement(struct state *state, struct ast *stmt)
{
	switch (stmt->k) {
	case AST_LET: return lower_let(state, stmt);
	case AST_CALL: return lower_call(state, stmt);
	default:
		internal_error("missing statement lowering");
		return -1;
	}

	return 0;
}

static int lower_block(struct state *state, struct ast *block)
{
	printf("{\n");
	increase_indent(state);

	foreach_node(stmt, block_body(block)) {
		indent(state);

		if (lower_statement(state, stmt))
			return -1;
	}

	decrease_indent(state);

	indent(state);
	printf("}");
	return 0;
}

static int lower_params(struct ast *params)
{
	if (!params)
		return 0;

	printf("auto %s", var_id(params));

	foreach_node(param, params->n) {
		printf(", auto %s", var_id(param));
	}

	return 0;
}

static int lower_closure(struct state *state, struct ast *closure)
{
	printf("[&](");
	if (lower_params(closure_bindings(closure)))
		return -1;

	printf(")");

	if (lower_block(state, closure_body(closure)))
		return -1;

	return 0;
}

static int lower_proto(struct ast *proc)
{
	if (strcmp("main", proc_id(proc)) == 0)
		printf("int ");
	else
		printf("void ");

	printf("%s(", proc_id(proc));

	if (lower_params(proc_params(proc)))
		return -1;

	printf(");\n\n");
	return 0;
}

static int lower_proc(struct ast *proc)
{
	if (strcmp("main", proc_id(proc)) == 0)
		printf("int ");
	else
		printf("void ");

	printf("%s(", proc_id(proc));

	if (lower_params(proc_params(proc)))
		return -1;

	printf(")\n");

	struct state state = {0};
	if (lower_block(&state, proc_body(proc)))
		return -1;

	printf("\n\n");
	return 0;
}

int lower(struct scope *root)
{
	printf("#include <fwdlib.hpp>\n");

	foreach_visible(visible, root->symbols) {
		struct ast *proc = visible->node;
		assert(proc->k == AST_PROC_DEF);
		if (lower_proto(proc))
			return -1;
	}

	foreach_visible(visible, root->symbols) {
		struct ast *proc = visible->node;
		if (lower_proc(proc))
			return -1;
	}

	return 0;
}