#include #include #define UNUSED(x) (void)x struct state { struct ast *parent; }; int analyze_visibility(struct scope *scope, struct ast *n) { switch (n->k) { case AST_FUNC_DEF: if (scope_add_func(scope, n)) return -1; break; case AST_PROC_DEF: if (scope_add_proc(scope, n)) return -1; break; default: } return 0; } static bool has_type(struct ast *n) { return n->t != 0; } static int analyze(struct state *state, struct scope *scope, struct ast *n); static int analyze_list(struct state *state, struct scope *scope, struct ast *l) { foreach_node(n, l) { if (analyze(state, scope, n)) return -1; } return 0; } static struct ast *file_scope_find_analyzed(struct state *state, struct scope *scope, char *id) { struct ast *exists = file_scope_find(scope, id); if (!exists) return NULL; if (analyze(state, scope, exists)) return NULL; return exists; } static int analyze_type(struct scope *scope, struct ast *n) { if (!n->type) { n->t = TYPE_VOID; return 0; } if (same_id(n->type, "int")) { n->t = TYPE_INT; return 0; } else if (same_id(n->type, "date")) { n->t = TYPE_DATE; return 0; } else if (same_id(n->type, "auto")) { n->t = TYPE_AUTO; return 0; } semantic_error(scope, n, "no such type: %s", n->type); return -1; } static int analyze_func_def(struct state *state, struct scope *scope, struct ast *f) { UNUSED(state); if (analyze_type(scope, f)) return -1; /* slightly hacky, but allows recursion */ ast_set_flags(f, AST_FLAG_CHECKED); struct scope *func_scope = create_scope(); scope_add_scope(scope, func_scope); f->scope = func_scope; struct state func_state = {.parent = f}; if (analyze_list(&func_state, func_scope, func_formals(f))) return -1; if (analyze_list(&func_state, func_scope, func_vars(f))) return -1; if (analyze(&func_state, func_scope, func_body(f))) return -1; if (f->t == TYPE_AUTO) { semantic_error(scope, f, "unable to determine return type"); return -1; } return 0; } static int analyze_proc_def(struct state *state, struct scope *scope, struct ast *p) { UNUSED(state); if (analyze_type(scope, p)) return -1; /* slightly hacky, but allows recursion */ ast_set_flags(p, AST_FLAG_CHECKED); struct scope *proc_scope = create_scope(); scope_add_scope(scope, proc_scope); p->scope = proc_scope; struct state proc_state = {.parent = p}; if (analyze_list(&proc_state, proc_scope, proc_formals(p))) return -1; if (analyze_list(&proc_state, proc_scope, proc_vars(p))) return -1; if (analyze_list(&proc_state, proc_scope, proc_body(p))) return -1; struct ast *last = ast_last(proc_body(p)); if (p->t != TYPE_VOID && last->k != AST_RETURN) { semantic_error(scope, p, "can't prove that proc returns"); return -1; } if (p->t == TYPE_AUTO) { semantic_error(scope, p, "unable to determine return type"); return -1; } return 0; } static int analyze_var_def(struct state *state, struct scope *scope, struct ast *n) { if (scope_add_var(scope, n)) return -1; struct ast *expr = var_expr(n); if (analyze(state, scope, expr)) return -1; n->t = expr->t; return 0; } static int analyze_formal_def(struct state *state, struct scope *scope, struct ast *n) { UNUSED(state); if (analyze_type(scope, n)) return -1; if (scope_add_formal(scope, n)) return -1; return 0; } static int analyze_neg(struct state *state, struct scope *scope, struct ast *n) { struct ast *expr = neg_expr(n); if (analyze(state, scope, expr)) return -1; if (expr->t != TYPE_INT) { semantic_error(scope, n, "negation requires %s, got %s\n", type_str(TYPE_INT), type_str(expr->t)); return -1; } n->t = expr->t; return 0; } static int analyze_pos(struct state *state, struct scope *scope, struct ast *p) { struct ast *expr = pos_expr(p); if (analyze(state, scope, expr)) return -1; if (expr->t != TYPE_INT) { semantic_error(scope, p, "unary pos requires %s, got %s\n", type_str(TYPE_INT), type_str(expr->t)); return -1; } p->t = expr->t; return 0; } static int analyze_const_date(struct state *state, struct scope *scope, struct ast *d) { UNUSED(state); UNUSED(scope); /* should be in correct format as the lexer took care of it */ d->t = TYPE_DATE; return 0; } static int analyze_const_int(struct state *state, struct scope *scope, struct ast *i) { UNUSED(state); UNUSED(scope); i->t = TYPE_INT; return 0; } static int analyze_const_string(struct state *state, struct scope *scope, struct ast *s) { UNUSED(state); UNUSED(scope); s->t = TYPE_STRING; return 0; } static int analyze_eq(struct state *state, struct scope *scope, struct ast *e) { struct ast *l = eq_l(e); if (analyze(state, scope, l)) return -1; struct ast *r = eq_r(e); if (analyze(state, scope, r)) return -1; if (r->t != l->t) { semantic_error(scope, e, "type mismatch: %s vs %s\n", type_str(l->t), type_str(r->t)); return -1; } e->t = TYPE_BOOL; return 0; } static int analyze_lt(struct state *state, struct scope *scope, struct ast *e) { struct ast *l = lt_l(e); if (analyze(state, scope, l)) return -1; struct ast *r = lt_r(e); if (analyze(state, scope, r)) return -1; if (r->t != l->t) { semantic_error(scope, e, "type mismatch: %s vs %s\n", type_str(l->t), type_str(r->t)); return -1; } e->t = TYPE_BOOL; return 0; } static int analyze_add(struct state *state, struct scope *scope, struct ast *a) { struct ast *l = add_l(a); if (analyze(state, scope, l)) return -1; struct ast *r = add_r(a); if (analyze(state, scope, r)) return -1; if (r->t == TYPE_DATE) { semantic_error(scope, r, "date not allowed as rhs to addition"); return -1; } if (l->t == TYPE_DATE) { a->k = AST_DATE_ADD; } a->t = l->t; return 0; } static int analyze_sub(struct state *state, struct scope *scope, struct ast *s) { struct ast *l = sub_l(s); if (analyze(state, scope, l)) return -1; struct ast *r = sub_r(s); if (analyze(state, scope, r)) return -1; if (l->t == TYPE_DATE && r->t == TYPE_DATE) { s->k = AST_DATE_DIFF; s->t = TYPE_INT; return 0; } if (l->t == TYPE_DATE && r->t == TYPE_INT) { s->k = AST_DATE_SUB; s->t = TYPE_DATE; return 0; } if (l->t == TYPE_INT && r->t == TYPE_INT) { s->t = TYPE_INT; return 0; } semantic_error(scope, s, "illegal subtraction types: %s, %s", type_str(l->t), type_str(r->t)); return -1; } static int analyze_mul(struct state *state, struct scope *scope, struct ast *m) { struct ast *l = mul_l(m); if (analyze(state, scope, l)) return -1; if (l->t != TYPE_INT) { semantic_error(scope, l, "expected %s, got %s", type_str(TYPE_INT), type_str(l->t)); return -1; } struct ast *r = mul_r(m); if (analyze(state, scope, r)) return -1; if (r->t != TYPE_INT) { semantic_error(scope, r, "expected %s, got %s", type_str(TYPE_INT), type_str(r->t)); return -1; } m->t = TYPE_INT; return 0; } static int analyze_div(struct state *state, struct scope *scope, struct ast *d) { struct ast *l = div_l(d); if (analyze(state, scope, l)) return -1; if (l->t != TYPE_INT) { semantic_error(scope, l, "expected %s, got %s", type_str(TYPE_INT), type_str(l->t)); return -1; } struct ast *r = div_r(d); if (analyze(state, scope, r)) return -1; if (r->t != TYPE_INT) { semantic_error(scope, r, "expected %s, got %s", type_str(TYPE_INT), type_str(r->t)); return -1; } d->t = TYPE_INT; return 0; } static int analyze_print(struct state *state, struct scope *scope, struct ast *p) { struct ast *items = print_items(p); struct ast *fixups = NULL, *last_fixup = NULL; foreach_node(n, items) { if (analyze(state, scope, n)) return -1; if (n->t == TYPE_VOID) { semantic_error(scope, n, "void not printable"); return -1; } /* build new list of nodes, helps a bit when lowering */ struct ast *fixup = fixup_print_item(n); if (!fixups) fixups = fixup; if (last_fixup) { last_fixup->n = fixup; } last_fixup = fixup; } print_items(p) = fixups; p->t = TYPE_VOID; return 0; } static int analyze_return(struct state *state, struct scope *scope, struct ast *r) { struct ast *parent = state->parent; if (!parent) { semantic_error(scope, r, "stray return"); return -1; } if (!has_type(parent)) { semantic_error(scope, r, "stray return in proc without return type"); return -1; } struct ast *expr = return_expr(r); if (analyze(state, scope, expr)) return -1; /* deduce return type from first return we hit */ if (parent->t == TYPE_AUTO) parent->t = expr->t; if (expr->t != parent->t) { semantic_error(scope, r, "return type mismatch: %s vs %s", type_str(parent->t), type_str(expr->t)); return -1; } r->t = expr->t; return 0; } static int analyze_id(struct state *state, struct scope *scope, struct ast *i) { UNUSED(state); struct ast *exists = file_scope_find_analyzed(state, scope, i->id); if (!exists) { semantic_error(scope, i, "no such symbol"); return -1; } if (exists->k != AST_FORMAL_DEF && exists->k != AST_VAR_DEF) { semantic_error(scope, i, "no such variable"); return -1; } i->t = exists->t; return 0; } static int analyze_dot(struct state *state, struct scope *scope, struct ast *d) { struct ast *base = dot_base(d); if (analyze(state, scope, base)) return -1; if (base->t != TYPE_DATE) { semantic_error(scope, d, "expected %s, got %s", type_str(TYPE_DATE), type_str(base->t)); return -1; } d->t = TYPE_INT; if (same_id(d->id, "day")) return 0; if (same_id(d->id, "month")) return 0; if (same_id(d->id, "year")) return 0; semantic_error(scope, d, "illegal write attribute %s", d->id); return -1; } static int analyze_attr(struct state *state, struct scope *scope, struct ast *d) { struct ast *base = attr_base(d); if (analyze(state, scope, base)) return -1; if (base->t != TYPE_DATE) { semantic_error(scope, d, "expected %s, got %s", type_str(TYPE_DATE), type_str(base->t)); return -1; } d->t = TYPE_INT; if (same_id(d->id, "day")) return 0; if (same_id(d->id, "month")) return 0; if (same_id(d->id, "year")) return 0; if (same_id(d->id, "weekday")) { d->t = TYPE_STRING; return 0; } if (same_id(d->id, "weeknum")) return 0; semantic_error(scope, d, "illegal write attribute %s", d->id); return -1; } static int analyze_assign(struct state *state, struct scope *scope, struct ast *n) { struct ast *l = assign_l(n); if (analyze(state, scope, l)) return -1; struct ast *r = assign_r(n); if (analyze(state, scope, r)) return -1; if (l->t != r->t) { semantic_error(scope, n, "type mismatch: %s vs %s", type_str(l->t), type_str(r->t)); return -1; } n->t = l->t; return 0; } static int analyze_today(struct state *state, struct scope *scope, struct ast *c) { UNUSED(state); if (ast_list_len(func_call_args(c)) != 0) { semantic_error(scope, c, "expected 0 arguments, got %zd", ast_list_len(func_call_args(c))); return -1; } c->k = AST_BUILTIN_CALL; c->t = TYPE_DATE; return 0; } static int analyze_proc_call(struct state *state, struct scope *scope, struct ast *c) { struct ast *exists = file_scope_find_analyzed(state, scope, proc_call_id(c)); if (!exists || exists->k != AST_PROC_DEF) { semantic_error(scope, c, "no such proc"); return -1; } if (exists->t == TYPE_AUTO) { semantic_error(scope, c, "proc call before deduction of auto"); return -1; } struct ast *args = proc_call_args(c); if (analyze_list(state, scope, args)) return -1; struct ast *formal = proc_formals(exists); if (ast_list_len(formal) != ast_list_len(args)) { semantic_error(scope, c, "expected %zd args, got %zd", ast_list_len(formal), ast_list_len(args)); return -1; } foreach_node(a, args) { if (a->t != formal->t) { semantic_error(scope, c, "expected %s, got %s", type_str(formal->t), type_str(a->t)); return -1; } formal = formal->n; } c->t = exists->t; return 0; } static int analyze_func_call(struct state *state, struct scope *scope, struct ast *c) { /* handle special Today() built-in */ if (same_id(func_call_id(c), "Today")) return analyze_today(state, scope, c); struct ast *exists = file_scope_find_analyzed(state, scope, func_call_id(c)); if (!exists || exists->k != AST_FUNC_DEF) { semantic_error(scope, c, "no such func"); return -1; } if (exists->t == TYPE_AUTO) { semantic_error(scope, c, "func call before deduction of auto"); return -1; } struct ast *args = func_call_args(c); if (analyze_list(state, scope, args)) return -1; struct ast *formal = func_formals(exists); if (ast_list_len(formal) != ast_list_len(args)) { semantic_error(scope, c, "expected %zd args, got %zd", ast_list_len(formal), ast_list_len(args)); return -1; } foreach_node(a, args) { if (a->t != formal->t) { semantic_error(scope, c, "expected %s, got %s", type_str(formal->t), type_str(a->t)); return -1; } formal = formal->n; } c->t = exists->t; return 0; } static int analyze_until(struct state *state, struct scope *scope, struct ast *n) { struct scope *until_scope = create_scope(); scope_add_scope(scope, until_scope); struct ast *body = until_body(n); if (analyze_list(state, until_scope, body)) return -1; struct ast *cond = until_cond(n); if (analyze(state, until_scope, cond)) return -1; if (cond->t != TYPE_BOOL) { semantic_error(scope, cond, "expected %s, got %s", type_str(TYPE_BOOL), type_str(cond->t)); return -1; } n->t = TYPE_VOID; return 0; } static int analyze_unless(struct state *state, struct scope *scope, struct ast *n) { struct scope *unless_scope = create_scope(); struct scope *otherwise_scope = create_scope(); scope_add_scope(scope, unless_scope); scope_add_scope(scope, otherwise_scope); struct ast *body = unless_body(n); if (analyze_list(state, unless_scope, body)) return -1; struct ast *cond = unless_cond(n); if (analyze(state, scope, cond)) return -1; if (cond->t != TYPE_BOOL) { semantic_error(scope, cond, "expected %s, got %s", type_str(TYPE_BOOL), type_str(cond->t)); return -1; } struct ast *otherwise = unless_otherwise(n); if (otherwise && analyze_list(state, otherwise_scope, otherwise)) return -1; n->t = TYPE_VOID; return 0; } static int analyze_unless_expr(struct state *state, struct scope *scope, struct ast *n) { struct ast *body = unless_expr_body(n); if (analyze(state, scope, body)) return -1; struct ast *cond = unless_expr_cond(n); if (analyze(state, scope, cond)) return -1; if (cond->t != TYPE_BOOL) { semantic_error(scope, cond, "expected %s, got %s", type_str(TYPE_BOOL), type_str(cond->t)); return -1; } struct ast *otherwise = unless_expr_otherwise(n); if (analyze(state, scope, otherwise)) return -1; if (body->t != otherwise->t) { semantic_error(scope, n, "type mismatch: %s vs %s", type_str(body->t), type_str(otherwise->t)); return -1; } n->t = body->t; return 0; } static int analyze(struct state *state, struct scope *scope, struct ast *n) { int ret = 0; if (ast_flags(n, AST_FLAG_CHECKED)) { assert(has_type(n)); return 0; } if (ast_flags(n, AST_FLAG_INIT)) { semantic_error(scope, n, "semantic loop detected"); return -1; } ast_set_flags(n, AST_FLAG_INIT); /* generally what we want to happen, special cases can handle themselves */ if (!n->scope) n->scope = scope; switch (n->k) { case AST_FUNC_DEF: ret = analyze_func_def(state, scope, n); break; case AST_PROC_DEF: ret = analyze_proc_def(state, scope, n); break; case AST_VAR_DEF: ret = analyze_var_def(state, scope, n); break; case AST_FORMAL_DEF: ret = analyze_formal_def(state, scope, n); break; case AST_NEG: ret = analyze_neg(state, scope, n); break; case AST_POS: ret = analyze_pos(state, scope, n); break; case AST_CONST_DATE: ret = analyze_const_date(state, scope, n); break; case AST_CONST_INT: ret = analyze_const_int(state, scope, n); break; case AST_CONST_STRING: ret = analyze_const_string(state, scope, n); break; case AST_EQ: ret = analyze_eq(state, scope, n); break; case AST_LT: ret = analyze_lt(state, scope, n); break; case AST_ADD: ret = analyze_add(state, scope, n); break; case AST_SUB: ret = analyze_sub(state, scope, n); break; case AST_MUL: ret = analyze_mul(state, scope, n); break; case AST_DIV: ret = analyze_div(state, scope, n); break; case AST_PRINT: ret = analyze_print(state, scope, n); break; case AST_RETURN: ret = analyze_return(state, scope, n); break; case AST_ID: ret = analyze_id(state, scope, n); break; case AST_DOT: ret = analyze_dot(state, scope, n); break; case AST_ATTR: ret = analyze_attr(state, scope, n); break; case AST_ASSIGN: ret = analyze_assign(state, scope, n); break; case AST_PROC_CALL: ret = analyze_proc_call(state, scope, n); break; case AST_FUNC_CALL: ret = analyze_func_call(state, scope, n); break; case AST_UNTIL: ret = analyze_until(state, scope, n); break; case AST_UNLESS: ret = analyze_unless(state, scope, n); break; case AST_UNLESS_EXPR: ret = analyze_unless_expr(state, scope, n); break; default: break; } if (ret == 0) { assert(has_type(n)); } ast_set_flags(n, AST_FLAG_CHECKED); return ret; } int check(struct scope *scope, struct ast *root) { foreach_node(n, root) { if (analyze_visibility(scope, n)) return -1; } foreach_node(n, root) { struct state state = {0}; if (analyze(&state, scope, n)) return -1; } #ifdef DEBUG foreach_node(n, root) { printf("// after checking:\n"); ast_dump(0, n); } #endif return 0; }