diff options
Diffstat (limited to 'src/check.c')
-rw-r--r-- | src/check.c | 767 |
1 files changed, 767 insertions, 0 deletions
diff --git a/src/check.c b/src/check.c new file mode 100644 index 0000000..54929a1 --- /dev/null +++ b/src/check.c @@ -0,0 +1,767 @@ +#include <posthaste/check.h> + +#define UNUSED(x) (void)x + +struct state { + struct ast *parent; +}; + +int analyze_visibility(struct scope *scope, struct ast *n) +{ + switch (n->k) { + case AST_FUNC_DEF: + if (scope_add_func(scope, n)) + return -1; + + break; + + case AST_PROC_DEF: + if (scope_add_proc(scope, n)) + return -1; + + break; + default: + } + + return 0; +} + +static bool has_type(struct ast *n) +{ + return n->t != 0; +} + +static int analyze(struct state *state, struct scope *scope, struct ast *n); +static int analyze_list(struct state *state, struct scope *scope, struct ast *l) +{ + foreach_node(n, l) { + if (analyze(state, scope, n)) + return -1; + } + + return 0; +} + +static struct ast *file_scope_find_analyzed(struct state *state, struct scope *scope, char *id) +{ + struct ast *exists = file_scope_find(scope, id); + if (!exists) + return NULL; + + if (analyze(state, scope, exists)) + return NULL; + + return exists; +} + +static int analyze_type(struct scope *scope, struct ast *n) +{ + if (!n->type) { + n->t = TYPE_VOID; + return 0; + } + + if (same_id(n->type, "int")) { + n->t = TYPE_INT; + return 0; + } + else if (same_id(n->type, "date")) { + n->t = TYPE_DATE; + return 0; + } + + semantic_error(scope, n, "no such type: %s", n->type); + return -1; +} + +static int analyze_func_def(struct state *state, struct scope *scope, struct ast *f) +{ + UNUSED(state); + if (analyze_type(scope, f)) + return -1; + + /* slightly hacky, but allows recursion */ + ast_set_flags(f, AST_FLAG_CHECKED); + + struct scope *func_scope = create_scope(); + scope_add_scope(scope, func_scope); + f->scope = func_scope; + + struct state func_state = {.parent = f}; + if (analyze_list(&func_state, func_scope, func_formals(f))) + return -1; + + if (analyze_list(&func_state, func_scope, func_vars(f))) + return -1; + + return analyze(&func_state, func_scope, func_body(f)); +} + +static int analyze_proc_def(struct state *state, struct scope *scope, struct ast *p) +{ + UNUSED(state); + if (analyze_type(scope, p)) + return -1; + + /* slightly hacky, but allows recursion */ + ast_set_flags(p, AST_FLAG_CHECKED); + + struct scope *proc_scope = create_scope(); + scope_add_scope(scope, proc_scope); + p->scope = proc_scope; + + struct state proc_state = {.parent = p}; + if (analyze_list(&proc_state, proc_scope, proc_formals(p))) + return -1; + + if (analyze_list(&proc_state, proc_scope, proc_vars(p))) + return -1; + + if (analyze_list(&proc_state, proc_scope, proc_body(p))) + return -1; + + struct ast *last = ast_last(proc_body(p)); + if (p->t != TYPE_VOID && last->k != AST_RETURN) { + semantic_error(scope, p, "can't prove that proc returns"); + return -1; + } + + return 0; +} + +static int analyze_var_def(struct state *state, struct scope *scope, struct ast *n) +{ + if (scope_add_var(scope, n)) + return -1; + + struct ast *expr = var_expr(n); + if (analyze(state, scope, expr)) + return -1; + + n->t = expr->t; + return 0; +} + +static int analyze_formal_def(struct state *state, struct scope *scope, struct ast *n) +{ + UNUSED(state); + if (analyze_type(scope, n)) + return -1; + + if (scope_add_formal(scope, n)) + return -1; + + return 0; +} + +static int analyze_neg(struct state *state, struct scope *scope, struct ast *n) +{ + struct ast *expr = neg_expr(n); + if (analyze(state, scope, expr)) + return -1; + + if (expr->t != TYPE_INT) { + semantic_error(scope, n, + "negation requires %s, got %s\n", + type_str(TYPE_INT), + type_str(expr->t)); + return -1; + } + + n->t = expr->t; + return 0; +} + +static int analyze_pos(struct state *state, struct scope *scope, struct ast *p) +{ + struct ast *expr = pos_expr(p); + if (analyze(state, scope, expr)) + return -1; + + if (expr->t != TYPE_INT) { + semantic_error(scope, p, + "unary pos requires %s, got %s\n", + type_str(TYPE_INT), + type_str(expr->t)); + return -1; + } + + p->t = expr->t; + return 0; +} + +static int analyze_const_date(struct state *state, struct scope *scope, struct ast *d) +{ + UNUSED(state); + UNUSED(scope); + /* should be in correct format as the lexer took care of it */ + d->t = TYPE_DATE; + return 0; +} + +static int analyze_const_int(struct state *state, struct scope *scope, struct ast *i) +{ + UNUSED(state); + UNUSED(scope); + i->t = TYPE_INT; + return 0; +} + +static int analyze_const_string(struct state *state, struct scope *scope, struct ast *s) +{ + UNUSED(state); + UNUSED(scope); + s->t = TYPE_STRING; + return 0; +} + +static int analyze_eq(struct state *state, struct scope *scope, struct ast *e) +{ + struct ast *l = eq_l(e); + if (analyze(state, scope, l)) + return -1; + + struct ast *r = eq_r(e); + if (analyze(state, scope, r)) + return -1; + + if (r->t != l->t) { + semantic_error(scope, e, "type mismatch: %s vs %s\n", + type_str(l->t), + type_str(r->t)); + return -1; + } + + e->t = TYPE_BOOL; + return 0; +} + +static int analyze_lt(struct state *state, struct scope *scope, struct ast *e) +{ + struct ast *l = lt_l(e); + if (analyze(state, scope, l)) + return -1; + + struct ast *r = lt_r(e); + if (analyze(state, scope, r)) + return -1; + + if (r->t != l->t) { + semantic_error(scope, e, "type mismatch: %s vs %s\n", + type_str(l->t), + type_str(r->t)); + return -1; + } + + e->t = TYPE_BOOL; + return 0; +} + +static int analyze_add(struct state *state, struct scope *scope, struct ast *a) +{ + struct ast *l = add_l(a); + if (analyze(state, scope, l)) + return -1; + + struct ast *r = add_r(a); + if (analyze(state, scope, r)) + return -1; + + if (r->t == TYPE_DATE) { + semantic_error(scope, r, "date not allowed as rhs to addition"); + return -1; + } + + if (l->t == TYPE_DATE) { + a->k = AST_DATE_ADD; + } + + a->t = l->t; + return 0; +} + +static int analyze_sub(struct state *state, struct scope *scope, struct ast *s) +{ + struct ast *l = sub_l(s); + if (analyze(state, scope, l)) + return -1; + + struct ast *r = sub_r(s); + if (analyze(state, scope, r)) + return -1; + + if (l->t == TYPE_DATE && r->t == TYPE_DATE) { + s->k = AST_DATE_DIFF; + s->t = TYPE_INT; + return 0; + } + + if (l->t == TYPE_DATE && r->t == TYPE_INT) { + s->k = AST_DATE_SUB; + s->t = TYPE_DATE; + return 0; + } + + if (l->t == TYPE_INT && r->t == TYPE_INT) { + s->t = TYPE_INT; + return 0; + } + + semantic_error(scope, s, "illegal subtraction types: %s, %s", + type_str(l->t), type_str(r->t)); + return -1; +} + +static int analyze_mul(struct state *state, struct scope *scope, struct ast *m) +{ + struct ast *l = mul_l(m); + if (analyze(state, scope, l)) + return -1; + + if (l->t != TYPE_INT) { + semantic_error(scope, l, "expected %s, got %s", + type_str(TYPE_INT), type_str(l->t)); + return -1; + } + + struct ast *r = mul_r(m); + if (analyze(state, scope, r)) + return -1; + + if (r->t != TYPE_INT) { + semantic_error(scope, r, "expected %s, got %s", + type_str(TYPE_INT), type_str(r->t)); + return -1; + } + + m->t = TYPE_INT; + return 0; +} + +static int analyze_div(struct state *state, struct scope *scope, struct ast *d) +{ + struct ast *l = div_l(d); + if (analyze(state, scope, l)) + return -1; + + if (l->t != TYPE_INT) { + semantic_error(scope, l, "expected %s, got %s", + type_str(TYPE_INT), type_str(l->t)); + return -1; + } + + struct ast *r = div_r(d); + if (analyze(state, scope, r)) + return -1; + + if (r->t != TYPE_INT) { + semantic_error(scope, r, "expected %s, got %s", + type_str(TYPE_INT), type_str(r->t)); + return -1; + } + + d->t = TYPE_INT; + return 0; +} + +static int analyze_print(struct state *state, struct scope *scope, struct ast *p) +{ + struct ast *items = print_items(p); + struct ast *fixups = NULL, *last_fixup = NULL; + + foreach_node(n, items) { + if (analyze(state, scope, n)) + return -1; + + if (n->t == TYPE_VOID) { + semantic_error(scope, n, "void not printable"); + return -1; + } + + /* build new list of nodes, helps a bit when lowering */ + struct ast *fixup = fixup_print_item(n); + if (!fixups) + fixups = fixup; + + if (last_fixup) { + last_fixup->n = fixup; + } + + last_fixup = fixup; + } + + print_items(p) = fixups; + p->t = TYPE_VOID; + return 0; +} + +static int analyze_return(struct state *state, struct scope *scope, struct ast *r) +{ + struct ast *parent = state->parent; + if (!parent) { + semantic_error(scope, r, "stray return"); + return -1; + } + + + if (!has_type(parent)) { + semantic_error(scope, r, "stray return in proc without return type"); + return -1; + } + + struct ast *expr = return_expr(r); + if (analyze(state, scope, expr)) + return -1; + + if (expr->t != parent->t) { + semantic_error(scope, r, "return type mismatch: %s vs %s", + type_str(parent->t), type_str(expr->t)); + return -1; + } + + r->t = expr->t; + return 0; +} + +static int analyze_id(struct state *state, struct scope *scope, struct ast *i) +{ + UNUSED(state); + struct ast *exists = file_scope_find_analyzed(state, scope, i->id); + if (!exists) { + semantic_error(scope, i, "no such symbol"); + return -1; + } + + if (exists->k != AST_FORMAL_DEF && exists->k != AST_VAR_DEF) { + semantic_error(scope, i, "no such variable"); + return -1; + } + + i->t = exists->t; + return 0; +} + +static int analyze_dot(struct state *state, struct scope *scope, struct ast *d) +{ + struct ast *base = dot_base(d); + if (analyze(state, scope, base)) + return -1; + + if (base->t != TYPE_DATE) { + semantic_error(scope, d, "expected %d, got %d", + type_str(TYPE_DATE), type_str(base->t)); + return -1; + } + + d->t = TYPE_INT; + if (same_id(d->id, "day")) + return 0; + + if (same_id(d->id, "month")) + return 0; + + if (same_id(d->id, "year")) + return 0; + + semantic_error(scope, d, "illegal write attribute %s", d->id); + return -1; +} + +static int analyze_attr(struct state *state, struct scope *scope, struct ast *d) +{ + struct ast *base = attr_base(d); + if (analyze(state, scope, base)) + return -1; + + if (base->t != TYPE_DATE) { + semantic_error(scope, d, "expected %d, got %d", + type_str(TYPE_DATE), type_str(base->t)); + return -1; + } + + d->t = TYPE_INT; + if (same_id(d->id, "day")) + return 0; + + if (same_id(d->id, "month")) + return 0; + + if (same_id(d->id, "year")) + return 0; + + if (same_id(d->id, "weekday")) { + d->t = TYPE_STRING; + return 0; + } + + if (same_id(d->id, "weeknum")) + return 0; + + semantic_error(scope, d, "illegal write attribute %s", d->id); + return -1; +} + +static int analyze_assign(struct state *state, struct scope *scope, struct ast *n) +{ + struct ast *l = assign_l(n); + if (analyze(state, scope, l)) + return -1; + + struct ast *r = assign_r(n); + if (analyze(state, scope, r)) + return -1; + + if (l->t != r->t) { + semantic_error(scope, n, "type mismatch: %s vs %s", + type_str(l->t), type_str(r->t)); + return -1; + } + + n->t = l->t; + return 0; +} + +static int analyze_today(struct state *state, struct scope *scope, struct ast *c) +{ + UNUSED(state); + if (ast_list_len(func_call_args(c)) != 0) { + semantic_error(scope, c, "expected 0 arguments, got %zd", + ast_list_len(func_call_args(c))); + return -1; + } + + c->k = AST_BUILTIN_CALL; + c->t = TYPE_DATE; + return 0; +} + +static int analyze_proc_call(struct state *state, struct scope *scope, struct ast *c) +{ + struct ast *exists = file_scope_find_analyzed(state, scope, proc_call_id(c)); + if (!exists || exists->k != AST_PROC_DEF) { + semantic_error(scope, c, "no such proc"); + return -1; + } + + struct ast *args = proc_call_args(c); + if (analyze_list(state, scope, args)) + return -1; + + struct ast *formal = proc_formals(exists); + if (ast_list_len(formal) != ast_list_len(args)) { + semantic_error(scope, c, "expected %s args, got %s", + ast_list_len(formal), ast_list_len(args)); + return -1; + } + + foreach_node(a, args) { + if (a->t != formal->t) { + semantic_error(scope, c, "expected %s, got %s", + type_str(formal->t), type_str(a->t)); + } + + formal = formal->n; + } + + c->t = exists->t; + return 0; +} + +static int analyze_func_call(struct state *state, struct scope *scope, struct ast *c) +{ + /* handle special Today() built-in */ + if (same_id(func_call_id(c), "Today")) + return analyze_today(state, scope, c); + + struct ast *exists = file_scope_find_analyzed(state, scope, func_call_id(c)); + if (!exists || exists->k != AST_FUNC_DEF) { + semantic_error(scope, c, "no such func"); + return -1; + } + + struct ast *args = func_call_args(c); + if (analyze_list(state, scope, args)) + return -1; + + struct ast *formal = func_formals(exists); + if (ast_list_len(formal) != ast_list_len(args)) { + semantic_error(scope, c, "expected %s args, got %s", + ast_list_len(formal), ast_list_len(args)); + return -1; + } + + foreach_node(a, args) { + if (a->t != formal->t) { + semantic_error(scope, c, "expected %s, got %s", + type_str(formal->t), type_str(a->t)); + } + + formal = formal->n; + } + + c->t = exists->t; + return 0; +} + +static int analyze_until(struct state *state, struct scope *scope, struct ast *n) +{ + struct scope *until_scope = create_scope(); + scope_add_scope(scope, until_scope); + + struct ast *body = until_body(n); + if (analyze_list(state, until_scope, body)) + return -1; + + struct ast *cond = until_cond(n); + if (analyze(state, until_scope, cond)) + return -1; + + if (cond->t != TYPE_BOOL) { + semantic_error(scope, cond, "expected %s, got %s", + type_str(TYPE_BOOL), type_str(cond->t)); + return -1; + } + + n->t = TYPE_VOID; + return 0; +} + +static int analyze_unless(struct state *state, struct scope *scope, struct ast *n) +{ + struct scope *unless_scope = create_scope(); + struct scope *otherwise_scope = create_scope(); + scope_add_scope(scope, unless_scope); + scope_add_scope(scope, otherwise_scope); + + struct ast *body = unless_body(n); + if (analyze_list(state, unless_scope, body)) + return -1; + + struct ast *cond = unless_cond(n); + if (analyze(state, scope, cond)) + return -1; + + if (cond->t != TYPE_BOOL) { + semantic_error(scope, cond, "expected %s, got %s", + type_str(TYPE_BOOL), type_str(cond->t)); + return -1; + } + + struct ast *otherwise = unless_otherwise(n); + if (otherwise && analyze_list(state, otherwise_scope, otherwise)) + return -1; + + n->t = TYPE_VOID; + return 0; +} + +static int analyze_unless_expr(struct state *state, struct scope *scope, struct ast *n) +{ + struct ast *body = unless_expr_body(n); + if (analyze(state, scope, body)) + return -1; + + struct ast *cond = unless_expr_cond(n); + if (analyze(state, scope, cond)) + return -1; + + if (cond->t != TYPE_BOOL) { + semantic_error(scope, cond, "expected %s, got %s", + type_str(TYPE_BOOL), type_str(cond->t)); + return -1; + } + + struct ast *otherwise = unless_expr_otherwise(n); + if (analyze(state, scope, otherwise)) + return -1; + + if (body->t != otherwise->t) { + semantic_error(scope, n, "type mismatch: %s vs %s", + type_str(body->t), type_str(otherwise->t)); + return -1; + } + + n->t = body->t; + return 0; +} + +static int analyze(struct state *state, struct scope *scope, struct ast *n) +{ + int ret = 0; + + if (ast_flags(n, AST_FLAG_CHECKED)) { + assert(has_type(n)); + return 0; + } + + if (ast_flags(n, AST_FLAG_INIT)) { + semantic_error(scope, n, "semantic loop detected"); + return -1; + } + + ast_set_flags(n, AST_FLAG_INIT); + + /* generally what we want to happen, special cases can handle themselves */ + if (!n->scope) + n->scope = scope; + + switch (n->k) { + case AST_FUNC_DEF: ret = analyze_func_def(state, scope, n); break; + case AST_PROC_DEF: ret = analyze_proc_def(state, scope, n); break; + case AST_VAR_DEF: ret = analyze_var_def(state, scope, n); break; + case AST_FORMAL_DEF: ret = analyze_formal_def(state, scope, n); break; + case AST_NEG: ret = analyze_neg(state, scope, n); break; + case AST_POS: ret = analyze_pos(state, scope, n); break; + case AST_CONST_DATE: ret = analyze_const_date(state, scope, n); break; + case AST_CONST_INT: ret = analyze_const_int(state, scope, n); break; + case AST_CONST_STRING: ret = analyze_const_string(state, scope, n); break; + case AST_EQ: ret = analyze_eq(state, scope, n); break; + case AST_LT: ret = analyze_lt(state, scope, n); break; + case AST_ADD: ret = analyze_add(state, scope, n); break; + case AST_SUB: ret = analyze_sub(state, scope, n); break; + case AST_MUL: ret = analyze_mul(state, scope, n); break; + case AST_DIV: ret = analyze_div(state, scope, n); break; + case AST_PRINT: ret = analyze_print(state, scope, n); break; + case AST_RETURN: ret = analyze_return(state, scope, n); break; + case AST_ID: ret = analyze_id(state, scope, n); break; + case AST_DOT: ret = analyze_dot(state, scope, n); break; + case AST_ATTR: ret = analyze_attr(state, scope, n); break; + case AST_ASSIGN: ret = analyze_assign(state, scope, n); break; + case AST_PROC_CALL: ret = analyze_proc_call(state, scope, n); break; + case AST_FUNC_CALL: ret = analyze_func_call(state, scope, n); break; + case AST_UNTIL: ret = analyze_until(state, scope, n); break; + case AST_UNLESS: ret = analyze_unless(state, scope, n); break; + case AST_UNLESS_EXPR: ret = analyze_unless_expr(state, scope, n); break; + default: break; + } + + if (ret == 0) { + assert(has_type(n)); + } + + ast_set_flags(n, AST_FLAG_CHECKED); + return ret; +} + +int check(struct scope *scope, struct ast *root) +{ + foreach_node(n, root) { + if (analyze_visibility(scope, n)) + return -1; + } + + foreach_node(n, root) { + struct state state = {0}; + if (analyze(&state, scope, n)) + return -1; + } + +#ifdef DEBUG + foreach_node(n, root) { + printf("// after checking:\n"); + ast_dump(0, n); + } +#endif + + return 0; +} |