diff --git a/src/analyze.cpp b/src/analyze.cpp index 7d99ca92da..d13b62bcfd 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -6367,7 +6367,9 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) { IrInstGen *instruction = block->instruction_list.at(instr_i); if (instruction->id == IrInstGenIdAwait || instruction->id == IrInstGenIdVarPtr || - instruction->id == IrInstGenIdAlloca) + instruction->id == IrInstGenIdAlloca || + instruction->id == IrInstGenIdSpillBegin || + instruction->id == IrInstGenIdSpillEnd) { // This instruction does its own spilling specially, or otherwise doesn't need it. continue; diff --git a/src/codegen.cpp b/src/codegen.cpp index 15a3ca7e95..5a831819a5 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2561,7 +2561,12 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir LLVMBuildRet(g->builder, by_val_value); } } else if (instruction->operand == nullptr) { - LLVMBuildRetVoid(g->builder); + if (g->cur_ret_ptr == nullptr) { + LLVMBuildRetVoid(g->builder); + } else { + LLVMValueRef by_val_value = gen_load_untyped(g, g->cur_ret_ptr, 0, false, ""); + LLVMBuildRet(g->builder, by_val_value); + } } else { LLVMValueRef value = ir_llvm_value(g, instruction->operand); LLVMBuildRet(g->builder, value); @@ -5715,18 +5720,24 @@ static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutableGen *ex bool want_safety = instruction->safety_check_on && ir_want_runtime_safety(g, &instruction->base) && g->errors_by_index.length > 1; - bool value_has_bits; - if ((err = type_has_bits2(g, instruction->base.value->type, &value_has_bits))) - codegen_report_errors_and_exit(g); - - if (!want_safety && !value_has_bits) - return nullptr; - ZigType *ptr_type = instruction->value->value->type; assert(ptr_type->id == ZigTypeIdPointer); ZigType *err_union_type = ptr_type->data.pointer.child_type; ZigType *payload_type = err_union_type->data.error_union.payload_type; LLVMValueRef err_union_ptr = ir_llvm_value(g, instruction->value); + + LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, g->err_tag_type)); + bool value_has_bits; + if ((err = type_has_bits2(g, instruction->base.value->type, &value_has_bits))) + codegen_report_errors_and_exit(g); + if (!want_safety && !value_has_bits) { + if (instruction->initializing) { + gen_store_untyped(g, zero, err_union_ptr, 0, false); + } + return nullptr; + } + + LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type); if (!type_has_bits(err_union_type->data.error_union.err_set_type)) { @@ -5741,7 +5752,6 @@ static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutableGen *ex } else { err_val = err_union_handle; } - LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, g->err_tag_type)); LLVMValueRef cond_val = LLVMBuildICmp(g->builder, LLVMIntEQ, err_val, zero, ""); LLVMBasicBlockRef err_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnwrapErrError"); LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnwrapErrOk"); @@ -5761,6 +5771,9 @@ static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutableGen *ex } return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_payload_index, ""); } else { + if (instruction->initializing) { + gen_store_untyped(g, zero, err_union_ptr, 0, false); + } return nullptr; } } diff --git a/src/ir.cpp b/src/ir.cpp index 2f0c52347c..82fce395b0 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -5252,6 +5252,7 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node, return irb->codegen->invalid_inst_src; } else { return_value = ir_build_const_void(irb, scope, node); + ir_build_end_expr(irb, scope, node, return_value, &result_loc_ret->base); } ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, return_value, result_loc_ret)); @@ -5262,7 +5263,7 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node, if (!have_err_defers && !irb->codegen->have_err_ret_tracing) { // only generate unconditional defers ir_gen_defers_for_block(irb, scope, outer_scope, false); - IrInstSrc *result = ir_build_return_src(irb, scope, node, return_value); + IrInstSrc *result = ir_build_return_src(irb, scope, node, nullptr); result_loc_ret->base.source_instruction = result; return result; } @@ -5271,10 +5272,6 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node, IrBasicBlockSrc *err_block = ir_create_basic_block(irb, scope, "ErrRetErr"); IrBasicBlockSrc *ok_block = ir_create_basic_block(irb, scope, "ErrRetOk"); - if (!have_err_defers) { - ir_gen_defers_for_block(irb, scope, outer_scope, false); - } - IrInstSrc *is_err = ir_build_test_err_src(irb, scope, node, return_value, false, true); IrInstSrc *is_comptime; @@ -5288,22 +5285,18 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node, IrBasicBlockSrc *ret_stmt_block = ir_create_basic_block(irb, scope, "RetStmt"); ir_set_cursor_at_end_and_append_block(irb, err_block); - if (have_err_defers) { - ir_gen_defers_for_block(irb, scope, outer_scope, true); - } + ir_gen_defers_for_block(irb, scope, outer_scope, true); if (irb->codegen->have_err_ret_tracing && !should_inline) { ir_build_save_err_ret_addr_src(irb, scope, node); } ir_build_br(irb, scope, node, ret_stmt_block, is_comptime); ir_set_cursor_at_end_and_append_block(irb, ok_block); - if (have_err_defers) { - ir_gen_defers_for_block(irb, scope, outer_scope, false); - } + ir_gen_defers_for_block(irb, scope, outer_scope, false); ir_build_br(irb, scope, node, ret_stmt_block, is_comptime); ir_set_cursor_at_end_and_append_block(irb, ret_stmt_block); - IrInstSrc *result = ir_build_return_src(irb, scope, node, return_value); + IrInstSrc *result = ir_build_return_src(irb, scope, node, nullptr); result_loc_ret->base.source_instruction = result; return result; } @@ -9622,7 +9615,10 @@ static IrInstSrc *ir_gen_catch(IrBuilderSrc *irb, Scope *parent_scope, AstNode * } - IrInstSrc *err_union_ptr = ir_gen_node_extra(irb, op1_node, parent_scope, LValPtr, nullptr); + ScopeExpr *spill_scope = create_expr_scope(irb->codegen, op1_node, parent_scope); + spill_scope->spill_harder = true; + + IrInstSrc *err_union_ptr = ir_gen_node_extra(irb, op1_node, &spill_scope->base, LValPtr, nullptr); if (err_union_ptr == irb->codegen->invalid_inst_src) return irb->codegen->invalid_inst_src; @@ -9644,7 +9640,7 @@ static IrInstSrc *ir_gen_catch(IrBuilderSrc *irb, Scope *parent_scope, AstNode * is_comptime); ir_set_cursor_at_end_and_append_block(irb, err_block); - Scope *subexpr_scope = create_runtime_scope(irb->codegen, node, parent_scope, is_comptime); + Scope *subexpr_scope = create_runtime_scope(irb->codegen, node, &spill_scope->base, is_comptime); Scope *err_scope; if (var_node) { assert(var_node->type == NodeTypeSymbol); @@ -15497,6 +15493,12 @@ static IrInstGen *ir_analyze_instruction_add_implicit_return_type(IrAnalyze *ira } static IrInstGen *ir_analyze_instruction_return(IrAnalyze *ira, IrInstSrcReturn *instruction) { + if (instruction->operand == nullptr) { + // result location mechanism took care of it. + IrInstGen *result = ir_build_return_gen(ira, &instruction->base.base, nullptr); + return ir_finish_anal(ira, result); + } + IrInstGen *operand = instruction->operand->child; if (type_is_invalid(operand->value->type)) return ir_unreach_error(ira); @@ -29551,8 +29553,13 @@ static IrInstGen *ir_analyze_instruction_spill_begin(IrAnalyze *ira, IrInstSrcSp if (!type_has_bits(operand->value->type)) return ir_const_void(ira, &instruction->base.base); - ir_assert(instruction->spill_id == SpillIdRetErrCode, &instruction->base.base); - ira->new_irb.exec->need_err_code_spill = true; + switch (instruction->spill_id) { + case SpillIdInvalid: + zig_unreachable(); + case SpillIdRetErrCode: + ira->new_irb.exec->need_err_code_spill = true; + break; + } return ir_build_spill_begin_gen(ira, &instruction->base.base, operand, instruction->spill_id); } @@ -29562,8 +29569,12 @@ static IrInstGen *ir_analyze_instruction_spill_end(IrAnalyze *ira, IrInstSrcSpil if (type_is_invalid(operand->value->type)) return ira->codegen->invalid_inst_gen; - if (ir_should_inline(ira->old_irb.exec, instruction->base.base.scope) || !type_has_bits(operand->value->type)) + if (ir_should_inline(ira->old_irb.exec, instruction->base.base.scope) || + !type_has_bits(operand->value->type) || + instr_is_comptime(operand)) + { return operand; + } ir_assert(instruction->begin->base.child->id == IrInstGenIdSpillBegin, &instruction->base.base); IrInstGenSpillBegin *begin = reinterpret_cast(instruction->begin->base.child); diff --git a/test/stage1/behavior/async_fn.zig b/test/stage1/behavior/async_fn.zig index fb17d5740c..dcd5102ff8 100644 --- a/test/stage1/behavior/async_fn.zig +++ b/test/stage1/behavior/async_fn.zig @@ -2,6 +2,7 @@ const std = @import("std"); const builtin = @import("builtin"); const expect = std.testing.expect; const expectEqual = std.testing.expectEqual; +const expectError = std.testing.expectError; var global_x: i32 = 1; @@ -1440,3 +1441,43 @@ test "properly spill optional payload capture value" { resume S.global_frame; expect(S.global_int == 1237); } + +test "handle defer interfering with return value spill" { + const S = struct { + var global_frame1: anyframe = undefined; + var global_frame2: anyframe = undefined; + var finished = false; + var baz_happened = false; + + fn doTheTest() void { + _ = async testFoo(); + resume global_frame1; + resume global_frame2; + expect(baz_happened); + expect(finished); + } + + fn testFoo() void { + expectError(error.Bad, foo()); + finished = true; + } + + fn foo() anyerror!void { + defer baz(); + return bar() catch |err| return err; + } + + fn bar() anyerror!void { + global_frame1 = @frame(); + suspend; + return error.Bad; + } + + fn baz() void { + global_frame2 = @frame(); + suspend; + baz_happened = true; + } + }; + S.doTheTest(); +}