diff options
Diffstat (limited to 'src/check.c')
| -rw-r--r-- | src/check.c | 767 | 
1 files changed, 767 insertions, 0 deletions
| diff --git a/src/check.c b/src/check.c new file mode 100644 index 0000000..54929a1 --- /dev/null +++ b/src/check.c @@ -0,0 +1,767 @@ +#include <posthaste/check.h> + +#define UNUSED(x) (void)x + +struct state { +	struct ast *parent; +}; + +int analyze_visibility(struct scope *scope, struct ast *n) +{ +	switch (n->k) { +	case AST_FUNC_DEF: +		if (scope_add_func(scope, n)) +			return -1; + +		break; + +	case AST_PROC_DEF: +		if (scope_add_proc(scope, n)) +			return -1; + +		break; +	default: +	} + +	return 0; +} + +static bool has_type(struct ast *n) +{ +	return n->t != 0; +} + +static int analyze(struct state *state, struct scope *scope, struct ast *n); +static int analyze_list(struct state *state, struct scope *scope, struct ast *l) +{ +	foreach_node(n, l) { +		if (analyze(state, scope, n)) +			return -1; +	} + +	return 0; +} + +static struct ast *file_scope_find_analyzed(struct state *state, struct scope *scope, char *id) +{ +	struct ast *exists = file_scope_find(scope, id); +	if (!exists) +		return NULL; + +	if (analyze(state, scope, exists)) +		return NULL; + +	return exists; +} + +static int analyze_type(struct scope *scope, struct ast *n) +{ +	if (!n->type) { +		n->t = TYPE_VOID; +		return 0; +	} + +	if (same_id(n->type, "int")) { +		n->t = TYPE_INT; +		return 0; +	} +	else if (same_id(n->type, "date")) { +		n->t = TYPE_DATE; +		return 0; +	} + +	semantic_error(scope, n, "no such type: %s", n->type); +	return -1; +} + +static int analyze_func_def(struct state *state, struct scope *scope, struct ast *f) +{ +	UNUSED(state); +	if (analyze_type(scope, f)) +		return -1; + +	/* slightly hacky, but allows recursion */ +	ast_set_flags(f, AST_FLAG_CHECKED); + +	struct scope *func_scope = create_scope(); +	scope_add_scope(scope, func_scope); +	f->scope = func_scope; + +	struct state func_state = {.parent = f}; +	if (analyze_list(&func_state, func_scope, func_formals(f))) +		return -1; + +	if (analyze_list(&func_state, func_scope, func_vars(f))) +		return -1; + +	return analyze(&func_state, func_scope, func_body(f)); +} + +static int analyze_proc_def(struct state *state, struct scope *scope, struct ast *p) +{ +	UNUSED(state); +	if (analyze_type(scope, p)) +		return -1; + +	/* slightly hacky, but allows recursion */ +	ast_set_flags(p, AST_FLAG_CHECKED); + +	struct scope *proc_scope = create_scope(); +	scope_add_scope(scope, proc_scope); +	p->scope = proc_scope; + +	struct state proc_state = {.parent = p}; +	if (analyze_list(&proc_state, proc_scope, proc_formals(p))) +		return -1; + +	if (analyze_list(&proc_state, proc_scope, proc_vars(p))) +		return -1; + +	if (analyze_list(&proc_state, proc_scope, proc_body(p))) +		return -1; + +	struct ast *last = ast_last(proc_body(p)); +	if (p->t != TYPE_VOID && last->k != AST_RETURN) { +		semantic_error(scope, p, "can't prove that proc returns"); +		return -1; +	} + +	return 0; +} + +static int analyze_var_def(struct state *state, struct scope *scope, struct ast *n) +{ +	if (scope_add_var(scope, n)) +		return -1; + +	struct ast *expr = var_expr(n); +	if (analyze(state, scope, expr)) +		return -1; + +	n->t = expr->t; +	return 0; +} + +static int analyze_formal_def(struct state *state, struct scope *scope, struct ast *n) +{ +	UNUSED(state); +	if (analyze_type(scope, n)) +		return -1; + +	if (scope_add_formal(scope, n)) +		return -1; + +	return 0; +} + +static int analyze_neg(struct state *state, struct scope *scope, struct ast *n) +{ +	struct ast *expr = neg_expr(n); +	if (analyze(state, scope, expr)) +		return -1; + +	if (expr->t != TYPE_INT) { +		semantic_error(scope, n, +				"negation requires %s, got %s\n", +				type_str(TYPE_INT), +				type_str(expr->t)); +		return -1; +	} + +	n->t = expr->t; +	return 0; +} + +static int analyze_pos(struct state *state, struct scope *scope, struct ast *p) +{ +	struct ast *expr = pos_expr(p); +	if (analyze(state, scope, expr)) +		return -1; + +	if (expr->t != TYPE_INT) { +		semantic_error(scope, p, +				"unary pos requires %s, got %s\n", +				type_str(TYPE_INT), +				type_str(expr->t)); +		return -1; +	} + +	p->t = expr->t; +	return 0; +} + +static int analyze_const_date(struct state *state, struct scope *scope, struct ast *d) +{ +	UNUSED(state); +	UNUSED(scope); +	/* should be in correct format as the lexer took care of it */ +	d->t = TYPE_DATE; +	return 0; +} + +static int analyze_const_int(struct state *state, struct scope *scope, struct ast *i) +{ +	UNUSED(state); +	UNUSED(scope); +	i->t = TYPE_INT; +	return 0; +} + +static int analyze_const_string(struct state *state, struct scope *scope, struct ast *s) +{ +	UNUSED(state); +	UNUSED(scope); +	s->t = TYPE_STRING; +	return 0; +} + +static int analyze_eq(struct state *state, struct scope *scope, struct ast *e) +{ +	struct ast *l = eq_l(e); +	if (analyze(state, scope, l)) +		return -1; + +	struct ast *r = eq_r(e); +	if (analyze(state, scope, r)) +		return -1; + +	if (r->t != l->t) { +		semantic_error(scope, e, "type mismatch: %s vs %s\n", +				type_str(l->t), +				type_str(r->t)); +		return -1; +	} + +	e->t = TYPE_BOOL; +	return 0; +} + +static int analyze_lt(struct state *state, struct scope *scope, struct ast *e) +{ +	struct ast *l = lt_l(e); +	if (analyze(state, scope, l)) +		return -1; + +	struct ast *r = lt_r(e); +	if (analyze(state, scope, r)) +		return -1; + +	if (r->t != l->t) { +		semantic_error(scope, e, "type mismatch: %s vs %s\n", +				type_str(l->t), +				type_str(r->t)); +		return -1; +	} + +	e->t = TYPE_BOOL; +	return 0; +} + +static int analyze_add(struct state *state, struct scope *scope, struct ast *a) +{ +	struct ast *l = add_l(a); +	if (analyze(state, scope, l)) +		return -1; + +	struct ast *r = add_r(a); +	if (analyze(state, scope, r)) +		return -1; + +	if (r->t == TYPE_DATE) { +		semantic_error(scope, r, "date not allowed as rhs to addition"); +		return -1; +	} + +	if (l->t == TYPE_DATE) { +		a->k = AST_DATE_ADD; +	} + +	a->t = l->t; +	return 0; +} + +static int analyze_sub(struct state *state, struct scope *scope, struct ast *s) +{ +	struct ast *l = sub_l(s); +	if (analyze(state, scope, l)) +		return -1; + +	struct ast *r = sub_r(s); +	if (analyze(state, scope, r)) +		return -1; + +	if (l->t == TYPE_DATE && r->t == TYPE_DATE) { +		s->k = AST_DATE_DIFF; +		s->t = TYPE_INT; +		return 0; +	} + +	if (l->t == TYPE_DATE && r->t == TYPE_INT) { +		s->k = AST_DATE_SUB; +		s->t = TYPE_DATE; +		return 0; +	} + +	if (l->t == TYPE_INT && r->t == TYPE_INT) { +		s->t = TYPE_INT; +		return 0; +	} + +	semantic_error(scope, s, "illegal subtraction types: %s, %s", +			type_str(l->t), type_str(r->t)); +	return -1; +} + +static int analyze_mul(struct state *state, struct scope *scope, struct ast *m) +{ +	struct ast *l = mul_l(m); +	if (analyze(state, scope, l)) +		return -1; + +	if (l->t != TYPE_INT) { +		semantic_error(scope, l, "expected %s, got %s", +				type_str(TYPE_INT), type_str(l->t)); +		return -1; +	} + +	struct ast *r = mul_r(m); +	if (analyze(state, scope, r)) +		return -1; + +	if (r->t != TYPE_INT) { +		semantic_error(scope, r, "expected %s, got %s", +				type_str(TYPE_INT), type_str(r->t)); +		return -1; +	} + +	m->t = TYPE_INT; +	return 0; +} + +static int analyze_div(struct state *state, struct scope *scope, struct ast *d) +{ +	struct ast *l = div_l(d); +	if (analyze(state, scope, l)) +		return -1; + +	if (l->t != TYPE_INT) { +		semantic_error(scope, l, "expected %s, got %s", +				type_str(TYPE_INT), type_str(l->t)); +		return -1; +	} + +	struct ast *r = div_r(d); +	if (analyze(state, scope, r)) +		return -1; + +	if (r->t != TYPE_INT) { +		semantic_error(scope, r, "expected %s, got %s", +				type_str(TYPE_INT), type_str(r->t)); +		return -1; +	} + +	d->t = TYPE_INT; +	return 0; +} + +static int analyze_print(struct state *state, struct scope *scope, struct ast *p) +{ +	struct ast *items = print_items(p); +	struct ast *fixups = NULL, *last_fixup = NULL; + +	foreach_node(n, items) { +		if (analyze(state, scope, n)) +			return -1; + +		if (n->t == TYPE_VOID) { +			semantic_error(scope, n, "void not printable"); +			return -1; +		} + +		/* build new list of nodes, helps a bit when lowering */ +		struct ast *fixup = fixup_print_item(n); +		if (!fixups) +			fixups = fixup; + +		if (last_fixup) { +			last_fixup->n = fixup; +		} + +		last_fixup = fixup; +	} + +	print_items(p) = fixups; +	p->t = TYPE_VOID; +	return 0; +} + +static int analyze_return(struct state *state, struct scope *scope, struct ast *r) +{ +	struct ast *parent = state->parent; +	if (!parent) { +		semantic_error(scope, r, "stray return"); +		return -1; +	} + + +	if (!has_type(parent)) { +		semantic_error(scope, r, "stray return in proc without return type"); +		return -1; +	} + +	struct ast *expr = return_expr(r); +	if (analyze(state, scope, expr)) +		return -1; + +	if (expr->t != parent->t) { +		semantic_error(scope, r, "return type mismatch: %s vs %s", +				type_str(parent->t), type_str(expr->t)); +		return -1; +	} + +	r->t = expr->t; +	return 0; +} + +static int analyze_id(struct state *state, struct scope *scope, struct ast *i) +{ +	UNUSED(state); +	struct ast *exists = file_scope_find_analyzed(state, scope, i->id); +	if (!exists) { +		semantic_error(scope, i, "no such symbol"); +		return -1; +	} + +	if (exists->k != AST_FORMAL_DEF && exists->k != AST_VAR_DEF) { +		semantic_error(scope, i, "no such variable"); +		return -1; +	} + +	i->t = exists->t; +	return 0; +} + +static int analyze_dot(struct state *state, struct scope *scope, struct ast *d) +{ +	struct ast *base = dot_base(d); +	if (analyze(state, scope, base)) +		return -1; + +	if (base->t != TYPE_DATE) { +		semantic_error(scope, d, "expected %d, got %d", +				type_str(TYPE_DATE), type_str(base->t)); +		return -1; +	} + +	d->t = TYPE_INT; +	if (same_id(d->id, "day")) +		return 0; + +	if (same_id(d->id, "month")) +		return 0; + +	if (same_id(d->id, "year")) +		return 0; + +	semantic_error(scope, d, "illegal write attribute %s", d->id); +	return -1; +} + +static int analyze_attr(struct state *state, struct scope *scope, struct ast *d) +{ +	struct ast *base = attr_base(d); +	if (analyze(state, scope, base)) +		return -1; + +	if (base->t != TYPE_DATE) { +		semantic_error(scope, d, "expected %d, got %d", +				type_str(TYPE_DATE), type_str(base->t)); +		return -1; +	} + +	d->t = TYPE_INT; +	if (same_id(d->id, "day")) +		return 0; + +	if (same_id(d->id, "month")) +		return 0; + +	if (same_id(d->id, "year")) +		return 0; + +	if (same_id(d->id, "weekday")) { +		d->t = TYPE_STRING; +		return 0; +	} + +	if (same_id(d->id, "weeknum")) +		return 0; + +	semantic_error(scope, d, "illegal write attribute %s", d->id); +	return -1; +} + +static int analyze_assign(struct state *state, struct scope *scope, struct ast *n) +{ +	struct ast *l = assign_l(n); +	if (analyze(state, scope, l)) +		return -1; + +	struct ast *r = assign_r(n); +	if (analyze(state, scope, r)) +		return -1; + +	if (l->t != r->t) { +		semantic_error(scope, n, "type mismatch: %s vs %s", +				type_str(l->t), type_str(r->t)); +		return -1; +	} + +	n->t = l->t; +	return 0; +} + +static int analyze_today(struct state *state, struct scope *scope, struct ast *c) +{ +	UNUSED(state); +	if (ast_list_len(func_call_args(c)) != 0) { +		semantic_error(scope, c, "expected 0 arguments, got %zd", +				ast_list_len(func_call_args(c))); +		return -1; +	} + +	c->k = AST_BUILTIN_CALL; +	c->t = TYPE_DATE; +	return 0; +} + +static int analyze_proc_call(struct state *state, struct scope *scope, struct ast *c) +{ +	struct ast *exists = file_scope_find_analyzed(state, scope, proc_call_id(c)); +	if (!exists || exists->k != AST_PROC_DEF) { +		semantic_error(scope, c, "no such proc"); +		return -1; +	} + +	struct ast *args = proc_call_args(c); +	if (analyze_list(state, scope, args)) +		return -1; + +	struct ast *formal = proc_formals(exists); +	if (ast_list_len(formal) != ast_list_len(args)) { +		semantic_error(scope, c, "expected %s args, got %s", +				ast_list_len(formal), ast_list_len(args)); +		return -1; +	} + +	foreach_node(a, args) { +		if (a->t != formal->t) { +			semantic_error(scope, c, "expected %s, got %s", +					type_str(formal->t), type_str(a->t)); +		} + +		formal = formal->n; +	} + +	c->t = exists->t; +	return 0; +} + +static int analyze_func_call(struct state *state, struct scope *scope, struct ast *c) +{ +	/* handle special Today() built-in */ +	if (same_id(func_call_id(c), "Today")) +		return analyze_today(state, scope, c); + +	struct ast *exists = file_scope_find_analyzed(state, scope, func_call_id(c)); +	if (!exists || exists->k != AST_FUNC_DEF) { +		semantic_error(scope, c, "no such func"); +		return -1; +	} + +	struct ast *args = func_call_args(c); +	if (analyze_list(state, scope, args)) +		return -1; + +	struct ast *formal = func_formals(exists); +	if (ast_list_len(formal) != ast_list_len(args)) { +		semantic_error(scope, c, "expected %s args, got %s", +				ast_list_len(formal), ast_list_len(args)); +		return -1; +	} + +	foreach_node(a, args) { +		if (a->t != formal->t) { +			semantic_error(scope, c, "expected %s, got %s", +					type_str(formal->t), type_str(a->t)); +		} + +		formal = formal->n; +	} + +	c->t = exists->t; +	return 0; +} + +static int analyze_until(struct state *state, struct scope *scope, struct ast *n) +{ +	struct scope *until_scope = create_scope(); +	scope_add_scope(scope, until_scope); + +	struct ast *body = until_body(n); +	if (analyze_list(state, until_scope, body)) +		return -1; + +	struct ast *cond = until_cond(n); +	if (analyze(state, until_scope, cond)) +		return -1; + +	if (cond->t != TYPE_BOOL) { +		semantic_error(scope, cond, "expected %s, got %s", +				type_str(TYPE_BOOL), type_str(cond->t)); +		return -1; +	} + +	n->t = TYPE_VOID; +	return 0; +} + +static int analyze_unless(struct state *state, struct scope *scope, struct ast *n) +{ +	struct scope *unless_scope = create_scope(); +	struct scope *otherwise_scope = create_scope(); +	scope_add_scope(scope, unless_scope); +	scope_add_scope(scope, otherwise_scope); + +	struct ast *body = unless_body(n); +	if (analyze_list(state, unless_scope, body)) +		return -1; + +	struct ast *cond = unless_cond(n); +	if (analyze(state, scope, cond)) +		return -1; + +	if (cond->t != TYPE_BOOL) { +		semantic_error(scope, cond, "expected %s, got %s", +				type_str(TYPE_BOOL), type_str(cond->t)); +		return -1; +	} + +	struct ast *otherwise = unless_otherwise(n); +	if (otherwise && analyze_list(state, otherwise_scope, otherwise)) +		return -1; + +	n->t = TYPE_VOID; +	return 0; +} + +static int analyze_unless_expr(struct state *state, struct scope *scope, struct ast *n) +{ +	struct ast *body = unless_expr_body(n); +	if (analyze(state, scope, body)) +		return -1; + +	struct ast *cond = unless_expr_cond(n); +	if (analyze(state, scope, cond)) +		return -1; + +	if (cond->t != TYPE_BOOL) { +		semantic_error(scope, cond, "expected %s, got %s", +				type_str(TYPE_BOOL), type_str(cond->t)); +		return -1; +	} + +	struct ast *otherwise = unless_expr_otherwise(n); +	if (analyze(state, scope, otherwise)) +		return -1; + +	if (body->t != otherwise->t) { +		semantic_error(scope, n, "type mismatch: %s vs %s", +				type_str(body->t), type_str(otherwise->t)); +		return -1; +	} + +	n->t = body->t; +	return 0; +} + +static int analyze(struct state *state, struct scope *scope, struct ast *n) +{ +	int ret = 0; + +	if (ast_flags(n, AST_FLAG_CHECKED)) { +		assert(has_type(n)); +		return 0; +	} + +	if (ast_flags(n, AST_FLAG_INIT)) { +		semantic_error(scope, n, "semantic loop detected"); +		return -1; +	} + +	ast_set_flags(n, AST_FLAG_INIT); + +	/* generally what we want to happen, special cases can handle themselves */ +	if (!n->scope) +		n->scope = scope; + +	switch (n->k) { +	case AST_FUNC_DEF: ret = analyze_func_def(state, scope, n); break; +	case AST_PROC_DEF: ret = analyze_proc_def(state, scope, n); break; +	case AST_VAR_DEF: ret = analyze_var_def(state, scope, n); break; +	case AST_FORMAL_DEF: ret = analyze_formal_def(state, scope, n); break; +	case AST_NEG: ret = analyze_neg(state, scope, n); break; +	case AST_POS: ret = analyze_pos(state, scope, n); break; +	case AST_CONST_DATE: ret = analyze_const_date(state, scope, n); break; +	case AST_CONST_INT: ret = analyze_const_int(state, scope, n); break; +	case AST_CONST_STRING: ret = analyze_const_string(state, scope, n); break; +	case AST_EQ: ret = analyze_eq(state, scope, n); break; +	case AST_LT: ret = analyze_lt(state, scope, n); break; +	case AST_ADD: ret = analyze_add(state, scope, n); break; +	case AST_SUB: ret = analyze_sub(state, scope, n); break; +	case AST_MUL: ret = analyze_mul(state, scope, n); break; +	case AST_DIV: ret = analyze_div(state, scope, n); break; +	case AST_PRINT: ret = analyze_print(state, scope, n); break; +	case AST_RETURN: ret = analyze_return(state, scope, n); break; +	case AST_ID: ret = analyze_id(state, scope, n); break; +	case AST_DOT: ret = analyze_dot(state, scope, n); break; +	case AST_ATTR: ret = analyze_attr(state, scope, n); break; +	case AST_ASSIGN: ret = analyze_assign(state, scope, n); break; +	case AST_PROC_CALL: ret = analyze_proc_call(state, scope, n); break; +	case AST_FUNC_CALL: ret = analyze_func_call(state, scope, n); break; +	case AST_UNTIL: ret = analyze_until(state, scope, n); break; +	case AST_UNLESS: ret = analyze_unless(state, scope, n); break; +	case AST_UNLESS_EXPR: ret = analyze_unless_expr(state, scope, n); break; +	default: break; +	} + +	if (ret == 0) { +		assert(has_type(n)); +	} + +	ast_set_flags(n, AST_FLAG_CHECKED); +	return ret; +} + +int check(struct scope *scope, struct ast *root) +{ +	foreach_node(n, root) { +		if (analyze_visibility(scope, n)) +			return -1; +	} + +	foreach_node(n, root) { +		struct state state = {0}; +		if (analyze(&state, scope, n)) +			return -1; +	} + +#ifdef DEBUG +	foreach_node(n, root) { +		printf("// after checking:\n"); +		ast_dump(0, n); +	} +#endif + +	return 0; +} | 
