diff --git a/README.md b/README.md index 6311010156..fd3396620a 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,6 @@ make ## Roadmap * variable declarations and assignment expressions - * Multiple files * Type checking * inline assembly and syscalls * running code at compile time diff --git a/example/multiple_files/foo.zig b/example/multiple_files/foo.zig index a42e58397c..021bf71e5b 100644 --- a/example/multiple_files/foo.zig +++ b/example/multiple_files/foo.zig @@ -6,6 +6,6 @@ fn private_function() { puts("it works!"); } -fn print_text() { +pub fn print_text() { private_function(); } diff --git a/example/multiple_files/libc.zig b/example/multiple_files/libc.zig index 7d1a5bebd9..19c1106fd6 100644 --- a/example/multiple_files/libc.zig +++ b/example/multiple_files/libc.zig @@ -1,5 +1,5 @@ #link("c") extern { - fn puts(s: *mut u8) -> i32; - fn exit(code: i32) -> unreachable; + pub fn puts(s: *mut u8) -> i32; + pub fn exit(code: i32) -> unreachable; } diff --git a/example/multiple_files/main.zig b/example/multiple_files/main.zig index 9bc489e925..c7cb2412e2 100644 --- a/example/multiple_files/main.zig +++ b/example/multiple_files/main.zig @@ -3,7 +3,7 @@ export executable "test"; use "libc.zig"; use "foo.zig"; -fn _start() -> unreachable { +export fn _start() -> unreachable { private_function(); } diff --git a/src/analyze.cpp b/src/analyze.cpp index fccbad8bb7..d7d5aeb5b0 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -137,6 +137,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import, AstNode *fn_decl = node->data.extern_block.fn_decls.at(fn_decl_i); assert(fn_decl->type == NodeTypeFnDecl); AstNode *fn_proto = fn_decl->data.fn_decl.fn_proto; + bool is_pub = (fn_proto->data.fn_proto.visib_mod == FnProtoVisibModPub); resolve_function_proto(g, fn_proto); Buf *name = &fn_proto->data.fn_proto.name; @@ -145,7 +146,12 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import, fn_table_entry->is_extern = true; fn_table_entry->calling_convention = LLVMCCallConv; fn_table_entry->import_entry = import; - g->fn_table.put(name, fn_table_entry); + + g->fn_protos.append(fn_table_entry); + import->fn_table.put(name, fn_table_entry); + if (is_pub) { + g->fn_table.put(name, fn_table_entry); + } } break; case NodeTypeFnDef: @@ -153,27 +159,44 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import, AstNode *proto_node = node->data.fn_def.fn_proto; assert(proto_node->type == NodeTypeFnProto); Buf *proto_name = &proto_node->data.fn_proto.name; - auto entry = g->fn_table.maybe_get(proto_name); + auto entry = import->fn_table.maybe_get(proto_name); + bool skip = false; + bool is_internal = (proto_node->data.fn_proto.visib_mod != FnProtoVisibModExport); + bool is_pub = (proto_node->data.fn_proto.visib_mod == FnProtoVisibModPub); if (entry) { add_node_error(g, node, buf_sprintf("redefinition of '%s'", buf_ptr(proto_name))); assert(!node->codegen_node); node->codegen_node = allocate(1); node->codegen_node->data.fn_def_node.skip = true; - } else { + skip = true; + } else if (is_pub) { + auto entry = g->fn_table.maybe_get(proto_name); + if (entry) { + add_node_error(g, node, + buf_sprintf("redefinition of '%s'", buf_ptr(proto_name))); + assert(!node->codegen_node); + node->codegen_node = allocate(1); + node->codegen_node->data.fn_def_node.skip = true; + skip = true; + } + } + if (!skip) { FnTableEntry *fn_table_entry = allocate(1); fn_table_entry->import_entry = import; fn_table_entry->proto_node = proto_node; fn_table_entry->fn_def_node = node; - fn_table_entry->internal_linkage = proto_node->data.fn_proto.visib_mod != FnProtoVisibModExport; - if (fn_table_entry->internal_linkage) { - fn_table_entry->calling_convention = LLVMFastCallConv; - } else { - fn_table_entry->calling_convention = LLVMCCallConv; - } - g->fn_table.put(proto_name, fn_table_entry); + fn_table_entry->internal_linkage = is_internal; + fn_table_entry->calling_convention = is_internal ? LLVMFastCallConv : LLVMCCallConv; + + g->fn_protos.append(fn_table_entry); g->fn_defs.append(fn_table_entry); + import->fn_table.put(proto_name, fn_table_entry); + if (is_pub) { + g->fn_table.put(proto_name, fn_table_entry); + } + resolve_function_proto(g, proto_node); } } @@ -297,28 +320,31 @@ static void check_fn_def_control_flow(CodeGen *g, AstNode *node) { } } -static void analyze_expression(CodeGen *g, AstNode *node) { +static void analyze_expression(CodeGen *g, ImportTableEntry *import, AstNode *node) { switch (node->type) { case NodeTypeBlock: for (int i = 0; i < node->data.block.statements.length; i += 1) { AstNode *child = node->data.block.statements.at(i); - analyze_expression(g, child); + analyze_expression(g, import, child); } break; case NodeTypeReturnExpr: if (node->data.return_expr.expr) { - analyze_expression(g, node->data.return_expr.expr); + analyze_expression(g, import, node->data.return_expr.expr); } break; case NodeTypeBinOpExpr: - analyze_expression(g, node->data.bin_op_expr.op1); - analyze_expression(g, node->data.bin_op_expr.op2); + analyze_expression(g, import, node->data.bin_op_expr.op1); + analyze_expression(g, import, node->data.bin_op_expr.op2); break; case NodeTypeFnCallExpr: { Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr); - auto entry = g->fn_table.maybe_get(name); + auto entry = import->fn_table.maybe_get(name); + if (!entry) + entry = g->fn_table.maybe_get(name); + if (!entry) { add_node_error(g, node, buf_sprintf("undefined function: '%s'", buf_ptr(name))); @@ -336,7 +362,7 @@ static void analyze_expression(CodeGen *g, AstNode *node) { for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) { AstNode *child = node->data.fn_call_expr.params.at(i); - analyze_expression(g, child); + analyze_expression(g, import, child); } break; } @@ -366,7 +392,7 @@ static void analyze_expression(CodeGen *g, AstNode *node) { } } -static void analyze_top_level_declaration(CodeGen *g, AstNode *node) { +static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import, AstNode *node) { switch (node->type) { case NodeTypeFnDef: { @@ -387,7 +413,7 @@ static void analyze_top_level_declaration(CodeGen *g, AstNode *node) { } check_fn_def_control_flow(g, node); - analyze_expression(g, node->data.fn_def.body); + analyze_expression(g, import, node->data.fn_def.body); } break; @@ -423,32 +449,49 @@ static void analyze_top_level_declaration(CodeGen *g, AstNode *node) { } } -static void analyze_root(CodeGen *g, ImportTableEntry *import, AstNode *node) { +static void find_function_declarations_root(CodeGen *g, ImportTableEntry *import, AstNode *node) { assert(node->type == NodeTypeRoot); - // find function declarations for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) { AstNode *child = node->data.root.top_level_decls.at(i); preview_function_declarations(g, import, child); } +} + +static void analyze_top_level_decls_root(CodeGen *g, ImportTableEntry *import, AstNode *node) { + assert(node->type == NodeTypeRoot); + for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) { AstNode *child = node->data.root.top_level_decls.at(i); - analyze_top_level_declaration(g, child); + analyze_top_level_declaration(g, import, child); } - } void semantic_analyze(CodeGen *g) { - auto it = g->import_table.entry_iterator(); - for (;;) { - auto *entry = it.next(); - if (!entry) - break; + { + auto it = g->import_table.entry_iterator(); + for (;;) { + auto *entry = it.next(); + if (!entry) + break; - ImportTableEntry *import = entry->value; - analyze_root(g, import, import->root); + ImportTableEntry *import = entry->value; + find_function_declarations_root(g, import, import->root); + } } + { + auto it = g->import_table.entry_iterator(); + for (;;) { + auto *entry = it.next(); + if (!entry) + break; + + ImportTableEntry *import = entry->value; + analyze_top_level_decls_root(g, import, import->root); + } + } + if (!g->root_out_name) { add_node_error(g, g->root_import->root, diff --git a/src/codegen.cpp b/src/codegen.cpp index 2f9bf47383..5ce95023e1 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -125,7 +125,13 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr); - FnTableEntry *fn_table_entry = g->fn_table.get(name); + FnTableEntry *fn_table_entry; + auto entry = g->cur_fn->import_entry->fn_table.maybe_get(name); + if (entry) + fn_table_entry = entry->value; + else + fn_table_entry = g->fn_table.get(name); + assert(fn_table_entry->proto_node->type == NodeTypeFnProto); int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length; int actual_param_count = node->data.fn_call_expr.params.length; @@ -478,13 +484,8 @@ static void do_code_gen(CodeGen *g) { // Generate function prototypes - auto it = g->fn_table.entry_iterator(); - for (;;) { - auto *entry = it.next(); - if (!entry) - break; - - FnTableEntry *fn_table_entry = entry->value; + for (int i = 0; i < g->fn_protos.length; i += 1) { + FnTableEntry *fn_table_entry = g->fn_protos.at(i); AstNode *proto_node = fn_table_entry->proto_node; assert(proto_node->type == NodeTypeFnProto); @@ -547,6 +548,7 @@ static void do_code_gen(CodeGen *g) { assert(codegen_node); FnDefNode *codegen_fn_def = &codegen_node->data.fn_def_node; + assert(codegen_fn_def); codegen_fn_def->params = allocate(LLVMCountParams(fn)); LLVMGetParams(fn, codegen_fn_def->params); @@ -733,9 +735,9 @@ static ImportTableEntry *codegen_add_code(CodeGen *g, Buf *source_path, Buf *sou if (!entry) { Buf full_path = BUF_INIT; os_path_join(g->root_source_dir, &top_level_decl->data.use.path, &full_path); - Buf import_code = BUF_INIT; - os_fetch_file_path(&full_path, &import_code); - codegen_add_code(g, &top_level_decl->data.use.path, &import_code); + Buf *import_code = buf_alloc(); + os_fetch_file_path(&full_path, import_code); + codegen_add_code(g, &top_level_decl->data.use.path, import_code); } } diff --git a/src/semantic_info.hpp b/src/semantic_info.hpp index 075b1c8bdc..f0a79d685d 100644 --- a/src/semantic_info.hpp +++ b/src/semantic_info.hpp @@ -80,7 +80,14 @@ struct CodeGen { Buf *root_source_dir; Buf *root_out_name; ZigList block_scopes; + + // The function definitions this module includes. There must be a corresponding + // fn_protos entry. ZigList fn_defs; + // The function prototypes this module includes. In the case of external declarations, + // there will not be a corresponding fn_defs entry. + ZigList fn_protos; + OutType out_type; FnTableEntry *cur_fn; bool c_stdint_used; diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 69a99590ce..7597ba761a 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -14,13 +14,13 @@ struct TestSourceFile { const char *relative_path; - const char *text; + const char *source_code; }; struct TestCase { const char *case_name; const char *output; - const char *source; + ZigList source_files; ZigList compile_errors; ZigList compiler_args; ZigList program_args; @@ -31,11 +31,20 @@ static const char *tmp_source_path = ".tmp_source.zig"; static const char *tmp_exe_path = "./.tmp_exe"; static const char *zig_exe = "./zig"; -static void add_simple_case(const char *case_name, const char *source, const char *output) { +static void add_source_file(TestCase *test_case, const char *path, const char *source) { + test_case->source_files.add_one(); + test_case->source_files.last().relative_path = path; + test_case->source_files.last().source_code = source; +} + +static TestCase *add_simple_case(const char *case_name, const char *source, const char *output) { TestCase *test_case = allocate(1); test_case->case_name = case_name; test_case->output = output; - test_case->source = source; + + test_case->source_files.resize(1); + test_case->source_files.at(0).relative_path = tmp_source_path; + test_case->source_files.at(0).source_code = source; test_case->compiler_args.append("build"); test_case->compiler_args.append(tmp_source_path); @@ -52,15 +61,19 @@ static void add_simple_case(const char *case_name, const char *source, const cha test_case->compiler_args.append("on"); test_cases.append(test_case); + + return test_case; } -static void add_compile_fail_case(const char *case_name, const char *source, int count, ...) { +static TestCase *add_compile_fail_case(const char *case_name, const char *source, int count, ...) { va_list ap; va_start(ap, count); TestCase *test_case = allocate(1); test_case->case_name = case_name; - test_case->source = source; + test_case->source_files.resize(1); + test_case->source_files.at(0).relative_path = tmp_source_path; + test_case->source_files.at(0).source_code = source; for (int i = 0; i < count; i += 1) { const char *arg = va_arg(ap, const char *); @@ -78,6 +91,8 @@ static void add_compile_fail_case(const char *case_name, const char *source, int test_cases.append(test_case); va_end(ap); + + return test_case; } static void add_compiling_test_cases(void) { @@ -135,6 +150,45 @@ static void add_compiling_test_cases(void) { exit(0); } )SOURCE", "OK\n"); + + { + TestCase *tc = add_simple_case("multiple files with private function", R"SOURCE( + use "libc.zig"; + use "foo.zig"; + + export fn _start() -> unreachable { + private_function(); + } + + fn private_function() -> unreachable { + print_text(); + exit(0); + } + )SOURCE", "OK\n"); + + add_source_file(tc, "libc.zig", R"SOURCE( + #link("c") + extern { + pub fn puts(s: *mut u8) -> i32; + pub fn exit(code: i32) -> unreachable; + } + )SOURCE"); + + add_source_file(tc, "foo.zig", R"SOURCE( + use "libc.zig"; + + // purposefully conflicting function with main source file + // but it's private so it should be OK + fn private_function() { + puts("OK"); + } + + pub fn print_text() { + private_function(); + } + )SOURCE"); + } + } static void add_compile_failure_test_cases(void) { @@ -207,7 +261,12 @@ static void print_compiler_invokation(TestCase *test_case, Buf *zig_stderr) { } static void run_test(TestCase *test_case) { - os_write_file(buf_create_from_str(tmp_source_path), buf_create_from_str(test_case->source)); + for (int i = 0; i < test_case->source_files.length; i += 1) { + TestSourceFile *test_source = &test_case->source_files.at(i); + os_write_file( + buf_create_from_str(test_source->relative_path), + buf_create_from_str(test_source->source_code)); + } Buf zig_stderr = BUF_INIT; Buf zig_stdout = BUF_INIT; @@ -265,6 +324,11 @@ static void run_test(TestCase *test_case) { printf("=======================================\n"); exit(1); } + + for (int i = 0; i < test_case->source_files.length; i += 1) { + TestSourceFile *test_source = &test_case->source_files.at(i); + remove(test_source->relative_path); + } } static void run_all_tests(void) {