diff options
| author | Kimplul <kimi.h.kuparinen@gmail.com> | 2024-12-20 14:52:34 +0200 | 
|---|---|---|
| committer | Kimplul <kimi.h.kuparinen@gmail.com> | 2024-12-20 14:52:34 +0200 | 
| commit | 98c3d8fbc924c62e2be571ed71b22053b9e8baa3 (patch) | |
| tree | 695c7877b7802cec60037ecf98052479c9bb06b0 | |
| parent | 0f5ce98342a7742c4e3af0dd33b5b642419d5286 (diff) | |
| download | fwd-98c3d8fbc924c62e2be571ed71b22053b9e8baa3.tar.gz fwd-98c3d8fbc924c62e2be571ed71b22053b9e8baa3.zip | |
add enough type checking to compile uniq.fwd
| -rw-r--r-- | examples/uniq.fwd | 3 | ||||
| -rw-r--r-- | include/fwd/ast.h | 56 | ||||
| -rw-r--r-- | include/fwd/debug.h | 16 | ||||
| -rw-r--r-- | src/analyze.c | 298 | ||||
| -rw-r--r-- | src/ast.c | 106 | ||||
| -rw-r--r-- | src/debug.c | 91 | ||||
| -rw-r--r-- | src/scope.c | 8 | 
7 files changed, 557 insertions, 21 deletions
| diff --git a/examples/uniq.fwd b/examples/uniq.fwd index 5945d65..9eb0092 100644 --- a/examples/uniq.fwd +++ b/examples/uniq.fwd @@ -5,7 +5,8 @@ fwd_getline(  	(optional![string]) next);  fwd_some(optional![string] o, -	(string) next); +	(string) something, +	() nothing);  fwd_insert(unordered_set![string] set, string line,  	(unordered_set![string]) next); diff --git a/include/fwd/ast.h b/include/fwd/ast.h index 34fcd64..62766fb 100644 --- a/include/fwd/ast.h +++ b/include/fwd/ast.h @@ -5,6 +5,7 @@  #define AST_H  #include <stddef.h> +#include <assert.h>  #include <stdbool.h>  /** @@ -30,7 +31,7 @@ struct src_loc {  struct ast;  enum type_kind { -	TYPE_ID = 1, TYPE_CONSTRUCT, TYPE_REF, TYPE_PTR, TYPE_CALLABLE +	TYPE_ID = 1, TYPE_CONSTRUCT, TYPE_REF, TYPE_PTR, TYPE_CALLABLE, TYPE_VOID  };  struct type { @@ -115,6 +116,11 @@ enum ast_kind {  	AST_CONST_STR,  }; +enum ast_flag { +	AST_FLAG_ANALYZED	= (1 << 0), +	AST_FLAG_PREANALYZIS	= (1 << 1), +}; +  struct ast {  	enum ast_kind k;  	double d; @@ -129,8 +135,21 @@ struct ast {  	struct ast *n;  	struct src_loc loc;  	struct scope *scope; + +	struct type *t; +	enum ast_flag flags;  }; +static inline bool ast_flags(struct ast *node, enum ast_flag flags) +{ +	return node->flags & flags; +} + +static inline void ast_set_flags(struct ast *node, enum ast_flag flags) +{ +	node->flags |= flags; +} +  struct ast *gen_ast(enum ast_kind kind,                      struct ast *a0,                      struct ast *a1, @@ -241,6 +260,9 @@ static inline bool is_const(struct ast *x)  #define return_id(x, kind) *({assert((x)->k == kind); &(x)->id;})  #define return_t0(x, kind) *({assert((x)->k == kind); &(x)->t0;}) +#define tgen_void(loc) \ +	tgen_type(TYPE_VOID, NULL, NULL, loc) +  #define tid_str(x) return_id(x, TYPE_ID)  #define tgen_id(id, loc) \  	tgen_type(TYPE_ID, NULL, id, loc) @@ -355,10 +377,18 @@ static inline bool is_const(struct ast *x)  struct ast *clone_ast(struct ast *n);  struct ast *clone_ast_list(struct ast *l); +struct type *clone_type(struct type *type); +struct type *clone_type_list(struct type *types); +  void ast_dump_list(int depth, struct ast *root);  void ast_dump(int depth, struct ast *node);  void ast_append(struct ast **list, struct ast *elem); +void type_append(struct type **list, struct type *elem); + +bool types_match(struct type *a, struct type *b); +bool type_lists_match(struct type *al, struct type *bl); +  struct ast *ast_prepend(struct ast *list, struct ast *elem);  typedef int (*ast_callback_t)(struct ast *, void *); @@ -375,6 +405,7 @@ int ast_visit_list(ast_callback_t before, ast_callback_t after,   * @return Number of elements in \p list.   */  size_t ast_list_len(struct ast *list); +size_t type_list_len(struct type *list);  /**   * Get last nose in ASt list. @@ -410,4 +441,27 @@ void fix_closures(struct ast *root);  #define foreach_type(iter, types) \  	for (struct type *iter = types; iter; iter = iter->n) +static inline bool is_callable(struct type *t) +{ +	if (t->k == TYPE_PTR) +		return (tptr_base(t))->k == TYPE_CALLABLE; + +	if (t->k == TYPE_REF) +		return (tref_base(t))->k == TYPE_CALLABLE; + +	return t->k == TYPE_CALLABLE; +} + +static inline struct type *callable_types(struct type *t) +{ +	assert(is_callable(t)); +	if (t->k == TYPE_REF) +		return callable_types(tptr_base(t)); + +	if (t->k == TYPE_PTR) +		return callable_types(tref_base(t)); + +	return tcallable_args(t); +} +  #endif /* AST_H */ diff --git a/include/fwd/debug.h b/include/fwd/debug.h index 85ed9db..cc20c83 100644 --- a/include/fwd/debug.h +++ b/include/fwd/debug.h @@ -68,7 +68,7 @@ struct file_ctx {   * @param node AST node to print message with.   * @param fmt Format string. Follows standard printf() formatting.   */ -void semantic_info(struct file_ctx ctx, struct ast *node, const char *fmt, ...); +void semantic_info(struct scope *scope, struct ast *node, const char *fmt, ...);  /**   * Print warning that relates to a specific AST node. @@ -83,8 +83,7 @@ void semantic_info(struct file_ctx ctx, struct ast *node, const char *fmt, ...);   * @param node AST node to print message with.   * @param fmt Format string. Follows standard printf() formatting.   */ -void semantic_warn(struct file_ctx ctx, struct ast *node, const char *fmt, -                   ...); +void semantic_warn(struct scope *scope, struct ast *node, const char *fmt, ...);  /**   * Print warning that relates to a specific AST node. @@ -94,9 +93,9 @@ void semantic_warn(struct file_ctx ctx, struct ast *node, const char *fmt,   * @param node AST node to print message with.   * @param fmt Format string. Follows standard printf() formatting.   */ -void semantic_error(struct file_ctx ctx, struct ast *node, const char *fmt, -                    ...); -void loc_error(struct file_ctx ctx, struct src_loc loc, const char *fmt, ...); +void semantic_error(struct scope *scope, struct ast *node, const char *fmt, ...); + +void loc_error(struct scope *scope, struct src_loc loc, const char *fmt, ...);  /**   * Print internal error. @@ -135,4 +134,9 @@ struct src_issue {   */  void src_issue(struct src_issue issue, const char *err_msg, ...); +void type_mismatch(struct scope *scope, struct ast *node, +		struct type *l, struct type *r); + +const char *type_str(struct type *t); +  #endif /* FWD_DEBUG_H */ diff --git a/src/analyze.c b/src/analyze.c index 6efd60f..900e771 100644 --- a/src/analyze.c +++ b/src/analyze.c @@ -4,11 +4,307 @@  #include <fwd/analyze.h>  #include <assert.h> +struct state { +}; + +static int analyze(struct state *state, struct scope *scope, struct ast *node); +static int analyze_list(struct state *state, struct scope *scope, struct ast *nodes); + +static int analyze_type(struct state *state, struct scope *scope, struct type *type); +static int analyze_type_list(struct state *state, struct scope *scope, struct type *types); + +static int analyze_proc(struct state *state, struct scope *scope, struct ast *node) +{ +	(void)state; +	(void)scope; + +	struct scope *proc_scope = create_scope(); +	if (!proc_scope) { +		internal_error("failed allocating proc scope"); +		return -1; +	} + +	scope_add_scope(scope, proc_scope); + +	struct state proc_state = {}; +	if (analyze_list(&proc_state, proc_scope, proc_params(node))) +		return -1; + +	struct type *callable = tgen_callable(NULL, node->loc); +	if (!callable) { +		internal_error("failed allocating proc type"); +		return -1; +	} + +	foreach_node(param, proc_params(node)) { +		type_append(&callable->t0, clone_type(param->t)); +	} + +	node->t = callable; + +	return analyze(&proc_state, proc_scope, proc_body(node)); +} + +static int analyze_unop(struct state *state, struct scope *scope, struct ast *node) +{ +	assert(false); +	return 0; +} + +static int analyze_block(struct state *state, struct scope *scope, struct ast *node) +{ +	struct scope *block_scope = create_scope(); +	if (!block_scope) { +		internal_error("failed to allocate block scope"); +		return -1; +	} + +	scope_add_scope(scope, block_scope); +	if (analyze_list(state, block_scope, block_body(node))) +		return -1; + +	node->t = ast_last(block_body(node))->t; +	return 0; +} + +static int analyze_var(struct state *state, struct scope *scope, struct ast *node) +{ +	if (analyze_type(state, scope, var_type(node))) +		return -1; + +	node->t = var_type(node); +	return scope_add_var(scope, node); +} + +static int analyze_let(struct state *state, struct scope *scope, struct ast *node) +{ +	if (analyze(state, scope, let_var(node))) +		return -1; + +	if (analyze(state, scope, let_expr(node))) +		return -1; + +	struct type *l = (let_var(node))->t; +	struct type *r = (let_expr(node))->t; +	if (!types_match(l, r)) { +		type_mismatch(scope, node, l, r); +		return -1; +	} + +	node->t = l; +	return 0; +} + +static int analyze_init(struct state *state, struct scope *scope, struct ast *node) +{ +	if (analyze_type_list(state, scope, init_args(node))) +		return -1; + +	if (analyze(state, scope, init_body(node))) +		return -1; + +	/** @todo check that all parameters match, though that requires that the +	 * type is defined in fwd and not in C++, so this might have to wait a +	 * bit */ +	node->t = tgen_construct(init_id(node), init_args(node), node->loc); +	if (!node->t) { +		internal_error("failed allocating type for construct"); +		return -1; +	} + +	return 0; +} + +static int analyze_call(struct state *state, struct scope *scope, struct ast *node) +{ +	struct ast *expr = call_expr(node); +	if (analyze(state, scope, expr)) +		return -1; + +	if (!is_callable(expr->t)) { +		semantic_error(scope, node, "expected callable expression"); +		return -1; +	} + +	struct type *types = callable_types(expr->t); +	struct ast *args = call_args(node); +	if (analyze_list(state, scope, args)) +		return -1; + +	size_t expected = type_list_len(types); +	size_t got = ast_list_len(args); +	if (expected != got) { +		semantic_error(scope, node, "expected %d params, got %d", expected, got); +		return -1; +	} + +	struct type *type = types; +	foreach_node(arg, args) { +		if (!types_match(type, arg->t)) { +			type_mismatch(scope, node, type, arg->t); +			return -1; +		} + +		type = type->n; +	} + +	node->t = tgen_void(node->loc); +	if (!node->t) { +		internal_error("failed allocating void type for call"); +		return -1; +	} + +	return 0; +} + +static int analyze_id(struct state *state, struct scope *scope, struct ast *node) +{ +	struct ast *found = file_scope_find_symbol(scope, id_str(node)); +	if (!found) { +		semantic_error(scope, node, "no symbol named \"%s\"", id_str(node)); +		return -1; +	} + +	/* kind of hacky, functions are given their type before they've been +	 * analyzed fully, this enables recursion */ +	if (!found->t && !ast_flags(found, AST_FLAG_ANALYZED)) { +		assert(found->k == AST_PROC_DEF); +		/* a proc def will use its own state and scope, but pass these +		 * on in case the analysis wants to print errors or something +		 */ +		if (analyze(state, scope, found)) +			return -1; +	} + +	node->t = found->t; +	return 0; +} + +static int analyze_closure(struct state *state, struct scope *scope, struct ast *node) +{ +	struct scope *closure_scope = create_scope(); +	if (!closure_scope) { +		internal_error("failed allocating closure scope"); +		return -1; +	} + +	scope_add_scope(scope, closure_scope); + +	if (analyze_list(state, closure_scope, closure_bindings(node))) +		return -1; + +	if (analyze(state, closure_scope, closure_body(node))) +		return -1; + +	struct type *callable = tgen_callable(NULL, node->loc); +	if (!callable) { +		internal_error("failed allocating closure type"); +		return -1; +	} + +	/** @todo use analysis to figure out if this closure can be called +	 * multiple times or just once */ + +	foreach_node(binding, closure_bindings(node)) { +		type_append(&callable->t0, clone_type(binding->t)); +	} + +	node->t = callable; +	return 0; +} + +static int analyze(struct state *state, struct scope *scope, struct ast *node) +{ +	if (!node) +		return 0; + +	if (ast_flags(node, AST_FLAG_ANALYZED)) { +		assert(node->t); +		return 0; +	} + +	if (ast_flags(node, AST_FLAG_PREANALYZIS)) { +		semantic_error(scope, node, "semantic loop detected"); +		return -1; +	} + +	ast_set_flags(node, AST_FLAG_PREANALYZIS); + +	if (!node->scope) +		node->scope = scope; + +	int ret = 0; +	if (is_unop(node)) { +		ret = analyze_unop(state, scope, node); +		goto out; +	} + +	switch (node->k) { +	case AST_PROC_DEF:	ret = analyze_proc	(state, scope, node); break; +	case AST_VAR_DEF:	ret = analyze_var	(state, scope, node); break; +	case AST_BLOCK:		ret = analyze_block	(state, scope, node); break; +	case AST_LET:		ret = analyze_let	(state, scope, node); break; +	case AST_INIT:		ret = analyze_init	(state, scope, node); break; +	case AST_CALL:		ret = analyze_call	(state, scope, node); break; +	case AST_ID:		ret = analyze_id	(state, scope, node); break; +	case AST_CLOSURE:	ret = analyze_closure	(state, scope, node); break; +	default: +		   internal_error("missing ast analysis"); +		   return -1; +	} + +out: +	if (ret) +		return ret; + +	assert(node->t); +	assert(node->scope); +	ast_set_flags(node, AST_FLAG_ANALYZED); +	return 0; +} + +static int analyze_list(struct state *state, struct scope *scope, struct ast *nodes) +{ +	foreach_node(node, nodes) { +		if (analyze(state, scope, node)) +			return -1; +	} + +	return 0; +} + +static int analyze_type(struct state *state, struct scope *scope, struct type *type) +{ +	/* for now, let's just say all types are fine as they are, specified by +	 * the user. */ +	(void)state; +	(void)scope; +	(void)type; +	return 0; +} + +static int analyze_type_list(struct state *state, struct scope *scope, struct type *types) +{ +	foreach_type(type, types) { +		if (analyze_type(state, scope, type)) +			return -1; +	} + +	return 0; +} +  int analyze_root(struct scope *scope, struct ast *root)  {  	foreach_node(node, root) {  		assert(node->k == AST_PROC_DEF); -		scope_add_proc(scope, node); +		if (scope_add_proc(scope, node)) +			return -1; +	} + +	foreach_node(node, root) { +		struct state state = {}; +		if (analyze(&state, scope, node)) +			return -1;  	}  	return 0; @@ -133,6 +133,9 @@ struct type *tgen_type(enum type_kind kind,  void ast_append(struct ast **list, struct ast *elem)  { +	assert(list); +	assert(elem); +  	struct ast *cur = *list;  	if (!cur) {  		*list = elem; @@ -145,6 +148,23 @@ void ast_append(struct ast **list, struct ast *elem)  	cur->n = elem;  } +void type_append(struct type **list, struct type *elem) +{ +	assert(list); +	assert(elem); + +	struct type *cur = *list; +	if (!cur) { +		*list = elem; +		return; +	} + +	while (cur->n) +		cur = cur->n; + +	cur->n = elem; +} +  struct ast *ast_prepend(struct ast *list, struct ast *elem)  {  	elem->n = list; @@ -295,6 +315,41 @@ struct ast *clone_ast_list(struct ast *root)  	return new_root;  } +struct type *clone_type(struct type *type) +{ +	if (!type) +		return NULL; + +	struct type *new = create_empty_type(); +	if (!new) +		return NULL; + +	new->k = type->k; +	if (type->id && !(new->id = strdup(type->id))) +		return NULL; + +	if (type->t0 && !(new->t0 = clone_type_list(type->t0))) +		return NULL; + +	return new; +} + +struct type *clone_type_list(struct type *root) +{ +	struct type *n = root, *new_root = NULL, *prev = NULL; +	while (n) { +		struct type *new = clone_type(n); + +		if (prev) prev->n = new; +		else new_root = new; + +		prev = new; +		n = n->n; +	} + +	return new_root; +} +  int ast_visit(ast_callback_t before, ast_callback_t after, struct ast *n,                void *d)  { @@ -346,6 +401,17 @@ size_t ast_list_len(struct ast *node)  	return count;  } +size_t type_list_len(struct type *node) +{ +	size_t count = 0; +	while (node) { +		count++; +		node = node->n; +	} + +	return count; +} +  struct ast *ast_last(struct ast *list)  {  	if (!list) @@ -482,3 +548,43 @@ void fix_closures(struct ast *root)  		root = next;  	}  } + +bool types_match(struct type *a, struct type *b) +{ +	if (!a && !b) +		return true; + +	if (a && !b) +		return false; + +	if (!a && b) +		return false; + +	if (a->k != b->k) +		return false; + +	if (a->id && b->id && strcmp(a->id, b->id) != 0) +		return false; + +	if (!type_lists_match(a->t0, b->t0)) +		return false; + +	return true; +} + +bool type_lists_match(struct type *al, struct type *bl) +{ +	if (type_list_len(al) != type_list_len(bl)) +		return false; + +	struct type *a = al; +	struct type *b = bl; +	while (a && b) { +		if (!types_match(a, b)) +			return false; + +		a = a->n; +		b = b->n; +	} +	return true; +} diff --git a/src/debug.c b/src/debug.c index f47a022..f823442 100644 --- a/src/debug.c +++ b/src/debug.c @@ -15,6 +15,7 @@  #include <stdarg.h>  #include <fwd/debug.h> +#include <fwd/scope.h>  /**   * Get string representation of issue_level. @@ -120,7 +121,7 @@ void src_issue(struct src_issue issue, const char *err_msg, ...)  	va_end(args);  } -void semantic_error(struct file_ctx fctx, struct ast *node, +void semantic_error(struct scope *scope, struct ast *node,                      const char *fmt, ...)  {  	va_list args; @@ -128,12 +129,86 @@ void semantic_error(struct file_ctx fctx, struct ast *node,  	struct src_issue issue;  	issue.level = SRC_ERROR;  	issue.loc = node->loc; -	issue.fctx = fctx; +	issue.fctx = scope->fctx;  	_issue(issue, fmt, args);  	va_end(args);  } -void loc_error(struct file_ctx fctx, struct src_loc loc, +void type_mismatch(struct scope *scope, struct ast *node, +		struct type *l, struct type *r) +{ +	const char *ls = type_str(l); +	const char *rs = type_str(r); +	semantic_error(scope, node, "type mismatch: %s vs %s\n", ls, rs); +	free((void *)ls); +	free((void *)rs); +} + +static void _type_str(FILE *f, struct type *t); + +static void _type_list_str(FILE *f, struct type *types) +{ +	_type_str(f, types); + +	foreach_type(type, types->n) { +		fprintf(f, ", "); +		_type_str(f, type); +	} +} + +static void _type_str(FILE *f, struct type *type) +{ +	if (!type) +		return; + +	switch (type->k) { +	case TYPE_VOID: +		fprintf(f, "void"); +		break; + +	case TYPE_PTR: +		fprintf(f, "*"); +		break; + +	case TYPE_REF: +		fprintf(f, "&"); +		break; + +	case TYPE_ID: +		fprintf(f, "%s", type->id); +		break; + +	case TYPE_CALLABLE: +		fprintf(f, "("); +		_type_list_str(f, type->t0); +		fprintf(f, ")"); +		break; + +	case TYPE_CONSTRUCT: +		fprintf(f, "%s![", type->id); +		_type_list_str(f, type->t0); +		fprintf(f, "]"); +		break; +	} +} + +const char *type_str(struct type *t) +{ +	if (!t) +		return strdup("NULL"); + +	char *buf = NULL; size_t size = 0; +	FILE *memstream = open_memstream(&buf, &size); +	if (!memstream) +		return NULL; + +	_type_str(memstream, t); + +	fclose(memstream); +	return buf; +} + +void loc_error(struct scope *scope, struct src_loc loc,                 const char *fmt, ...)  {  	va_list args; @@ -141,12 +216,12 @@ void loc_error(struct file_ctx fctx, struct src_loc loc,  	struct src_issue issue;  	issue.level = SRC_ERROR;  	issue.loc = loc; -	issue.fctx = fctx; +	issue.fctx = scope->fctx;  	_issue(issue, fmt, args);  	va_end(args);  } -void semantic_warn(struct file_ctx fctx, struct ast *node, const char *fmt, +void semantic_warn(struct scope *scope, struct ast *node, const char *fmt,                     ...)  {  	va_list args; @@ -154,12 +229,12 @@ void semantic_warn(struct file_ctx fctx, struct ast *node, const char *fmt,  	struct src_issue issue;  	issue.level = SRC_WARN;  	issue.loc = node->loc; -	issue.fctx = fctx; +	issue.fctx = scope->fctx;  	_issue(issue, fmt, args);  	va_end(args);  } -void semantic_info(struct file_ctx fctx, struct ast *node, const char *fmt, +void semantic_info(struct scope *scope, struct ast *node, const char *fmt,                     ...)  {  	va_list args; @@ -167,7 +242,7 @@ void semantic_info(struct file_ctx fctx, struct ast *node, const char *fmt,  	struct src_issue issue;  	issue.level = SRC_INFO;  	issue.loc = node->loc; -	issue.fctx = fctx; +	issue.fctx = scope->fctx;  	_issue(issue, fmt, args);  	va_end(args);  } diff --git a/src/scope.c b/src/scope.c index 869e0d6..a7dfa69 100644 --- a/src/scope.c +++ b/src/scope.c @@ -102,8 +102,8 @@ int scope_add_var(struct scope *scope, struct ast *var)  {  	struct ast *exists = scope_find_symbol(scope, var_id(var));  	if (exists) { -		semantic_error(scope->fctx, var, "var redefined"); -		semantic_info(scope->fctx, exists, "previously here"); +		semantic_error(scope, var, "var redefined"); +		semantic_info(scope, exists, "previously here");  		return -1;  	} @@ -116,8 +116,8 @@ int scope_add_proc(struct scope *scope, struct ast *proc)  	assert(proc->k == AST_PROC_DEF);  	struct ast *exists = file_scope_find_symbol(scope, proc_id(proc));  	if (exists) { -		semantic_error(scope->fctx, proc, "proc redefined"); -		semantic_info(scope->fctx, exists, "previously here"); +		semantic_error(scope, proc, "proc redefined"); +		semantic_info(scope, exists, "previously here");  		return -1;  	} | 
