diff --git a/example/shared_library/mathtest.zig b/example/shared_library/mathtest.zig index 35ec901523..e708cfe216 100644 --- a/example/shared_library/mathtest.zig +++ b/example/shared_library/mathtest.zig @@ -2,5 +2,5 @@ export library "mathtest"; export fn add(a: i32, b: i32) -> i32 { - return a + b; + a + b } diff --git a/src/analyze.cpp b/src/analyze.cpp index c70b2d6695..004f3fa0de 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -74,7 +74,7 @@ TypeTableEntry *get_pointer_to_type(CodeGen *g, TypeTableEntry *child_type, bool } } -static void resolve_type(CodeGen *g, AstNode *node) { +static TypeTableEntry *resolve_type(CodeGen *g, AstNode *node) { assert(!node->codegen_node); node->codegen_node = allocate(1); TypeNode *type_node = &node->codegen_node->data.type_node; @@ -90,7 +90,7 @@ static void resolve_type(CodeGen *g, AstNode *node) { buf_sprintf("invalid type name: '%s'", buf_ptr(name))); type_node->entry = g->builtin_types.entry_invalid; } - break; + return type_node->entry; } case AstNodeTypeTypePointer: { @@ -101,12 +101,12 @@ static void resolve_type(CodeGen *g, AstNode *node) { buf_create_from_str("pointer to unreachable not allowed")); } type_node->entry = get_pointer_to_type(g, child_type, node->data.type.is_const); - break; + return type_node->entry; } } } -static void resolve_function_proto(CodeGen *g, AstNode *node) { +static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry) { assert(node->type == NodeTypeFnProto); for (int i = 0; i < node->data.fn_proto.directives->length; i += 1) { @@ -120,9 +120,11 @@ static void resolve_function_proto(CodeGen *g, AstNode *node) { AstNode *child = node->data.fn_proto.params.at(i); assert(child->type == NodeTypeParamDecl); - // parameter names are not important here. - - resolve_type(g, child->data.param_decl.type); + Buf *param_name = &child->data.param_decl.name; + SymbolTableEntry *symbol_entry = allocate(1); + symbol_entry->type_entry = resolve_type(g, child->data.param_decl.type); + symbol_entry->param_index = i; + fn_table_entry->symbol_table.put(param_name, symbol_entry); } resolve_type(g, node->data.fn_proto.return_type); @@ -148,20 +150,26 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import, assert(fn_decl->type == NodeTypeFnDecl); AstNode *fn_proto = fn_decl->data.fn_decl.fn_proto; bool is_pub = (fn_proto->data.fn_proto.visib_mod == FnProtoVisibModPub); - resolve_function_proto(g, fn_proto); - Buf *name = &fn_proto->data.fn_proto.name; FnTableEntry *fn_table_entry = allocate(1); fn_table_entry->proto_node = fn_proto; fn_table_entry->is_extern = true; fn_table_entry->calling_convention = LLVMCCallConv; fn_table_entry->import_entry = import; + fn_table_entry->symbol_table.init(8); + resolve_function_proto(g, fn_proto, fn_table_entry); + + Buf *name = &fn_proto->data.fn_proto.name; g->fn_protos.append(fn_table_entry); import->fn_table.put(name, fn_table_entry); if (is_pub) { g->fn_table.put(name, fn_table_entry); } + + assert(!fn_proto->codegen_node); + fn_proto->codegen_node = allocate(1); + fn_proto->codegen_node->data.fn_proto_node.fn_table_entry = fn_table_entry; } break; case NodeTypeFnDef: @@ -198,6 +206,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import, fn_table_entry->fn_def_node = node; fn_table_entry->internal_linkage = is_internal; fn_table_entry->calling_convention = is_internal ? LLVMFastCallConv : LLVMCCallConv; + fn_table_entry->symbol_table.init(8); g->fn_protos.append(fn_table_entry); g->fn_defs.append(fn_table_entry); @@ -207,7 +216,11 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import, g->fn_table.put(proto_name, fn_table_entry); } - resolve_function_proto(g, proto_node); + resolve_function_proto(g, proto_node, fn_table_entry); + + assert(!proto_node->codegen_node); + proto_node->codegen_node = allocate(1); + proto_node->codegen_node->data.fn_proto_node.fn_table_entry = fn_table_entry; } } break; @@ -289,6 +302,16 @@ static TypeTableEntry * get_return_type(BlockContext *context) { return return_type_node->codegen_node->data.type_node.entry; } +static FnTableEntry *get_context_fn_entry(BlockContext *context) { + AstNode *fn_def_node = context->root->node; + assert(fn_def_node->type == NodeTypeFnDef); + AstNode *fn_proto_node = fn_def_node->data.fn_def.fn_proto; + assert(fn_proto_node->type == NodeTypeFnProto); + assert(fn_proto_node->codegen_node); + assert(fn_proto_node->codegen_node->data.fn_proto_node.fn_table_entry); + return fn_proto_node->codegen_node->data.fn_proto_node.fn_table_entry; +} + static void check_type_compatibility(CodeGen *g, AstNode *node, TypeTableEntry *expected_type, TypeTableEntry *actual_type) { if (expected_type == nullptr) return; // anything will do @@ -482,9 +505,20 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, break; case NodeTypeSymbol: - // look up symbol in symbol table - zig_panic("TODO analyze_expression symbol"); - + { + Buf *symbol_name = &node->data.symbol; + FnTableEntry *fn_table_entry = get_context_fn_entry(context); + auto table_entry = fn_table_entry->symbol_table.maybe_get(symbol_name); + if (table_entry) { + SymbolTableEntry *symbol_entry = table_entry->value; + return_type = symbol_entry->type_entry; + } else { + add_node_error(g, node, + buf_sprintf("use of undeclared identifier '%s'", buf_ptr(symbol_name))); + return_type = g->builtin_types.entry_invalid; + } + break; + } case NodeTypeCastExpr: zig_panic("TODO analyze_expression cast expr"); break; diff --git a/src/codegen.cpp b/src/codegen.cpp index d442856692..45bbfeb310 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -105,19 +105,13 @@ static LLVMValueRef find_or_create_string(CodeGen *g, Buf *str) { static LLVMValueRef get_variable_value(CodeGen *g, Buf *name) { assert(g->cur_fn->proto_node->type == NodeTypeFnProto); - int param_count = g->cur_fn->proto_node->data.fn_proto.params.length; - for (int i = 0; i < param_count; i += 1) { - AstNode *param_decl_node = g->cur_fn->proto_node->data.fn_proto.params.at(i); - assert(param_decl_node->type == NodeTypeParamDecl); - Buf *param_name = ¶m_decl_node->data.param_decl.name; - if (buf_eql_buf(name, param_name)) { - CodeGenNode *codegen_node = g->cur_fn->fn_def_node->codegen_node; - assert(codegen_node); - FnDefNode *codegen_fn_def = &codegen_node->data.fn_def_node; - return codegen_fn_def->params[i]; - } - } - zig_unreachable(); + + SymbolTableEntry *symbol_entry = g->cur_fn->symbol_table.get(name); + + CodeGenNode *codegen_node = g->cur_fn->fn_def_node->codegen_node; + assert(codegen_node); + FnDefNode *codegen_fn_def = &codegen_node->data.fn_def_node; + return codegen_fn_def->params[symbol_entry->param_index]; } static TypeTableEntry *get_expr_type(AstNode *node) { diff --git a/src/semantic_info.hpp b/src/semantic_info.hpp index 85da09cb16..2993948446 100644 --- a/src/semantic_info.hpp +++ b/src/semantic_info.hpp @@ -38,6 +38,11 @@ struct ImportTableEntry { HashMap fn_table; }; +struct SymbolTableEntry { + TypeTableEntry *type_entry; + int param_index; // only valid in the case of parameters +}; + struct FnTableEntry { LLVMValueRef fn_value; AstNode *proto_node; @@ -46,6 +51,9 @@ struct FnTableEntry { bool internal_linkage; unsigned calling_convention; ImportTableEntry *import_entry; + + // reminder: hash tables must be initialized before use + HashMap symbol_table; }; struct CodeGen { @@ -106,6 +114,10 @@ struct TypeNode { TypeTableEntry *entry; }; +struct FnProtoNode { + FnTableEntry *fn_table_entry; +}; + struct FnDefNode { TypeTableEntry *implicit_return_type; bool skip; @@ -121,6 +133,7 @@ struct CodeGenNode { TypeNode type_node; // for NodeTypeType FnDefNode fn_def_node; // for NodeTypeFnDef ExprNode expr_node; // for all the expression nodes + FnProtoNode fn_proto_node; // for NodeTypeFnProto } data; }; diff --git a/test/run_tests.cpp b/test/run_tests.cpp index c9855e8606..d4460bbdc2 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -213,6 +213,25 @@ static void add_compiling_test_cases(void) { exit(0); } )SOURCE", "1 is true\n!0 is true\n"); + + add_simple_case("params", R"SOURCE( + #link("c") + extern { + fn puts(s: *const u8) -> i32; + fn exit(code: i32) -> unreachable; + } + + fn add(a: i32, b: i32) -> i32 { + a + b + } + + export fn _start() -> unreachable { + if add(22, 11) == 33 { + puts("pass"); + } + exit(0); + } + )SOURCE", "pass\n"); } static void add_compile_failure_test_cases(void) {