diff --git a/src/all_types.hpp b/src/all_types.hpp index cf41444f0b..0b03388502 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1725,6 +1725,7 @@ struct CodeGen { LLVMValueRef cur_async_resume_index_ptr; LLVMValueRef cur_async_awaiter_ptr; LLVMValueRef cur_async_prev_val; + LLVMValueRef cur_async_prev_val_field_ptr; LLVMBasicBlockRef cur_preamble_llvm_block; size_t cur_resume_block_count; LLVMValueRef cur_err_ret_trace_val_arg; @@ -1886,6 +1887,7 @@ struct CodeGen { bool system_linker_hack; bool reported_bad_link_libc_error; bool is_dynamic; // shared library rather than static library. dynamic musl rather than static musl. + bool cur_is_after_return; //////////////////////////// Participates in Input Parameter Cache Hash /////// Note: there is a separate cache hash for builtin.zig, when adding fields, @@ -3639,8 +3641,6 @@ struct IrInstructionCoroResume { struct IrInstructionTestCancelRequested { IrInstruction base; - - bool use_return_begin_prev_value; }; enum ResultLocId { @@ -3730,7 +3730,8 @@ static const size_t err_union_payload_index = 1; static const size_t coro_fn_ptr_index = 0; static const size_t coro_resume_index = 1; static const size_t coro_awaiter_index = 2; -static const size_t coro_ret_start = 3; +static const size_t coro_prev_val_index = 3; +static const size_t coro_ret_start = 4; // TODO call graph analysis to find out what this number needs to be for every function // MUST BE A POWER OF TWO. diff --git a/src/analyze.cpp b/src/analyze.cpp index cc90573f41..a09ba582c9 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5246,6 +5246,9 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) { field_names.append("@awaiter"); field_types.append(g->builtin_types.entry_usize); + field_names.append("@prev_val"); + field_types.append(g->builtin_types.entry_usize); + FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false); field_names.append("@result_ptr_callee"); @@ -7592,6 +7595,7 @@ static void resolve_llvm_types_any_frame(CodeGen *g, ZigType *any_frame_type, Re field_types.append(ptr_fn_llvm_type); // fn_ptr field_types.append(usize_type_ref); // resume_index field_types.append(usize_type_ref); // awaiter + field_types.append(usize_type_ref); // prev_val bool have_result_type = result_type != nullptr && type_has_bits(result_type); if (have_result_type) { diff --git a/src/codegen.cpp b/src/codegen.cpp index 46cd8e9fcf..5a8fd3e9ca 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2226,7 +2226,18 @@ static LLVMValueRef gen_resume(CodeGen *g, LLVMValueRef fn_val, LLVMValueRef tar return ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, ""); } +static LLVMValueRef get_cur_async_prev_val(CodeGen *g) { + if (g->cur_async_prev_val != nullptr) { + return g->cur_async_prev_val; + } + g->cur_async_prev_val = LLVMBuildLoad(g->builder, g->cur_async_prev_val_field_ptr, ""); + return g->cur_async_prev_val; +} + static LLVMBasicBlockRef gen_suspend_begin(CodeGen *g, const char *name_hint) { + // This becomes invalid when a suspend happens. + g->cur_async_prev_val = nullptr; + LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, name_hint); size_t new_block_index = g->cur_resume_block_count; @@ -2319,6 +2330,9 @@ static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable, LLVMBasicBlockRef incoming_blocks[] = { after_resume_block, switch_bb }; LLVMAddIncoming(g->cur_async_prev_val, incoming_values, incoming_blocks, 2); + g->cur_is_after_return = true; + LLVMBuildStore(g->builder, g->cur_async_prev_val, g->cur_async_prev_val_field_ptr); + if (!ret_type_has_bits) { return nullptr; } @@ -2366,7 +2380,7 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns ZigType *any_frame_type = get_any_frame_type(g, ret_type); LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false); LLVMValueRef mask_val = LLVMConstNot(one); - LLVMValueRef masked_prev_val = LLVMBuildAnd(g->builder, g->cur_async_prev_val, mask_val, ""); + LLVMValueRef masked_prev_val = LLVMBuildAnd(g->builder, get_cur_async_prev_val(g), mask_val, ""); LLVMValueRef their_frame_ptr = LLVMBuildIntToPtr(g->builder, masked_prev_val, get_llvm_type(g, any_frame_type), ""); LLVMValueRef call_inst = gen_resume(g, nullptr, their_frame_ptr, ResumeIdReturn, nullptr); @@ -5590,8 +5604,8 @@ static LLVMValueRef ir_render_test_cancel_requested(CodeGen *g, IrExecutable *ex { if (!fn_is_async(g->cur_fn)) return LLVMConstInt(LLVMInt1Type(), 0, false); - if (instruction->use_return_begin_prev_value) { - return LLVMBuildTrunc(g->builder, g->cur_async_prev_val, LLVMInt1Type(), ""); + if (g->cur_is_after_return) { + return LLVMBuildTrunc(g->builder, get_cur_async_prev_val(g), LLVMInt1Type(), ""); } else { zig_panic("TODO"); } @@ -7063,6 +7077,7 @@ static void do_code_gen(CodeGen *g) { } if (is_async) { + g->cur_is_after_return = false; g->cur_resume_block_count = 0; LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; @@ -7099,6 +7114,8 @@ static void do_code_gen(CodeGen *g) { g->cur_err_ret_trace_val_stack = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, trace_field_index_stack, ""); } + g->cur_async_prev_val_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, + coro_prev_val_index, ""); LLVMValueRef resume_index = LLVMBuildLoad(g->builder, resume_index_ptr, ""); LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, resume_index, bad_resume_block, 4); diff --git a/src/ir.cpp b/src/ir.cpp index 5fc31db3ef..4dcfaa6cce 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -3325,12 +3325,9 @@ static IrInstruction *ir_build_coro_resume(IrBuilder *irb, Scope *scope, AstNode return &instruction->base; } -static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scope, AstNode *source_node, - bool use_return_begin_prev_value) -{ +static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scope, AstNode *source_node) { IrInstructionTestCancelRequested *instruction = ir_build_instruction(irb, scope, source_node); instruction->base.value.type = irb->codegen->builtin_types.entry_bool; - instruction->use_return_begin_prev_value = use_return_begin_prev_value; return &instruction->base; } @@ -3546,7 +3543,7 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node, if (need_test_cancel) { ir_set_cursor_at_end_and_append_block(irb, ok_block); - IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, scope, node, true); + IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, scope, node); ir_mark_gen(ir_build_cond_br(irb, scope, node, is_canceled, all_defers_block, normal_defers_block, force_comptime)); } @@ -3830,7 +3827,7 @@ static IrInstruction *ir_gen_block(IrBuilder *irb, Scope *parent_scope, AstNode ir_gen_defers_for_block(irb, child_scope, outer_block_scope, false); return ir_mark_gen(ir_build_return(irb, child_scope, result->source_node, result)); } - IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, child_scope, block_node, true); + IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, child_scope, block_node); IrBasicBlock *all_defers_block = ir_create_basic_block(irb, child_scope, "ErrDefers"); IrBasicBlock *normal_defers_block = ir_create_basic_block(irb, child_scope, "Defers"); IrBasicBlock *ret_stmt_block = ir_create_basic_block(irb, child_scope, "RetStmt"); @@ -24725,8 +24722,7 @@ static IrInstruction *ir_analyze_instruction_test_cancel_requested(IrAnalyze *ir if (ir_should_inline(ira->new_irb.exec, instruction->base.scope)) { return ir_const_bool(ira, &instruction->base, false); } - return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node, - instruction->use_return_begin_prev_value); + return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node); } static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) { diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 8b8445f625..8c90eb02f3 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1551,8 +1551,7 @@ static void ir_print_await_gen(IrPrint *irp, IrInstructionAwaitGen *instruction) } static void ir_print_test_cancel_requested(IrPrint *irp, IrInstructionTestCancelRequested *instruction) { - const char *arg = instruction->use_return_begin_prev_value ? "UseReturnBeginPrevValue" : "AdditionalCheck"; - fprintf(irp->f, "@testCancelRequested(%s)", arg); + fprintf(irp->f, "@testCancelRequested()"); } static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index 57706c2455..c2b95e8559 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -318,7 +318,7 @@ test "@asyncCall with return type" { } }; var foo = Foo{ .bar = Foo.middle }; - var bytes: [100]u8 = undefined; + var bytes: [150]u8 = undefined; var aresult: i32 = 0; _ = @asyncCall(&bytes, &aresult, foo.bar); expect(aresult == 0); @@ -589,3 +589,27 @@ test "pass string literal to async function" { }; S.doTheTest(); } + +test "cancel inside an errdefer" { + const S = struct { + var frame: anyframe = undefined; + + fn doTheTest() void { + _ = async amainWrap(); + resume frame; + } + + fn amainWrap() !void { + var foo = async func(); + errdefer cancel foo; + return error.Bad; + } + + fn func() void { + frame = @frame(); + suspend; + } + + }; + S.doTheTest(); +}