From 8c39cdc89f2ae7fc25c3856e7c4c6b4662ac8a80 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 3 Jul 2018 21:36:16 -0400 Subject: [PATCH] fix await on early return when return type is struct previously, await on an early return would try to access the destroyed coroutine frame; now it copies the result into a temporary variable before destroying the coroutine frame --- src/ir.cpp | 12 +++---- test/behavior.zig | 1 + test/cases/coroutine_await_struct.zig | 47 +++++++++++++++++++++++++++ test/cases/coroutines.zig | 4 +-- 4 files changed, 54 insertions(+), 10 deletions(-) create mode 100644 test/cases/coroutine_await_struct.zig diff --git a/src/ir.cpp b/src/ir.cpp index 5df5c1d676..b40c2dc36d 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -6674,7 +6674,10 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, Ast } Buf *result_field_name = buf_create_from_str(RESULT_FIELD_NAME); IrInstruction *promise_result_ptr = ir_build_field_ptr(irb, parent_scope, node, coro_promise_ptr, result_field_name); + // If the type of the result handle_is_ptr then this does not actually perform a load. But we need it to, + // because we're about to destroy the memory. So we store it into our result variable. IrInstruction *no_suspend_result = ir_build_load_ptr(irb, parent_scope, node, promise_result_ptr); + ir_build_store_ptr(irb, parent_scope, node, my_result_var_ptr, no_suspend_result); ir_build_cancel(irb, parent_scope, node, target_inst); ir_build_br(irb, parent_scope, node, merge_block, const_bool_false); @@ -6696,17 +6699,10 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, Ast ir_mark_gen(ir_build_br(irb, parent_scope, node, irb->exec->coro_final_cleanup_block, const_bool_false)); ir_set_cursor_at_end_and_append_block(irb, resume_block); - IrInstruction *yes_suspend_result = ir_build_load_ptr(irb, parent_scope, node, my_result_var_ptr); ir_build_br(irb, parent_scope, node, merge_block, const_bool_false); ir_set_cursor_at_end_and_append_block(irb, merge_block); - IrBasicBlock **incoming_blocks = allocate(2); - IrInstruction **incoming_values = allocate(2); - incoming_blocks[0] = resume_block; - incoming_values[0] = yes_suspend_result; - incoming_blocks[1] = no_suspend_block; - incoming_values[1] = no_suspend_result; - return ir_build_phi(irb, parent_scope, node, 2, incoming_blocks, incoming_values); + return ir_build_load_ptr(irb, parent_scope, node, my_result_var_ptr); } static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNode *node) { diff --git a/test/behavior.zig b/test/behavior.zig index 3766ed4305..d47eb8fd6c 100644 --- a/test/behavior.zig +++ b/test/behavior.zig @@ -18,6 +18,7 @@ comptime { _ = @import("cases/cast.zig"); _ = @import("cases/const_slice_child.zig"); _ = @import("cases/coroutines.zig"); + _ = @import("cases/coroutine_await_struct.zig"); _ = @import("cases/defer.zig"); _ = @import("cases/enum.zig"); _ = @import("cases/enum_with_members.zig"); diff --git a/test/cases/coroutine_await_struct.zig b/test/cases/coroutine_await_struct.zig new file mode 100644 index 0000000000..56c526092d --- /dev/null +++ b/test/cases/coroutine_await_struct.zig @@ -0,0 +1,47 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const assert = std.debug.assert; + +const Foo = struct { + x: i32, +}; + +var await_a_promise: promise = undefined; +var await_final_result = Foo{ .x = 0 }; + +test "coroutine await struct" { + var da = std.heap.DirectAllocator.init(); + defer da.deinit(); + + await_seq('a'); + const p = async<&da.allocator> await_amain() catch unreachable; + await_seq('f'); + resume await_a_promise; + await_seq('i'); + assert(await_final_result.x == 1234); + assert(std.mem.eql(u8, await_points, "abcdefghi")); +} +async fn await_amain() void { + await_seq('b'); + const p = async await_another() catch unreachable; + await_seq('e'); + await_final_result = await p; + await_seq('h'); +} +async fn await_another() Foo { + await_seq('c'); + suspend |p| { + await_seq('d'); + await_a_promise = p; + } + await_seq('g'); + return Foo{ .x = 1234 }; +} + +var await_points = []u8{0} ** "abcdefghi".len; +var await_seq_index: usize = 0; + +fn await_seq(c: u8) void { + await_points[await_seq_index] = c; + await_seq_index += 1; +} diff --git a/test/cases/coroutines.zig b/test/cases/coroutines.zig index b3899b306b..f7f2af62a6 100644 --- a/test/cases/coroutines.zig +++ b/test/cases/coroutines.zig @@ -116,14 +116,14 @@ test "coroutine await early return" { defer da.deinit(); early_seq('a'); - const p = async<&da.allocator> early_amain() catch unreachable; + const p = async<&da.allocator> early_amain() catch @panic("out of memory"); early_seq('f'); assert(early_final_result == 1234); assert(std.mem.eql(u8, early_points, "abcdef")); } async fn early_amain() void { early_seq('b'); - const p = async early_another() catch unreachable; + const p = async early_another() catch @panic("out of memory"); early_seq('d'); early_final_result = await p; early_seq('e');