implement basics of type checking

This commit is contained in:
Josh Wolfe 2015-11-30 18:43:45 -07:00
parent ef482ece7c
commit abbc395701

View File

@ -10,6 +10,12 @@
#include "error.hpp"
#include "zig_llvm.hpp"
struct BlockContext {
AstNode *node;
BlockContext *root;
BlockContext *parent;
};
static void add_node_error(CodeGen *g, AstNode *node, Buf *msg) {
g->errors.add_one();
ErrorMsg *last_msg = &g->errors.last();
@ -229,6 +235,155 @@ static void preview_function_declarations(CodeGen *g, AstNode *node) {
}
}
static TypeTableEntry * get_return_type(BlockContext *context) {
AstNode *fn_def_node = context->root->node;
assert(fn_def_node->type == NodeTypeFnDef);
AstNode *fn_proto_node = fn_def_node->data.fn_def.fn_proto;
assert(fn_proto_node->type == NodeTypeFnProto);
AstNode *return_type_node = fn_proto_node->data.fn_proto.return_type;
assert(return_type_node->codegen_node);
return return_type_node->codegen_node->data.type_node.entry;
}
static void check_type_compatibility(CodeGen *g, AstNode *node, TypeTableEntry *expected_type, TypeTableEntry *actual_type) {
if (expected_type == actual_type)
return; // good
if (expected_type == g->builtin_types.entry_invalid || actual_type == g->builtin_types.entry_invalid)
return; // already complained
if (actual_type == g->builtin_types.entry_unreachable)
return; // TODO: is this true?
// TODO better error message
add_node_error(g, node, buf_sprintf("type mismatch."));
}
static TypeTableEntry * analyze_expression(CodeGen *g, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) {
switch (node->type) {
case NodeTypeBlock:
{
// TODO: nested block scopes
TypeTableEntry *return_type = g->builtin_types.entry_void;
for (int i = 0; i < node->data.block.statements.length; i += 1) {
AstNode *child = node->data.block.statements.at(i);
if (return_type == g->builtin_types.entry_unreachable) {
add_node_error(g, child,
buf_sprintf("unreachable code"));
break;
}
return_type = analyze_expression(g, context, nullptr, child);
}
return return_type;
}
case NodeTypeReturnExpr:
{
TypeTableEntry *expected_return_type = get_return_type(context);
TypeTableEntry *actual_return_type;
if (node->data.return_expr.expr) {
actual_return_type = analyze_expression(g, context, expected_return_type, node->data.return_expr.expr);
} else {
actual_return_type = g->builtin_types.entry_void;
}
if (actual_return_type == g->builtin_types.entry_unreachable) {
// "return exit(0)" should just be "exit(0)".
add_node_error(g, node, buf_sprintf("returning is unreachable."));
actual_return_type = g->builtin_types.entry_invalid;
}
check_type_compatibility(g, node, expected_return_type, actual_return_type);
return g->builtin_types.entry_unreachable;
}
case NodeTypeBinOpExpr:
{
// TODO: think about expected types
analyze_expression(g, context, expected_type, node->data.bin_op_expr.op1);
analyze_expression(g, context, expected_type, node->data.bin_op_expr.op2);
return expected_type;
}
case NodeTypeFnCallExpr:
{
Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
auto entry = g->fn_table.maybe_get(name);
if (!entry) {
add_node_error(g, node,
buf_sprintf("undefined function: '%s'", buf_ptr(name)));
// still analyze the parameters, even though we don't know what to expect
for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
AstNode *child = node->data.fn_call_expr.params.at(i);
analyze_expression(g, context, nullptr, child);
}
return g->builtin_types.entry_invalid;
} else {
FnTableEntry *fn_table_entry = entry->value;
assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
AstNodeFnProto *fn_proto = &fn_table_entry->proto_node->data.fn_proto;
// count parameters
int expected_param_count = fn_proto->params.length;
int actual_param_count = node->data.fn_call_expr.params.length;
if (expected_param_count != actual_param_count) {
add_node_error(g, node,
buf_sprintf("wrong number of arguments. Expected %d, got %d.",
expected_param_count, actual_param_count));
}
// analyze each parameter
for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
AstNode *child = node->data.fn_call_expr.params.at(i);
// determine the expected type for each parameter
TypeTableEntry *expected_param_type = nullptr;
if (i < fn_proto->params.length) {
AstNode *param_decl_node = fn_proto->params.at(i);
assert(param_decl_node->type == NodeTypeParamDecl);
AstNode *param_type_node = param_decl_node->data.param_decl.type;
if (param_type_node->codegen_node)
expected_param_type = param_type_node->codegen_node->data.type_node.entry;
}
analyze_expression(g, context, expected_param_type, child);
}
TypeTableEntry *return_type = fn_proto->return_type->codegen_node->data.type_node.entry;
check_type_compatibility(g, node, expected_type, return_type);
return return_type;
}
}
case NodeTypeNumberLiteral:
// TODO: generic literal int type
return g->builtin_types.entry_i32;
case NodeTypeStringLiteral:
zig_panic("TODO");
case NodeTypeUnreachable:
return g->builtin_types.entry_unreachable;
case NodeTypeSymbol:
// look up symbol in symbol table
zig_panic("TODO");
case NodeTypeCastExpr:
case NodeTypePrefixOpExpr:
zig_panic("TODO");
case NodeTypeDirective:
case NodeTypeFnDecl:
case NodeTypeFnProto:
case NodeTypeParamDecl:
case NodeTypeType:
case NodeTypeRoot:
case NodeTypeRootExportDecl:
case NodeTypeExternBlock:
case NodeTypeFnDef:
zig_unreachable();
}
zig_unreachable();
}
static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
// Follow the execution flow and make sure the code returns appropriately.
// * A `return` statement in an unreachable type function should be an error.
@ -282,74 +437,6 @@ static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
}
}
static void analyze_expression(CodeGen *g, AstNode *node) {
switch (node->type) {
case NodeTypeBlock:
for (int i = 0; i < node->data.block.statements.length; i += 1) {
AstNode *child = node->data.block.statements.at(i);
analyze_expression(g, child);
}
break;
case NodeTypeReturnExpr:
if (node->data.return_expr.expr) {
analyze_expression(g, node->data.return_expr.expr);
}
break;
case NodeTypeBinOpExpr:
analyze_expression(g, node->data.bin_op_expr.op1);
analyze_expression(g, node->data.bin_op_expr.op2);
break;
case NodeTypeFnCallExpr:
{
Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
auto entry = g->fn_table.maybe_get(name);
if (!entry) {
add_node_error(g, node,
buf_sprintf("undefined function: '%s'", buf_ptr(name)));
} else {
FnTableEntry *fn_table_entry = entry->value;
assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
int actual_param_count = node->data.fn_call_expr.params.length;
if (expected_param_count != actual_param_count) {
add_node_error(g, node,
buf_sprintf("wrong number of arguments. Expected %d, got %d.",
expected_param_count, actual_param_count));
}
}
for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
AstNode *child = node->data.fn_call_expr.params.at(i);
analyze_expression(g, child);
}
break;
}
case NodeTypeCastExpr:
zig_panic("TODO");
break;
case NodeTypePrefixOpExpr:
zig_panic("TODO");
break;
case NodeTypeNumberLiteral:
case NodeTypeStringLiteral:
case NodeTypeUnreachable:
case NodeTypeSymbol:
// nothing to do
break;
case NodeTypeDirective:
case NodeTypeFnDecl:
case NodeTypeFnProto:
case NodeTypeParamDecl:
case NodeTypeType:
case NodeTypeRoot:
case NodeTypeRootExportDecl:
case NodeTypeExternBlock:
case NodeTypeFnDef:
zig_unreachable();
}
}
static void analyze_top_level_declaration(CodeGen *g, AstNode *node) {
switch (node->type) {
case NodeTypeFnDef:
@ -371,7 +458,13 @@ static void analyze_top_level_declaration(CodeGen *g, AstNode *node) {
}
check_fn_def_control_flow(g, node);
analyze_expression(g, node->data.fn_def.body);
BlockContext context;
context.node = node;
context.root = &context;
context.parent = nullptr;
TypeTableEntry *expected_type = fn_proto->return_type->codegen_node->data.type_node.entry;
analyze_expression(g, &context, expected_type, node->data.fn_def.body);
}
break;
@ -424,6 +517,12 @@ static void analyze_root(CodeGen *g, AstNode *node) {
}
static void define_primitive_types(CodeGen *g) {
{
// if this type is anywhere in the AST, we should never hit codegen.
TypeTableEntry *entry = allocate<TypeTableEntry>(1);
buf_init_from_str(&entry->name, "(invalid)");
g->builtin_types.entry_invalid = entry;
}
{
TypeTableEntry *entry = allocate<TypeTableEntry>(1);
entry->type_ref = LLVMInt8Type();
@ -450,15 +549,12 @@ static void define_primitive_types(CodeGen *g) {
LLVMZigEncoding_DW_ATE_unsigned());
g->type_table.put(&entry->name, entry);
g->builtin_types.entry_void = entry;
// invalid types are void
g->builtin_types.entry_invalid = entry;
}
{
TypeTableEntry *entry = allocate<TypeTableEntry>(1);
entry->type_ref = LLVMVoidType();
buf_init_from_str(&entry->name, "unreachable");
entry->di_type = g->builtin_types.entry_invalid->di_type;
entry->di_type = g->builtin_types.entry_void->di_type;
g->type_table.put(&entry->name, entry);
g->builtin_types.entry_unreachable = entry;
}