From bf804fa1e18c717cec3944f5edea858a2f3a015d Mon Sep 17 00:00:00 2001
From: Kimplul <kimi.h.kuparinen@gmail.com>
Date: Sat, 28 Dec 2024 16:40:41 +0200
Subject: enough type checking for all examples to pass

---
 src/analyze.c | 141 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 src/ast.c     |  16 +++++++
 src/debug.c   |   2 +
 3 files changed, 158 insertions(+), 1 deletion(-)

(limited to 'src')

diff --git a/src/analyze.c b/src/analyze.c
index 900e771..10e1f0b 100644
--- a/src/analyze.c
+++ b/src/analyze.c
@@ -2,6 +2,7 @@
 /* Copyright 2024 Kim Kuparinen < kimi.h.kuparinen@gmail.com > */
 
 #include <fwd/analyze.h>
+#include <string.h>
 #include <assert.h>
 
 struct state {
@@ -47,7 +48,67 @@ static int analyze_proc(struct state *state, struct scope *scope, struct ast *no
 
 static int analyze_unop(struct state *state, struct scope *scope, struct ast *node)
 {
-	assert(false);
+	/** @todo check expr is some primitive type */
+	struct ast *expr = unop_expr(node);
+	if (analyze(state, scope, expr))
+		return -1;
+
+	node->t = expr->t;
+	return 0;
+}
+
+static int analyze_binop(struct state *state, struct scope *scope, struct ast *node)
+{
+	struct ast *lhs = binop_left(node);
+	struct ast *rhs = binop_right(node);
+
+	if (analyze(state, scope, lhs))
+		return -1;
+
+	if (analyze(state, scope, rhs))
+		return -1;
+
+
+	if (!types_match(lhs->t, rhs->t)) {
+		type_mismatch(scope, node, lhs->t, rhs->t);
+		return -1;
+	}
+
+	/** @todo check type is some primitive */
+	node->t = lhs->t;
+	return 0;
+}
+
+static int analyze_comparison(struct state *state, struct scope *scope, struct ast *node)
+{
+	struct ast *lhs = comparison_left(node);
+	struct ast *rhs = comparison_right(node);
+
+	if (analyze(state, scope, lhs))
+		return -1;
+
+	if (analyze(state, scope, rhs))
+		return -1;
+
+
+	if (!types_match(lhs->t, rhs->t)) {
+		type_mismatch(scope, node, lhs->t, rhs->t);
+		return -1;
+	}
+
+	/** @todo check type is some primitive */
+	char *tf = strdup("bool");
+	if (!tf) {
+		internal_error("failed allocating comparison bool str");
+		return -1;
+	}
+
+	node->t = tgen_id(tf, node->loc);
+	if (!node->t) {
+		internal_error("failed allocating comparison bool type");
+		return -1;
+	}
+
 	return 0;
 }
 
@@ -91,6 +152,7 @@ static int analyze_let(struct state *state, struct scope *scope, struct ast *nod
 		return -1;
 	}
 
+	/** @todo check move semantics, maybe in another pass? */
 	node->t = l;
 	return 0;
 }
@@ -154,6 +216,7 @@ static int analyze_call(struct state *state, struct scope *scope, struct ast *no
 		return -1;
 	}
 
+	/** @todo check move semantics? */
 	return 0;
 }
 
@@ -213,6 +276,69 @@ static int analyze_closure(struct state *state, struct scope *scope, struct ast
 	return 0;
 }
 
+static int analyze_int(struct state *state, struct scope *scope, struct ast *node)
+{
+	/** @todo do this properly, very hacky, bad bad bad */
+	char *i = strdup("int");
+	if (!i) {
+		internal_error("failed allocating constant int type string");
+		return -1;
+	}
+
+	node->t = tgen_id(i, node->loc);
+	if (!node->t) {
+		internal_error("failed allocating constant int type");
+		return -1;
+	}
+
+	return 0;
+}
+
+static int analyze_str(struct state *state, struct scope *scope, struct ast *node)
+{
+	/** @todo do this properly, very hacky, bad bad bad */
+	char *i = strdup("char");
+	if (!i) {
+		internal_error("failed allocating constant char type string");
+		return -1;
+	}
+
+	struct type *ch = tgen_id(i, node->loc);
+	if (!ch) {
+		internal_error("failed allocating constant char type");
+		return -1;
+	}
+
+	struct type *str = tgen_ptr(ch, node->loc);
+	if (!str) {
+		internal_error("failed allocating constant str type");
+		return -1;
+	}
+
+	node->t = str;
+	return 0;
+}
+
+static int analyze_if(struct state *state, struct scope *scope, struct ast *node)
+{
+	if (analyze(state, scope, if_cond(node)))
+		return -1;
+
+	if (analyze(state, scope, if_body(node)))
+		return -1;
+
+	if (analyze(state, scope, if_else(node)))
+		return -1;
+
+	node->t = tgen_void(node->loc);
+	if (!node->t) {
+		internal_error("failed allocating 'if' void type");
+		return -1;
+	}
+
+	return 0;
+}
+
 static int analyze(struct state *state, struct scope *scope, struct ast *node)
 {
 	if (!node)
@@ -234,6 +360,16 @@ static int analyze(struct state *state, struct scope *scope, struct ast *node)
 		node->scope = scope;
 
 	int ret = 0;
+	if (is_binop(node)) {
+		ret = analyze_binop(state, scope, node);
+		goto out;
+	}
+
+	if (is_comparison(node)) {
+		ret = analyze_comparison(state, scope, node);
+		goto out;
+	}
+
 	if (is_unop(node)) {
 		ret = analyze_unop(state, scope, node);
 		goto out;
@@ -247,7 +383,10 @@ static int analyze(struct state *state, struct scope *scope, struct ast *node)
 	case AST_INIT:		ret = analyze_init	(state, scope, node); break;
 	case AST_CALL:		ret = analyze_call	(state, scope, node); break;
 	case AST_ID:		ret = analyze_id	(state, scope, node); break;
+	case AST_IF:		ret = analyze_if	(state, scope, node); break;
 	case AST_CLOSURE:	ret = analyze_closure	(state, scope, node); break;
+	case AST_CONST_INT:	ret = analyze_int	(state, scope, node); break;
+	case AST_CONST_STR:	ret = analyze_str	(state, scope, node); break;
 	default:
 		   internal_error("missing ast analysis");
 		   return -1;
diff --git a/src/ast.c b/src/ast.c
index 17a4461..119f943 100644
--- a/src/ast.c
+++ b/src/ast.c
@@ -549,6 +549,19 @@ void fix_closures(struct ast *root)
 	}
 }
 
+static bool special_auto_very_bad(struct type *a, struct type *b)
+{
+	/** @todo massive hack, accept 'auto' as a placeholder and match it
+	 * against anything, will need to be fixed eventually */
+	if (a->k == TYPE_ID && strcmp(a->id, "auto") == 0)
+		return true;
+
+	if (b->k == TYPE_ID && strcmp(b->id, "auto") == 0)
+		return true;
+
+	return false;
+}
+
 bool types_match(struct type *a, struct type *b)
 {
 	if (!a && !b)
@@ -560,6 +573,9 @@ bool types_match(struct type *a, struct type *b)
 	if (!a && b)
 		return false;
 
+	if (special_auto_very_bad(a, b))
+		return true;
+
 	if (a->k != b->k)
 		return false;
 
diff --git a/src/debug.c b/src/debug.c
index f823442..ecb03e3 100644
--- a/src/debug.c
+++ b/src/debug.c
@@ -168,10 +168,12 @@ static void _type_str(FILE *f, struct type *type)
 
 	case TYPE_PTR:
 		fprintf(f, "*");
+		_type_list_str(f, tptr_base(type));
 		break;
 
 	case TYPE_REF:
 		fprintf(f, "&");
+		_type_list_str(f, tref_base(type));
 		break;
 
 	case TYPE_ID:
-- 
cgit v1.2.3