diff --git a/src/all_types.hpp b/src/all_types.hpp index 503d45fd9b..72ec860556 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -58,8 +58,9 @@ struct IrExecutable { ZigList tld_list; IrInstruction *coro_handle; - IrInstruction *coro_awaiter_field_ptr; + IrInstruction *coro_awaiter_field_ptr; // this one is shared and in the promise IrInstruction *coro_result_ptr_field_ptr; + IrInstruction *await_handle_var_ptr; // this one is where we put the one we extracted from the promise IrBasicBlock *coro_early_final; IrBasicBlock *coro_normal_final; IrBasicBlock *coro_suspend_block; diff --git a/src/ir.cpp b/src/ir.cpp index 51c75ca23b..9a01b152d8 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -2752,6 +2752,7 @@ static IrInstruction *ir_gen_async_return(IrBuilder *irb, Scope *scope, AstNode IrInstruction *maybe_await_handle = ir_build_atomic_rmw(irb, scope, node, promise_type_val, irb->exec->coro_awaiter_field_ptr, nullptr, replacement_value, nullptr, AtomicRmwOp_xchg, AtomicOrderSeqCst); + ir_build_store_ptr(irb, scope, node, irb->exec->await_handle_var_ptr, maybe_await_handle); IrInstruction *is_non_null = ir_build_test_nonnull(irb, scope, node, maybe_await_handle); IrInstruction *is_comptime = ir_build_const_bool(irb, scope, node, false); return ir_build_cond_br(irb, scope, node, is_non_null, irb->exec->coro_normal_final, irb->exec->coro_early_final, @@ -6020,7 +6021,6 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, Ast ir_build_br(irb, parent_scope, node, merge_block, const_bool_false); ir_set_cursor_at_end_and_append_block(irb, yes_suspend_block); - ir_build_coro_resume(irb, parent_scope, node, target_inst); IrInstruction *suspend_code = ir_build_coro_suspend(irb, parent_scope, node, save_token, const_bool_false); IrBasicBlock *cleanup_block = ir_create_basic_block(irb, parent_scope, "SuspendCleanup"); IrBasicBlock *resume_block = ir_create_basic_block(irb, parent_scope, "SuspendResume"); @@ -6277,13 +6277,20 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec // create the coro promise const_bool_false = ir_build_const_bool(irb, scope, node, false); VariableTableEntry *promise_var = ir_create_var(irb, node, scope, nullptr, false, false, true, const_bool_false); - //scope = promise_var->child_scope; return_type = fn_entry->type_entry->data.fn.fn_type_id.return_type; IrInstruction *promise_init = ir_build_const_promise_init(irb, scope, node, return_type); ir_build_var_decl(irb, scope, node, promise_var, nullptr, nullptr, promise_init); IrInstruction *coro_promise_ptr = ir_build_var_ptr(irb, scope, node, promise_var, false, false); + VariableTableEntry *await_handle_var = ir_create_var(irb, node, scope, nullptr, false, false, true, const_bool_false); + IrInstruction *null_value = ir_build_const_null(irb, scope, node); + IrInstruction *await_handle_type_val = ir_build_const_type(irb, scope, node, + get_maybe_type(irb->codegen, irb->codegen->builtin_types.entry_promise)); + ir_build_var_decl(irb, scope, node, await_handle_var, await_handle_type_val, nullptr, null_value); + irb->exec->await_handle_var_ptr = ir_build_var_ptr(irb, scope, node, + await_handle_var, false, false); + u8_ptr_type = ir_build_const_type(irb, scope, node, get_pointer_to_type(irb->codegen, irb->codegen->builtin_types.entry_u8, false)); IrInstruction *promise_as_u8_ptr = ir_build_ptr_cast(irb, scope, node, u8_ptr_type, coro_promise_ptr); @@ -6409,7 +6416,7 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec ir_set_cursor_at_end_and_append_block(irb, resume_block); IrInstruction *unwrapped_await_handle_ptr = ir_build_unwrap_maybe(irb, scope, node, - irb->exec->coro_awaiter_field_ptr, false); + irb->exec->await_handle_var_ptr, false); IrInstruction *awaiter_handle = ir_build_load_ptr(irb, scope, node, unwrapped_await_handle_ptr); ir_build_coro_resume(irb, scope, node, awaiter_handle); ir_build_br(irb, scope, node, irb->exec->coro_suspend_block, const_bool_false); diff --git a/test/cases/coroutines.zig b/test/cases/coroutines.zig index 8f1909a64f..fa32cd8ce9 100644 --- a/test/cases/coroutines.zig +++ b/test/cases/coroutines.zig @@ -59,3 +59,42 @@ async fn testSuspendBlock() void { } result = true; } + +var await_a_promise: promise = undefined; +var await_final_result: i32 = 0; + +test "coroutine await" { + await_seq('a'); + const p = async(std.debug.global_allocator) await_amain() catch unreachable; + await_seq('f'); + resume await_a_promise; + await_seq('i'); + assert(await_final_result == 1234); + assert(std.mem.eql(u8, await_points, "abcdefghi")); +} + +async fn await_amain() void { + await_seq('b'); + const p = async await_another() catch unreachable; + await_seq('e'); + await_final_result = await p; + await_seq('h'); +} + +async fn await_another() i32 { + await_seq('c'); + suspend |p| { + await_seq('d'); + await_a_promise = p; + } + await_seq('g'); + return 1234; +} + +var await_points = []u8{0} ** "abcdefghi".len; +var await_seq_index: usize = 0; + +fn await_seq(c: u8) void { + await_points[await_seq_index] = c; + await_seq_index += 1; +}