aboutsummaryrefslogtreecommitdiff
path: root/src/check.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/check.c')
-rw-r--r--src/check.c767
1 files changed, 767 insertions, 0 deletions
diff --git a/src/check.c b/src/check.c
new file mode 100644
index 0000000..54929a1
--- /dev/null
+++ b/src/check.c
@@ -0,0 +1,767 @@
+#include <posthaste/check.h>
+
+#define UNUSED(x) (void)x
+
+struct state {
+ struct ast *parent;
+};
+
+int analyze_visibility(struct scope *scope, struct ast *n)
+{
+ switch (n->k) {
+ case AST_FUNC_DEF:
+ if (scope_add_func(scope, n))
+ return -1;
+
+ break;
+
+ case AST_PROC_DEF:
+ if (scope_add_proc(scope, n))
+ return -1;
+
+ break;
+ default:
+ }
+
+ return 0;
+}
+
+static bool has_type(struct ast *n)
+{
+ return n->t != 0;
+}
+
+static int analyze(struct state *state, struct scope *scope, struct ast *n);
+static int analyze_list(struct state *state, struct scope *scope, struct ast *l)
+{
+ foreach_node(n, l) {
+ if (analyze(state, scope, n))
+ return -1;
+ }
+
+ return 0;
+}
+
+static struct ast *file_scope_find_analyzed(struct state *state, struct scope *scope, char *id)
+{
+ struct ast *exists = file_scope_find(scope, id);
+ if (!exists)
+ return NULL;
+
+ if (analyze(state, scope, exists))
+ return NULL;
+
+ return exists;
+}
+
+static int analyze_type(struct scope *scope, struct ast *n)
+{
+ if (!n->type) {
+ n->t = TYPE_VOID;
+ return 0;
+ }
+
+ if (same_id(n->type, "int")) {
+ n->t = TYPE_INT;
+ return 0;
+ }
+ else if (same_id(n->type, "date")) {
+ n->t = TYPE_DATE;
+ return 0;
+ }
+
+ semantic_error(scope, n, "no such type: %s", n->type);
+ return -1;
+}
+
+static int analyze_func_def(struct state *state, struct scope *scope, struct ast *f)
+{
+ UNUSED(state);
+ if (analyze_type(scope, f))
+ return -1;
+
+ /* slightly hacky, but allows recursion */
+ ast_set_flags(f, AST_FLAG_CHECKED);
+
+ struct scope *func_scope = create_scope();
+ scope_add_scope(scope, func_scope);
+ f->scope = func_scope;
+
+ struct state func_state = {.parent = f};
+ if (analyze_list(&func_state, func_scope, func_formals(f)))
+ return -1;
+
+ if (analyze_list(&func_state, func_scope, func_vars(f)))
+ return -1;
+
+ return analyze(&func_state, func_scope, func_body(f));
+}
+
+static int analyze_proc_def(struct state *state, struct scope *scope, struct ast *p)
+{
+ UNUSED(state);
+ if (analyze_type(scope, p))
+ return -1;
+
+ /* slightly hacky, but allows recursion */
+ ast_set_flags(p, AST_FLAG_CHECKED);
+
+ struct scope *proc_scope = create_scope();
+ scope_add_scope(scope, proc_scope);
+ p->scope = proc_scope;
+
+ struct state proc_state = {.parent = p};
+ if (analyze_list(&proc_state, proc_scope, proc_formals(p)))
+ return -1;
+
+ if (analyze_list(&proc_state, proc_scope, proc_vars(p)))
+ return -1;
+
+ if (analyze_list(&proc_state, proc_scope, proc_body(p)))
+ return -1;
+
+ struct ast *last = ast_last(proc_body(p));
+ if (p->t != TYPE_VOID && last->k != AST_RETURN) {
+ semantic_error(scope, p, "can't prove that proc returns");
+ return -1;
+ }
+
+ return 0;
+}
+
+static int analyze_var_def(struct state *state, struct scope *scope, struct ast *n)
+{
+ if (scope_add_var(scope, n))
+ return -1;
+
+ struct ast *expr = var_expr(n);
+ if (analyze(state, scope, expr))
+ return -1;
+
+ n->t = expr->t;
+ return 0;
+}
+
+static int analyze_formal_def(struct state *state, struct scope *scope, struct ast *n)
+{
+ UNUSED(state);
+ if (analyze_type(scope, n))
+ return -1;
+
+ if (scope_add_formal(scope, n))
+ return -1;
+
+ return 0;
+}
+
+static int analyze_neg(struct state *state, struct scope *scope, struct ast *n)
+{
+ struct ast *expr = neg_expr(n);
+ if (analyze(state, scope, expr))
+ return -1;
+
+ if (expr->t != TYPE_INT) {
+ semantic_error(scope, n,
+ "negation requires %s, got %s\n",
+ type_str(TYPE_INT),
+ type_str(expr->t));
+ return -1;
+ }
+
+ n->t = expr->t;
+ return 0;
+}
+
+static int analyze_pos(struct state *state, struct scope *scope, struct ast *p)
+{
+ struct ast *expr = pos_expr(p);
+ if (analyze(state, scope, expr))
+ return -1;
+
+ if (expr->t != TYPE_INT) {
+ semantic_error(scope, p,
+ "unary pos requires %s, got %s\n",
+ type_str(TYPE_INT),
+ type_str(expr->t));
+ return -1;
+ }
+
+ p->t = expr->t;
+ return 0;
+}
+
+static int analyze_const_date(struct state *state, struct scope *scope, struct ast *d)
+{
+ UNUSED(state);
+ UNUSED(scope);
+ /* should be in correct format as the lexer took care of it */
+ d->t = TYPE_DATE;
+ return 0;
+}
+
+static int analyze_const_int(struct state *state, struct scope *scope, struct ast *i)
+{
+ UNUSED(state);
+ UNUSED(scope);
+ i->t = TYPE_INT;
+ return 0;
+}
+
+static int analyze_const_string(struct state *state, struct scope *scope, struct ast *s)
+{
+ UNUSED(state);
+ UNUSED(scope);
+ s->t = TYPE_STRING;
+ return 0;
+}
+
+static int analyze_eq(struct state *state, struct scope *scope, struct ast *e)
+{
+ struct ast *l = eq_l(e);
+ if (analyze(state, scope, l))
+ return -1;
+
+ struct ast *r = eq_r(e);
+ if (analyze(state, scope, r))
+ return -1;
+
+ if (r->t != l->t) {
+ semantic_error(scope, e, "type mismatch: %s vs %s\n",
+ type_str(l->t),
+ type_str(r->t));
+ return -1;
+ }
+
+ e->t = TYPE_BOOL;
+ return 0;
+}
+
+static int analyze_lt(struct state *state, struct scope *scope, struct ast *e)
+{
+ struct ast *l = lt_l(e);
+ if (analyze(state, scope, l))
+ return -1;
+
+ struct ast *r = lt_r(e);
+ if (analyze(state, scope, r))
+ return -1;
+
+ if (r->t != l->t) {
+ semantic_error(scope, e, "type mismatch: %s vs %s\n",
+ type_str(l->t),
+ type_str(r->t));
+ return -1;
+ }
+
+ e->t = TYPE_BOOL;
+ return 0;
+}
+
+static int analyze_add(struct state *state, struct scope *scope, struct ast *a)
+{
+ struct ast *l = add_l(a);
+ if (analyze(state, scope, l))
+ return -1;
+
+ struct ast *r = add_r(a);
+ if (analyze(state, scope, r))
+ return -1;
+
+ if (r->t == TYPE_DATE) {
+ semantic_error(scope, r, "date not allowed as rhs to addition");
+ return -1;
+ }
+
+ if (l->t == TYPE_DATE) {
+ a->k = AST_DATE_ADD;
+ }
+
+ a->t = l->t;
+ return 0;
+}
+
+static int analyze_sub(struct state *state, struct scope *scope, struct ast *s)
+{
+ struct ast *l = sub_l(s);
+ if (analyze(state, scope, l))
+ return -1;
+
+ struct ast *r = sub_r(s);
+ if (analyze(state, scope, r))
+ return -1;
+
+ if (l->t == TYPE_DATE && r->t == TYPE_DATE) {
+ s->k = AST_DATE_DIFF;
+ s->t = TYPE_INT;
+ return 0;
+ }
+
+ if (l->t == TYPE_DATE && r->t == TYPE_INT) {
+ s->k = AST_DATE_SUB;
+ s->t = TYPE_DATE;
+ return 0;
+ }
+
+ if (l->t == TYPE_INT && r->t == TYPE_INT) {
+ s->t = TYPE_INT;
+ return 0;
+ }
+
+ semantic_error(scope, s, "illegal subtraction types: %s, %s",
+ type_str(l->t), type_str(r->t));
+ return -1;
+}
+
+static int analyze_mul(struct state *state, struct scope *scope, struct ast *m)
+{
+ struct ast *l = mul_l(m);
+ if (analyze(state, scope, l))
+ return -1;
+
+ if (l->t != TYPE_INT) {
+ semantic_error(scope, l, "expected %s, got %s",
+ type_str(TYPE_INT), type_str(l->t));
+ return -1;
+ }
+
+ struct ast *r = mul_r(m);
+ if (analyze(state, scope, r))
+ return -1;
+
+ if (r->t != TYPE_INT) {
+ semantic_error(scope, r, "expected %s, got %s",
+ type_str(TYPE_INT), type_str(r->t));
+ return -1;
+ }
+
+ m->t = TYPE_INT;
+ return 0;
+}
+
+static int analyze_div(struct state *state, struct scope *scope, struct ast *d)
+{
+ struct ast *l = div_l(d);
+ if (analyze(state, scope, l))
+ return -1;
+
+ if (l->t != TYPE_INT) {
+ semantic_error(scope, l, "expected %s, got %s",
+ type_str(TYPE_INT), type_str(l->t));
+ return -1;
+ }
+
+ struct ast *r = div_r(d);
+ if (analyze(state, scope, r))
+ return -1;
+
+ if (r->t != TYPE_INT) {
+ semantic_error(scope, r, "expected %s, got %s",
+ type_str(TYPE_INT), type_str(r->t));
+ return -1;
+ }
+
+ d->t = TYPE_INT;
+ return 0;
+}
+
+static int analyze_print(struct state *state, struct scope *scope, struct ast *p)
+{
+ struct ast *items = print_items(p);
+ struct ast *fixups = NULL, *last_fixup = NULL;
+
+ foreach_node(n, items) {
+ if (analyze(state, scope, n))
+ return -1;
+
+ if (n->t == TYPE_VOID) {
+ semantic_error(scope, n, "void not printable");
+ return -1;
+ }
+
+ /* build new list of nodes, helps a bit when lowering */
+ struct ast *fixup = fixup_print_item(n);
+ if (!fixups)
+ fixups = fixup;
+
+ if (last_fixup) {
+ last_fixup->n = fixup;
+ }
+
+ last_fixup = fixup;
+ }
+
+ print_items(p) = fixups;
+ p->t = TYPE_VOID;
+ return 0;
+}
+
+static int analyze_return(struct state *state, struct scope *scope, struct ast *r)
+{
+ struct ast *parent = state->parent;
+ if (!parent) {
+ semantic_error(scope, r, "stray return");
+ return -1;
+ }
+
+
+ if (!has_type(parent)) {
+ semantic_error(scope, r, "stray return in proc without return type");
+ return -1;
+ }
+
+ struct ast *expr = return_expr(r);
+ if (analyze(state, scope, expr))
+ return -1;
+
+ if (expr->t != parent->t) {
+ semantic_error(scope, r, "return type mismatch: %s vs %s",
+ type_str(parent->t), type_str(expr->t));
+ return -1;
+ }
+
+ r->t = expr->t;
+ return 0;
+}
+
+static int analyze_id(struct state *state, struct scope *scope, struct ast *i)
+{
+ UNUSED(state);
+ struct ast *exists = file_scope_find_analyzed(state, scope, i->id);
+ if (!exists) {
+ semantic_error(scope, i, "no such symbol");
+ return -1;
+ }
+
+ if (exists->k != AST_FORMAL_DEF && exists->k != AST_VAR_DEF) {
+ semantic_error(scope, i, "no such variable");
+ return -1;
+ }
+
+ i->t = exists->t;
+ return 0;
+}
+
+static int analyze_dot(struct state *state, struct scope *scope, struct ast *d)
+{
+ struct ast *base = dot_base(d);
+ if (analyze(state, scope, base))
+ return -1;
+
+ if (base->t != TYPE_DATE) {
+ semantic_error(scope, d, "expected %d, got %d",
+ type_str(TYPE_DATE), type_str(base->t));
+ return -1;
+ }
+
+ d->t = TYPE_INT;
+ if (same_id(d->id, "day"))
+ return 0;
+
+ if (same_id(d->id, "month"))
+ return 0;
+
+ if (same_id(d->id, "year"))
+ return 0;
+
+ semantic_error(scope, d, "illegal write attribute %s", d->id);
+ return -1;
+}
+
+static int analyze_attr(struct state *state, struct scope *scope, struct ast *d)
+{
+ struct ast *base = attr_base(d);
+ if (analyze(state, scope, base))
+ return -1;
+
+ if (base->t != TYPE_DATE) {
+ semantic_error(scope, d, "expected %d, got %d",
+ type_str(TYPE_DATE), type_str(base->t));
+ return -1;
+ }
+
+ d->t = TYPE_INT;
+ if (same_id(d->id, "day"))
+ return 0;
+
+ if (same_id(d->id, "month"))
+ return 0;
+
+ if (same_id(d->id, "year"))
+ return 0;
+
+ if (same_id(d->id, "weekday")) {
+ d->t = TYPE_STRING;
+ return 0;
+ }
+
+ if (same_id(d->id, "weeknum"))
+ return 0;
+
+ semantic_error(scope, d, "illegal write attribute %s", d->id);
+ return -1;
+}
+
+static int analyze_assign(struct state *state, struct scope *scope, struct ast *n)
+{
+ struct ast *l = assign_l(n);
+ if (analyze(state, scope, l))
+ return -1;
+
+ struct ast *r = assign_r(n);
+ if (analyze(state, scope, r))
+ return -1;
+
+ if (l->t != r->t) {
+ semantic_error(scope, n, "type mismatch: %s vs %s",
+ type_str(l->t), type_str(r->t));
+ return -1;
+ }
+
+ n->t = l->t;
+ return 0;
+}
+
+static int analyze_today(struct state *state, struct scope *scope, struct ast *c)
+{
+ UNUSED(state);
+ if (ast_list_len(func_call_args(c)) != 0) {
+ semantic_error(scope, c, "expected 0 arguments, got %zd",
+ ast_list_len(func_call_args(c)));
+ return -1;
+ }
+
+ c->k = AST_BUILTIN_CALL;
+ c->t = TYPE_DATE;
+ return 0;
+}
+
+static int analyze_proc_call(struct state *state, struct scope *scope, struct ast *c)
+{
+ struct ast *exists = file_scope_find_analyzed(state, scope, proc_call_id(c));
+ if (!exists || exists->k != AST_PROC_DEF) {
+ semantic_error(scope, c, "no such proc");
+ return -1;
+ }
+
+ struct ast *args = proc_call_args(c);
+ if (analyze_list(state, scope, args))
+ return -1;
+
+ struct ast *formal = proc_formals(exists);
+ if (ast_list_len(formal) != ast_list_len(args)) {
+ semantic_error(scope, c, "expected %s args, got %s",
+ ast_list_len(formal), ast_list_len(args));
+ return -1;
+ }
+
+ foreach_node(a, args) {
+ if (a->t != formal->t) {
+ semantic_error(scope, c, "expected %s, got %s",
+ type_str(formal->t), type_str(a->t));
+ }
+
+ formal = formal->n;
+ }
+
+ c->t = exists->t;
+ return 0;
+}
+
+static int analyze_func_call(struct state *state, struct scope *scope, struct ast *c)
+{
+ /* handle special Today() built-in */
+ if (same_id(func_call_id(c), "Today"))
+ return analyze_today(state, scope, c);
+
+ struct ast *exists = file_scope_find_analyzed(state, scope, func_call_id(c));
+ if (!exists || exists->k != AST_FUNC_DEF) {
+ semantic_error(scope, c, "no such func");
+ return -1;
+ }
+
+ struct ast *args = func_call_args(c);
+ if (analyze_list(state, scope, args))
+ return -1;
+
+ struct ast *formal = func_formals(exists);
+ if (ast_list_len(formal) != ast_list_len(args)) {
+ semantic_error(scope, c, "expected %s args, got %s",
+ ast_list_len(formal), ast_list_len(args));
+ return -1;
+ }
+
+ foreach_node(a, args) {
+ if (a->t != formal->t) {
+ semantic_error(scope, c, "expected %s, got %s",
+ type_str(formal->t), type_str(a->t));
+ }
+
+ formal = formal->n;
+ }
+
+ c->t = exists->t;
+ return 0;
+}
+
+static int analyze_until(struct state *state, struct scope *scope, struct ast *n)
+{
+ struct scope *until_scope = create_scope();
+ scope_add_scope(scope, until_scope);
+
+ struct ast *body = until_body(n);
+ if (analyze_list(state, until_scope, body))
+ return -1;
+
+ struct ast *cond = until_cond(n);
+ if (analyze(state, until_scope, cond))
+ return -1;
+
+ if (cond->t != TYPE_BOOL) {
+ semantic_error(scope, cond, "expected %s, got %s",
+ type_str(TYPE_BOOL), type_str(cond->t));
+ return -1;
+ }
+
+ n->t = TYPE_VOID;
+ return 0;
+}
+
+static int analyze_unless(struct state *state, struct scope *scope, struct ast *n)
+{
+ struct scope *unless_scope = create_scope();
+ struct scope *otherwise_scope = create_scope();
+ scope_add_scope(scope, unless_scope);
+ scope_add_scope(scope, otherwise_scope);
+
+ struct ast *body = unless_body(n);
+ if (analyze_list(state, unless_scope, body))
+ return -1;
+
+ struct ast *cond = unless_cond(n);
+ if (analyze(state, scope, cond))
+ return -1;
+
+ if (cond->t != TYPE_BOOL) {
+ semantic_error(scope, cond, "expected %s, got %s",
+ type_str(TYPE_BOOL), type_str(cond->t));
+ return -1;
+ }
+
+ struct ast *otherwise = unless_otherwise(n);
+ if (otherwise && analyze_list(state, otherwise_scope, otherwise))
+ return -1;
+
+ n->t = TYPE_VOID;
+ return 0;
+}
+
+static int analyze_unless_expr(struct state *state, struct scope *scope, struct ast *n)
+{
+ struct ast *body = unless_expr_body(n);
+ if (analyze(state, scope, body))
+ return -1;
+
+ struct ast *cond = unless_expr_cond(n);
+ if (analyze(state, scope, cond))
+ return -1;
+
+ if (cond->t != TYPE_BOOL) {
+ semantic_error(scope, cond, "expected %s, got %s",
+ type_str(TYPE_BOOL), type_str(cond->t));
+ return -1;
+ }
+
+ struct ast *otherwise = unless_expr_otherwise(n);
+ if (analyze(state, scope, otherwise))
+ return -1;
+
+ if (body->t != otherwise->t) {
+ semantic_error(scope, n, "type mismatch: %s vs %s",
+ type_str(body->t), type_str(otherwise->t));
+ return -1;
+ }
+
+ n->t = body->t;
+ return 0;
+}
+
+static int analyze(struct state *state, struct scope *scope, struct ast *n)
+{
+ int ret = 0;
+
+ if (ast_flags(n, AST_FLAG_CHECKED)) {
+ assert(has_type(n));
+ return 0;
+ }
+
+ if (ast_flags(n, AST_FLAG_INIT)) {
+ semantic_error(scope, n, "semantic loop detected");
+ return -1;
+ }
+
+ ast_set_flags(n, AST_FLAG_INIT);
+
+ /* generally what we want to happen, special cases can handle themselves */
+ if (!n->scope)
+ n->scope = scope;
+
+ switch (n->k) {
+ case AST_FUNC_DEF: ret = analyze_func_def(state, scope, n); break;
+ case AST_PROC_DEF: ret = analyze_proc_def(state, scope, n); break;
+ case AST_VAR_DEF: ret = analyze_var_def(state, scope, n); break;
+ case AST_FORMAL_DEF: ret = analyze_formal_def(state, scope, n); break;
+ case AST_NEG: ret = analyze_neg(state, scope, n); break;
+ case AST_POS: ret = analyze_pos(state, scope, n); break;
+ case AST_CONST_DATE: ret = analyze_const_date(state, scope, n); break;
+ case AST_CONST_INT: ret = analyze_const_int(state, scope, n); break;
+ case AST_CONST_STRING: ret = analyze_const_string(state, scope, n); break;
+ case AST_EQ: ret = analyze_eq(state, scope, n); break;
+ case AST_LT: ret = analyze_lt(state, scope, n); break;
+ case AST_ADD: ret = analyze_add(state, scope, n); break;
+ case AST_SUB: ret = analyze_sub(state, scope, n); break;
+ case AST_MUL: ret = analyze_mul(state, scope, n); break;
+ case AST_DIV: ret = analyze_div(state, scope, n); break;
+ case AST_PRINT: ret = analyze_print(state, scope, n); break;
+ case AST_RETURN: ret = analyze_return(state, scope, n); break;
+ case AST_ID: ret = analyze_id(state, scope, n); break;
+ case AST_DOT: ret = analyze_dot(state, scope, n); break;
+ case AST_ATTR: ret = analyze_attr(state, scope, n); break;
+ case AST_ASSIGN: ret = analyze_assign(state, scope, n); break;
+ case AST_PROC_CALL: ret = analyze_proc_call(state, scope, n); break;
+ case AST_FUNC_CALL: ret = analyze_func_call(state, scope, n); break;
+ case AST_UNTIL: ret = analyze_until(state, scope, n); break;
+ case AST_UNLESS: ret = analyze_unless(state, scope, n); break;
+ case AST_UNLESS_EXPR: ret = analyze_unless_expr(state, scope, n); break;
+ default: break;
+ }
+
+ if (ret == 0) {
+ assert(has_type(n));
+ }
+
+ ast_set_flags(n, AST_FLAG_CHECKED);
+ return ret;
+}
+
+int check(struct scope *scope, struct ast *root)
+{
+ foreach_node(n, root) {
+ if (analyze_visibility(scope, n))
+ return -1;
+ }
+
+ foreach_node(n, root) {
+ struct state state = {0};
+ if (analyze(&state, scope, n))
+ return -1;
+ }
+
+#ifdef DEBUG
+ foreach_node(n, root) {
+ printf("// after checking:\n");
+ ast_dump(0, n);
+ }
+#endif
+
+ return 0;
+}