diff --git a/doc/langref.html.in b/doc/langref.html.in index 7fde550338..54677bc5b5 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -4690,9 +4690,9 @@ test "coroutine suspend with block" { var a_promise: promise = undefined; var result = false; async fn testSuspendBlock() void { - suspend |p| { - comptime assert(@typeOf(p) == promise->void); - a_promise = p; + suspend { + comptime assert(@typeOf(@handle()) == promise->void); + a_promise = @handle(); } result = true; } @@ -4733,8 +4733,8 @@ test "resume from suspend" { std.debug.assert(my_result == 2); } async fn testResumeFromSuspend(my_result: *i32) void { - suspend |p| { - resume p; + suspend { + resume @handle(); } my_result.* += 1; suspend; @@ -4791,9 +4791,9 @@ async fn amain() void { } async fn another() i32 { seq('c'); - suspend |p| { + suspend { seq('d'); - a_promise = p; + a_promise = @handle(); } seq('g'); return 1234; @@ -5383,6 +5383,16 @@ test "main" { This function is only valid within function scope.

{#header_close#} + {#header_open|@handle#} +
@handle()
+

+ This function returns a promise->T type, where T + is the return type of the async function in scope. +

+

+ This function is only valid within an async function scope. +

+ {#header_close#} {#header_open|@import#}
@import(comptime path: []u8) (namespace)

@@ -7388,7 +7398,7 @@ Defer(body) = ("defer" | "deferror") body IfExpression(body) = "if" "(" Expression ")" body option("else" BlockExpression(body)) -SuspendExpression(body) = "suspend" option(("|" Symbol "|" body)) +SuspendExpression(body) = "suspend" option( body ) IfErrorExpression(body) = "if" "(" Expression ")" option("|" option("*") Symbol "|") body "else" "|" Symbol "|" BlockExpression(body) diff --git a/src/all_types.hpp b/src/all_types.hpp index 2f09e70301..b1e8a3746d 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -899,7 +899,6 @@ struct AstNodeAwaitExpr { struct AstNodeSuspend { AstNode *block; - AstNode *promise_symbol; }; struct AstNodePromiseType { @@ -1358,6 +1357,7 @@ enum BuiltinFnId { BuiltinFnIdBreakpoint, BuiltinFnIdReturnAddress, BuiltinFnIdFrameAddress, + BuiltinFnIdHandle, BuiltinFnIdEmbedFile, BuiltinFnIdCmpxchgWeak, BuiltinFnIdCmpxchgStrong, @@ -1716,6 +1716,7 @@ struct CodeGen { LLVMValueRef coro_save_fn_val; LLVMValueRef coro_promise_fn_val; LLVMValueRef coro_alloc_helper_fn_val; + LLVMValueRef coro_frame_fn_val; LLVMValueRef merge_err_ret_traces_fn_val; LLVMValueRef add_error_return_trace_addr_fn_val; LLVMValueRef stacksave_fn_val; @@ -2076,6 +2077,7 @@ enum IrInstructionId { IrInstructionIdBreakpoint, IrInstructionIdReturnAddress, IrInstructionIdFrameAddress, + IrInstructionIdHandle, IrInstructionIdAlignOf, IrInstructionIdOverflowOp, IrInstructionIdTestErr, @@ -2793,6 +2795,10 @@ struct IrInstructionFrameAddress { IrInstruction base; }; +struct IrInstructionHandle { + IrInstruction base; +}; + enum IrOverflowOp { IrOverflowOpAdd, IrOverflowOpSub, diff --git a/src/ast_render.cpp b/src/ast_render.cpp index 2ace00885d..984b4230b1 100644 --- a/src/ast_render.cpp +++ b/src/ast_render.cpp @@ -1112,9 +1112,6 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) { { fprintf(ar->f, "suspend"); if (node->data.suspend.block != nullptr) { - fprintf(ar->f, " |"); - render_node_grouped(ar, node->data.suspend.promise_symbol); - fprintf(ar->f, "| "); render_node_grouped(ar, node->data.suspend.block); } break; diff --git a/src/codegen.cpp b/src/codegen.cpp index 7420da9797..539356ef2f 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -4146,6 +4146,26 @@ static LLVMValueRef ir_render_frame_address(CodeGen *g, IrExecutable *executable return LLVMBuildCall(g->builder, get_frame_address_fn_val(g), &zero, 1, ""); } +static LLVMValueRef get_handle_fn_val(CodeGen *g) { + if (g->coro_frame_fn_val) + return g->coro_frame_fn_val; + + LLVMTypeRef fn_type = LLVMFunctionType( LLVMPointerType(LLVMInt8Type(), 0) + , nullptr, 0, false); + Buf *name = buf_sprintf("llvm.coro.frame"); + g->coro_frame_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type); + assert(LLVMGetIntrinsicID(g->coro_frame_fn_val)); + + return g->coro_frame_fn_val; +} + +static LLVMValueRef ir_render_handle(CodeGen *g, IrExecutable *executable, + IrInstructionHandle *instruction) +{ + LLVMValueRef zero = LLVMConstNull(g->builtin_types.entry_promise->type_ref); + return LLVMBuildCall(g->builder, get_handle_fn_val(g), &zero, 0, ""); +} + static LLVMValueRef render_shl_with_overflow(CodeGen *g, IrInstructionOverflowOp *instruction) { TypeTableEntry *int_type = instruction->result_ptr_type; assert(int_type->id == TypeTableEntryIdInt); @@ -4910,6 +4930,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_return_address(g, executable, (IrInstructionReturnAddress *)instruction); case IrInstructionIdFrameAddress: return ir_render_frame_address(g, executable, (IrInstructionFrameAddress *)instruction); + case IrInstructionIdHandle: + return ir_render_handle(g, executable, (IrInstructionHandle *)instruction); case IrInstructionIdOverflowOp: return ir_render_overflow_op(g, executable, (IrInstructionOverflowOp *)instruction); case IrInstructionIdTestErr: @@ -6005,6 +6027,7 @@ static void do_code_gen(CodeGen *g) { ir_render(g, fn_table_entry); } + assert(!g->errors.length); if (buf_len(&g->global_asm) != 0) { @@ -6344,6 +6367,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdBreakpoint, "breakpoint", 0); create_builtin_fn(g, BuiltinFnIdReturnAddress, "returnAddress", 0); create_builtin_fn(g, BuiltinFnIdFrameAddress, "frameAddress", 0); + create_builtin_fn(g, BuiltinFnIdHandle, "handle", 0); create_builtin_fn(g, BuiltinFnIdMemcpy, "memcpy", 3); create_builtin_fn(g, BuiltinFnIdMemset, "memset", 3); create_builtin_fn(g, BuiltinFnIdSizeof, "sizeOf", 1); diff --git a/src/ir.cpp b/src/ir.cpp index 699baa152e..7d2881744d 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -580,6 +580,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionFrameAddress *) return IrInstructionIdFrameAddress; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionHandle *) { + return IrInstructionIdHandle; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionAlignOf *) { return IrInstructionIdAlignOf; } @@ -2240,6 +2244,17 @@ static IrInstruction *ir_build_frame_address_from(IrBuilder *irb, IrInstruction return new_instruction; } +static IrInstruction *ir_build_handle(IrBuilder *irb, Scope *scope, AstNode *source_node) { + IrInstructionHandle *instruction = ir_build_instruction(irb, scope, source_node); + return &instruction->base; +} + +static IrInstruction *ir_build_handle_from(IrBuilder *irb, IrInstruction *old_instruction) { + IrInstruction *new_instruction = ir_build_handle(irb, old_instruction->scope, old_instruction->source_node); + ir_link_new_instruction(new_instruction, old_instruction); + return new_instruction; +} + static IrInstruction *ir_build_overflow_op(IrBuilder *irb, Scope *scope, AstNode *source_node, IrOverflowOp op, IrInstruction *type_value, IrInstruction *op1, IrInstruction *op2, IrInstruction *result_ptr, TypeTableEntry *result_ptr_type) @@ -3843,6 +3858,8 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo return irb->codegen->invalid_instruction; } + bool is_async = exec_is_async(irb->exec); + switch (builtin_fn->id) { case BuiltinFnIdInvalid: zig_unreachable(); @@ -4475,6 +4492,16 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo return ir_lval_wrap(irb, scope, ir_build_return_address(irb, scope, node), lval); case BuiltinFnIdFrameAddress: return ir_lval_wrap(irb, scope, ir_build_frame_address(irb, scope, node), lval); + case BuiltinFnIdHandle: + if (!irb->exec->fn_entry) { + add_node_error(irb->codegen, node, buf_sprintf("@handle() called outside of function definition")); + return irb->codegen->invalid_instruction; + } + if (!is_async) { + add_node_error(irb->codegen, node, buf_sprintf("@handle() in non-async function")); + return irb->codegen->invalid_instruction; + } + return ir_lval_wrap(irb, scope, ir_build_handle(irb, scope, node), lval); case BuiltinFnIdAlignOf: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); @@ -7069,19 +7096,8 @@ static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNod if (node->data.suspend.block == nullptr) { suspend_code = ir_build_coro_suspend(irb, parent_scope, node, nullptr, const_bool_false); } else { - assert(node->data.suspend.promise_symbol != nullptr); - assert(node->data.suspend.promise_symbol->type == NodeTypeSymbol); - Buf *promise_symbol_name = node->data.suspend.promise_symbol->data.symbol_expr.symbol; Scope *child_scope; - if (!buf_eql_str(promise_symbol_name, "_")) { - VariableTableEntry *promise_var = ir_create_var(irb, node, parent_scope, promise_symbol_name, - true, true, false, const_bool_false); - ir_build_var_decl(irb, parent_scope, node, promise_var, nullptr, nullptr, irb->exec->coro_handle); - child_scope = promise_var->child_scope; - } else { - child_scope = parent_scope; - } - ScopeSuspend *suspend_scope = create_suspend_scope(node, child_scope); + ScopeSuspend *suspend_scope = create_suspend_scope(node, parent_scope); suspend_scope->resume_block = resume_block; child_scope = &suspend_scope->base; IrInstruction *save_token = ir_build_coro_save(irb, child_scope, node, irb->exec->coro_handle); @@ -19007,6 +19023,14 @@ static TypeTableEntry *ir_analyze_instruction_frame_address(IrAnalyze *ira, IrIn return u8_ptr_const; } +static TypeTableEntry *ir_analyze_instruction_handle(IrAnalyze *ira, IrInstructionHandle *instruction) { + ir_build_handle_from(&ira->new_irb, &instruction->base); + + FnTableEntry *fn_entry = exec_fn_entry(ira->new_irb.exec); + assert(fn_entry != nullptr); + return get_promise_type(ira->codegen, fn_entry->type_entry->data.fn.fn_type_id.return_type); +} + static TypeTableEntry *ir_analyze_instruction_align_of(IrAnalyze *ira, IrInstructionAlignOf *instruction) { IrInstruction *type_value = instruction->type_value->other; if (type_is_invalid(type_value->value.type)) @@ -20982,6 +21006,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi return ir_analyze_instruction_return_address(ira, (IrInstructionReturnAddress *)instruction); case IrInstructionIdFrameAddress: return ir_analyze_instruction_frame_address(ira, (IrInstructionFrameAddress *)instruction); + case IrInstructionIdHandle: + return ir_analyze_instruction_handle(ira, (IrInstructionHandle *)instruction); case IrInstructionIdAlignOf: return ir_analyze_instruction_align_of(ira, (IrInstructionAlignOf *)instruction); case IrInstructionIdOverflowOp: @@ -21274,6 +21300,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdAlignOf: case IrInstructionIdReturnAddress: case IrInstructionIdFrameAddress: + case IrInstructionIdHandle: case IrInstructionIdTestErr: case IrInstructionIdUnwrapErrCode: case IrInstructionIdOptionalWrap: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 127afa94a5..77c7ef47b6 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -791,6 +791,10 @@ static void ir_print_frame_address(IrPrint *irp, IrInstructionFrameAddress *inst fprintf(irp->f, "@frameAddress()"); } +static void ir_print_handle(IrPrint *irp, IrInstructionHandle *instruction) { + fprintf(irp->f, "@handle()"); +} + static void ir_print_return_address(IrPrint *irp, IrInstructionReturnAddress *instruction) { fprintf(irp->f, "@returnAddress()"); } @@ -1556,6 +1560,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdFrameAddress: ir_print_frame_address(irp, (IrInstructionFrameAddress *)instruction); break; + case IrInstructionIdHandle: + ir_print_handle(irp, (IrInstructionHandle *)instruction); + break; case IrInstructionIdAlignOf: ir_print_align_of(irp, (IrInstructionAlignOf *)instruction); break; diff --git a/src/parser.cpp b/src/parser.cpp index a93d8de830..84ccdbeea8 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -648,35 +648,37 @@ static AstNode *ast_parse_asm_expr(ParseContext *pc, size_t *token_index, bool m } /* -SuspendExpression(body) = "suspend" option(("|" Symbol "|" body)) +SuspendExpression(body) = "suspend" option( body ) */ static AstNode *ast_parse_suspend_block(ParseContext *pc, size_t *token_index, bool mandatory) { size_t orig_token_index = *token_index; + Token *token = &pc->tokens->at(*token_index); + Token *suspend_token = nullptr; - Token *suspend_token = &pc->tokens->at(*token_index); - if (suspend_token->id == TokenIdKeywordSuspend) { + if (token->id == TokenIdKeywordSuspend) { *token_index += 1; + suspend_token = token; + token = &pc->tokens->at(*token_index); } else if (mandatory) { - ast_expect_token(pc, suspend_token, TokenIdKeywordSuspend); + ast_expect_token(pc, token, TokenIdKeywordSuspend); zig_unreachable(); } else { return nullptr; } - Token *bar_token = &pc->tokens->at(*token_index); - if (bar_token->id == TokenIdBinOr) { - *token_index += 1; - } else if (mandatory) { - ast_expect_token(pc, suspend_token, TokenIdBinOr); - zig_unreachable(); - } else { - *token_index = orig_token_index; - return nullptr; + //guessing that semicolon is checked elsewhere? + if (token->id != TokenIdLBrace) { + if (mandatory) { + ast_expect_token(pc, token, TokenIdLBrace); + zig_unreachable(); + } else { + *token_index = orig_token_index; + return nullptr; + } } + //Expect that we have a block; AstNode *node = ast_create_node(pc, NodeTypeSuspend, suspend_token); - node->data.suspend.promise_symbol = ast_parse_symbol(pc, token_index); - ast_eat_token(pc, token_index, TokenIdBinOr); node->data.suspend.block = ast_parse_block(pc, token_index, true); return node; @@ -3134,7 +3136,6 @@ void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *cont visit_field(&node->data.await_expr.expr, visit, context); break; case NodeTypeSuspend: - visit_field(&node->data.suspend.promise_symbol, visit, context); visit_field(&node->data.suspend.block, visit, context); break; } diff --git a/std/event/channel.zig b/std/event/channel.zig index 03a036042b..71e97f6e78 100644 --- a/std/event/channel.zig +++ b/std/event/channel.zig @@ -71,10 +71,10 @@ pub fn Channel(comptime T: type) type { /// puts a data item in the channel. The promise completes when the value has been added to the /// buffer, or in the case of a zero size buffer, when the item has been retrieved by a getter. pub async fn put(self: *SelfChannel, data: T) void { - suspend |handle| { + suspend { var my_tick_node = Loop.NextTickNode{ .next = undefined, - .data = handle, + .data = @handle(), }; var queue_node = std.atomic.Queue(PutNode).Node{ .data = PutNode{ @@ -96,10 +96,10 @@ pub fn Channel(comptime T: type) type { // TODO integrate this function with named return values // so we can get rid of this extra result copy var result: T = undefined; - suspend |handle| { + suspend { var my_tick_node = Loop.NextTickNode{ .next = undefined, - .data = handle, + .data = @handle(), }; var queue_node = std.atomic.Queue(GetNode).Node{ .data = GetNode{ diff --git a/std/event/future.zig b/std/event/future.zig index f5d14d1ca6..8abdce7d02 100644 --- a/std/event/future.zig +++ b/std/event/future.zig @@ -100,8 +100,8 @@ test "std.event.Future" { } async fn testFuture(loop: *Loop) void { - suspend |p| { - resume p; + suspend { + resume @handle(); } var future = Future(i32).init(loop); @@ -115,15 +115,15 @@ async fn testFuture(loop: *Loop) void { } async fn waitOnFuture(future: *Future(i32)) i32 { - suspend |p| { - resume p; + suspend { + resume @handle(); } return (await (async future.get() catch @panic("memory"))).*; } async fn resolveFuture(future: *Future(i32)) void { - suspend |p| { - resume p; + suspend { + resume @handle(); } future.data = 6; future.resolve(); diff --git a/std/event/group.zig b/std/event/group.zig index 26c098399e..6c7fc63699 100644 --- a/std/event/group.zig +++ b/std/event/group.zig @@ -54,10 +54,10 @@ pub fn Group(comptime ReturnType: type) type { const S = struct { async fn asyncFunc(node: **Stack.Node, args2: ...) ReturnType { // TODO this is a hack to make the memory following be inside the coro frame - suspend |p| { + suspend { var my_node: Stack.Node = undefined; node.* = &my_node; - resume p; + resume @handle(); } // TODO this allocation elision should be guaranteed because we await it in diff --git a/std/event/lock.zig b/std/event/lock.zig index 0bd7183db2..c4cb1a3f0e 100644 --- a/std/event/lock.zig +++ b/std/event/lock.zig @@ -90,10 +90,10 @@ pub const Lock = struct { } pub async fn acquire(self: *Lock) Held { - suspend |handle| { + suspend { // TODO explicitly put this memory in the coroutine frame #1194 var my_tick_node = Loop.NextTickNode{ - .data = handle, + .data = @handle(), .next = undefined, }; @@ -141,8 +141,8 @@ test "std.event.Lock" { async fn testLock(loop: *Loop, lock: *Lock) void { // TODO explicitly put next tick node memory in the coroutine frame #1194 - suspend |p| { - resume p; + suspend { + resume @handle(); } const handle1 = async lockRunner(lock) catch @panic("out of memory"); var tick_node1 = Loop.NextTickNode{ diff --git a/std/event/loop.zig b/std/event/loop.zig index 4e219653be..8b1b2e53db 100644 --- a/std/event/loop.zig +++ b/std/event/loop.zig @@ -331,11 +331,11 @@ pub const Loop = struct { pub async fn waitFd(self: *Loop, fd: i32) !void { defer self.removeFd(fd); - suspend |p| { + suspend { // TODO explicitly put this memory in the coroutine frame #1194 var resume_node = ResumeNode{ .id = ResumeNode.Id.Basic, - .handle = p, + .handle = @handle(), }; try self.addFd(fd, &resume_node); } @@ -417,11 +417,11 @@ pub const Loop = struct { pub fn call(self: *Loop, comptime func: var, args: ...) !(promise->@typeOf(func).ReturnType) { const S = struct { async fn asyncFunc(loop: *Loop, handle: *promise->@typeOf(func).ReturnType, args2: ...) @typeOf(func).ReturnType { - suspend |p| { - handle.* = p; + suspend { + handle.* = @handle(); var my_tick_node = Loop.NextTickNode{ .next = undefined, - .data = p, + .data = @handle(), }; loop.onNextTick(&my_tick_node); } @@ -439,10 +439,10 @@ pub const Loop = struct { /// CPU bound tasks would be waiting in the event loop but never get started because no async I/O /// is performed. pub async fn yield(self: *Loop) void { - suspend |p| { + suspend { var my_tick_node = Loop.NextTickNode{ .next = undefined, - .data = p, + .data = @handle(), }; self.onNextTick(&my_tick_node); } diff --git a/std/event/tcp.zig b/std/event/tcp.zig index 416a8c07dc..ea803a9322 100644 --- a/std/event/tcp.zig +++ b/std/event/tcp.zig @@ -88,8 +88,8 @@ pub const Server = struct { }, error.ProcessFdQuotaExceeded => { errdefer std.os.emfile_promise_queue.remove(&self.waiting_for_emfile_node); - suspend |p| { - self.waiting_for_emfile_node = PromiseNode.init(p); + suspend { + self.waiting_for_emfile_node = PromiseNode.init( @handle() ); std.os.emfile_promise_queue.append(&self.waiting_for_emfile_node); } continue; @@ -141,8 +141,8 @@ test "listen on a port, send bytes, receive bytes" { (await next_handler) catch |err| { std.debug.panic("unable to handle connection: {}\n", err); }; - suspend |p| { - cancel p; + suspend { + cancel @handle(); } } async fn errorableHandler(self: *Self, _addr: *const std.net.Address, _socket: *const std.os.File) !void { diff --git a/std/zig/parser_test.zig b/std/zig/parser_test.zig index 21259bec3c..32cdc8121f 100644 --- a/std/zig/parser_test.zig +++ b/std/zig/parser_test.zig @@ -1784,7 +1784,7 @@ test "zig fmt: coroutines" { \\ x += 1; \\ suspend; \\ x += 1; - \\ suspend |p| {} + \\ suspend; \\ const p: promise->void = async simpleAsyncFn() catch unreachable; \\ await p; \\} diff --git a/test/cases/cancel.zig b/test/cases/cancel.zig index edf11d687d..c0f74fd34f 100644 --- a/test/cases/cancel.zig +++ b/test/cases/cancel.zig @@ -85,8 +85,8 @@ async fn b4() void { defer { defer_b4 = true; } - suspend |p| { - b4_handle = p; + suspend { + b4_handle = @handle(); } suspend; } diff --git a/test/cases/coroutine_await_struct.zig b/test/cases/coroutine_await_struct.zig index 56c526092d..79168715d8 100644 --- a/test/cases/coroutine_await_struct.zig +++ b/test/cases/coroutine_await_struct.zig @@ -30,9 +30,9 @@ async fn await_amain() void { } async fn await_another() Foo { await_seq('c'); - suspend |p| { + suspend { await_seq('d'); - await_a_promise = p; + await_a_promise = @handle(); } await_seq('g'); return Foo{ .x = 1234 }; diff --git a/test/cases/coroutines.zig b/test/cases/coroutines.zig index 72a4ed0b38..a955eeac37 100644 --- a/test/cases/coroutines.zig +++ b/test/cases/coroutines.zig @@ -62,10 +62,15 @@ test "coroutine suspend with block" { var a_promise: promise = undefined; var result = false; async fn testSuspendBlock() void { - suspend |p| { - comptime assert(@typeOf(p) == promise->void); - a_promise = p; + suspend { + comptime assert(@typeOf(@handle()) == promise->void); + a_promise = @handle(); } + + //Test to make sure that @handle() works as advertised (issue #1296) + //var our_handle: promise = @handle(); + assert( a_promise == @handle() ); + result = true; } @@ -93,9 +98,9 @@ async fn await_amain() void { } async fn await_another() i32 { await_seq('c'); - suspend |p| { + suspend { await_seq('d'); - await_a_promise = p; + await_a_promise = @handle(); } await_seq('g'); return 1234; @@ -244,10 +249,26 @@ test "break from suspend" { std.debug.assert(my_result == 2); } async fn testBreakFromSuspend(my_result: *i32) void { - suspend |p| { - resume p; + suspend { + resume @handle(); } my_result.* += 1; suspend; my_result.* += 1; } + +test "suspend resume @handle()" { + var buf: [500]u8 = undefined; + var a = &std.heap.FixedBufferAllocator.init(buf[0..]).allocator; + var my_result: i32 = 1; + const p = try async testBreakFromSuspend(&my_result); + std.debug.assert(my_result == 2); +} +async fn testSuspendResumeAtHandle() void { + suspend { + resume @handle(); + } + my_result.* += 1; + suspend; + my_result.* += 1; +} \ No newline at end of file diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 2c4c9208eb..f4b289f70d 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -367,8 +367,8 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { \\} \\ \\async fn foo() void { - \\ suspend |p| { - \\ suspend |p1| { + \\ suspend { + \\ suspend { \\ } \\ } \\} @@ -4738,4 +4738,34 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { , ".tmp_source.zig:3:36: error: @ArgType could not resolve the type of arg 0 because 'fn(var)var' is generic", ); + + cases.add( + "@handle() called outside of function definition", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\ + \\var handle_undef: promise = undefined; + \\var handle_dummy: promise = @handle(); + \\ + \\pub fn main() void { + \\ if (handle_undef == handle_dummy) return 0; + \\} + , + ".tmp_source.zig:6:29: error: @handle() called outside of function definition", + ); + + cases.add( + "@handle() in non-async function", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\ + \\pub fn main() void { + \\ var handle_undef: promise = undefined; + \\ if (handle_undef == @handle()) return 0; + \\} + , + ".tmp_source.zig:7:25: error: @handle() in non-async function", + ); }