diff --git a/src/all_types.hpp b/src/all_types.hpp index 1a01754af5..dbaa3b5467 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -2350,6 +2350,7 @@ struct IrInstructionVarPtr { IrInstruction base; ZigVar *var; + ScopeFnDef *crossed_fndef_scope; }; struct IrInstructionCall { diff --git a/src/analyze.cpp b/src/analyze.cpp index f511e82253..58fe9e7392 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -3519,7 +3519,7 @@ ZigVar *add_variable(CodeGen *g, AstNode *source_node, Scope *parent_scope, Buf if (!type_is_invalid(value->type)) { variable_entry->align_bytes = get_abi_alignment(g, value->type); - ZigVar *existing_var = find_variable(g, parent_scope, name); + ZigVar *existing_var = find_variable(g, parent_scope, name, nullptr); if (existing_var && !existing_var->shadowable) { ErrorMsg *msg = add_node_error(g, source_node, buf_sprintf("redeclaration of variable '%s'", buf_ptr(name))); @@ -3726,12 +3726,16 @@ Tld *find_decl(CodeGen *g, Scope *scope, Buf *name) { return nullptr; } -ZigVar *find_variable(CodeGen *g, Scope *scope, Buf *name) { +ZigVar *find_variable(CodeGen *g, Scope *scope, Buf *name, ScopeFnDef **crossed_fndef_scope) { + ScopeFnDef *my_crossed_fndef_scope = nullptr; while (scope) { if (scope->id == ScopeIdVarDecl) { ScopeVarDecl *var_scope = (ScopeVarDecl *)scope; - if (buf_eql_buf(name, &var_scope->var->name)) + if (buf_eql_buf(name, &var_scope->var->name)) { + if (crossed_fndef_scope != nullptr) + *crossed_fndef_scope = my_crossed_fndef_scope; return var_scope->var; + } } else if (scope->id == ScopeIdDecls) { ScopeDecls *decls_scope = (ScopeDecls *)scope; auto entry = decls_scope->decl_table.maybe_get(name); @@ -3739,10 +3743,15 @@ ZigVar *find_variable(CodeGen *g, Scope *scope, Buf *name) { Tld *tld = entry->value; if (tld->id == TldIdVar) { TldVar *tld_var = (TldVar *)tld; - if (tld_var->var) + if (tld_var->var) { + if (crossed_fndef_scope != nullptr) + *crossed_fndef_scope = nullptr; return tld_var->var; + } } } + } else if (scope->id == ScopeIdFnDef) { + my_crossed_fndef_scope = (ScopeFnDef *)scope; } scope = scope->parent; } diff --git a/src/analyze.hpp b/src/analyze.hpp index f31b9b219d..357312f2dc 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -48,7 +48,7 @@ bool type_has_bits(ZigType *type_entry); ImportTableEntry *add_source_file(CodeGen *g, PackageTableEntry *package, Buf *abs_full_path, Buf *source_code); -ZigVar *find_variable(CodeGen *g, Scope *orig_context, Buf *name); +ZigVar *find_variable(CodeGen *g, Scope *orig_context, Buf *name, ScopeFnDef **crossed_fndef_scope); Tld *find_decl(CodeGen *g, Scope *scope, Buf *name); void resolve_top_level_decl(CodeGen *g, Tld *tld, bool pointer_only, AstNode *source_node); bool type_is_codegen_pointer(ZigType *type); diff --git a/src/ir.cpp b/src/ir.cpp index c9ce0d3be3..4e9c8d2bc7 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -1114,15 +1114,22 @@ static IrInstruction *ir_build_bin_op_from(IrBuilder *irb, IrInstruction *old_in return new_instruction; } -static IrInstruction *ir_build_var_ptr(IrBuilder *irb, Scope *scope, AstNode *source_node, ZigVar *var) { +static IrInstruction *ir_build_var_ptr_x(IrBuilder *irb, Scope *scope, AstNode *source_node, ZigVar *var, + ScopeFnDef *crossed_fndef_scope) +{ IrInstructionVarPtr *instruction = ir_build_instruction(irb, scope, source_node); instruction->var = var; + instruction->crossed_fndef_scope = crossed_fndef_scope; ir_ref_var(var); return &instruction->base; } +static IrInstruction *ir_build_var_ptr(IrBuilder *irb, Scope *scope, AstNode *source_node, ZigVar *var) { + return ir_build_var_ptr_x(irb, scope, source_node, var, nullptr); +} + static IrInstruction *ir_build_elem_ptr(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *array_ptr, IrInstruction *elem_index, bool safety_check_on, PtrLen ptr_len) { @@ -3336,7 +3343,7 @@ static ZigVar *create_local_var(CodeGen *codegen, AstNode *node, Scope *parent_s buf_init_from_buf(&variable_entry->name, name); if (!skip_name_check) { - ZigVar *existing_var = find_variable(codegen, parent_scope, name); + ZigVar *existing_var = find_variable(codegen, parent_scope, name, nullptr); if (existing_var && !existing_var->shadowable) { ErrorMsg *msg = add_node_error(codegen, node, buf_sprintf("redeclaration of variable '%s'", buf_ptr(name))); @@ -3799,9 +3806,10 @@ static IrInstruction *ir_gen_symbol(IrBuilder *irb, Scope *scope, AstNode *node, } } - ZigVar *var = find_variable(irb->codegen, scope, variable_name); + ScopeFnDef *crossed_fndef_scope; + ZigVar *var = find_variable(irb->codegen, scope, variable_name, &crossed_fndef_scope); if (var) { - IrInstruction *var_ptr = ir_build_var_ptr(irb, scope, node, var); + IrInstruction *var_ptr = ir_build_var_ptr_x(irb, scope, node, var, crossed_fndef_scope); if (lval == LValPtr) return var_ptr; else @@ -5822,7 +5830,9 @@ static IrInstruction *ir_gen_asm_expr(IrBuilder *irb, Scope *scope, AstNode *nod output_types[i] = return_type; } else { Buf *variable_name = asm_output->variable_name; - ZigVar *var = find_variable(irb->codegen, scope, variable_name); + // TODO there is some duplication here with ir_gen_symbol. I need to do a full audit of how + // inline assembly works. https://github.com/ziglang/zig/issues/215 + ZigVar *var = find_variable(irb->codegen, scope, variable_name, nullptr); if (var) { output_vars[i] = var; } else { @@ -14157,17 +14167,26 @@ static ZigType *ir_analyze_instruction_phi(IrAnalyze *ira, IrInstructionPhi *phi return resolved_type; } -static ZigType *ir_analyze_var_ptr(IrAnalyze *ira, IrInstruction *instruction, - ZigVar *var) -{ +static ZigType *ir_analyze_var_ptr(IrAnalyze *ira, IrInstruction *instruction, ZigVar *var) { IrInstruction *result = ir_get_var_ptr(ira, instruction, var); ir_link_new_instruction(result, instruction); return result->value.type; } -static ZigType *ir_analyze_instruction_var_ptr(IrAnalyze *ira, IrInstructionVarPtr *var_ptr_instruction) { - ZigVar *var = var_ptr_instruction->var; - return ir_analyze_var_ptr(ira, &var_ptr_instruction->base, var); +static ZigType *ir_analyze_instruction_var_ptr(IrAnalyze *ira, IrInstructionVarPtr *instruction) { + ZigVar *var = instruction->var; + IrInstruction *result = ir_get_var_ptr(ira, &instruction->base, var); + if (instruction->crossed_fndef_scope != nullptr && !instr_is_comptime(result)) { + ErrorMsg *msg = ir_add_error(ira, &instruction->base, + buf_sprintf("'%s' not accessible from inner function", buf_ptr(&var->name))); + add_error_note(ira->codegen, msg, instruction->crossed_fndef_scope->base.source_node, + buf_sprintf("crossed function definition here")); + add_error_note(ira->codegen, msg, var->decl_node, + buf_sprintf("declared here")); + return ira->codegen->builtin_types.entry_invalid; + } + ir_link_new_instruction(result, &instruction->base); + return result->value.type; } static ZigType *adjust_ptr_align(CodeGen *g, ZigType *ptr_type, uint32_t new_align) { diff --git a/test/cases/fn.zig b/test/cases/fn.zig index 47f7d5e688..a862f85b0d 100644 --- a/test/cases/fn.zig +++ b/test/cases/fn.zig @@ -176,3 +176,18 @@ test "pass by non-copying value as method, at comptime" { assert(pt.addPointCoords() == 3); } } + +fn outer(y: u32) fn (u32) u32 { + const Y = @typeOf(y); + const st = struct { + fn get(z: u32) u32 { + return z + @sizeOf(Y); + } + }; + return st.get; +} + +test "return inner function which references comptime variable of outer function" { + var func = outer(10); + assert(func(3) == 7); +} diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 00fc33d122..24d977c218 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -1,6 +1,26 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompileErrorContext) void { + cases.add( + "accessing runtime parameter from outer function", + \\fn outer(y: u32) fn (u32) u32 { + \\ const st = struct { + \\ fn get(z: u32) u32 { + \\ return z + y; + \\ } + \\ }; + \\ return st.get; + \\} + \\export fn entry() void { + \\ var func = outer(10); + \\ var x = func(3); + \\} + , + ".tmp_source.zig:4:24: error: 'y' not accessible from inner function", + ".tmp_source.zig:3:28: note: crossed function definition here", + ".tmp_source.zig:1:10: note: declared here", + ); + cases.add( "non int passed to @intToFloat", \\export fn entry() void {