diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index 16536dc71f..71a233c964 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -7641,12 +7641,7 @@ static IrInstSrc *ir_gen_fn_call(IrBuilderSrc *irb, Scope *scope, AstNode *node, bool is_nosuspend = get_scope_nosuspend(scope) != nullptr; CallModifier modifier = node->data.fn_call_expr.modifier; - if (is_nosuspend) { - if (modifier == CallModifierAsync) { - add_node_error(irb->codegen, node, - buf_sprintf("async call in nosuspend scope")); - return irb->codegen->invalid_inst_src; - } + if (is_nosuspend && modifier != CallModifierAsync) { modifier = CallModifierNoSuspend; } @@ -10129,10 +10124,6 @@ static IrInstSrc *ir_gen_fn_proto(IrBuilderSrc *irb, Scope *parent_scope, AstNod static IrInstSrc *ir_gen_resume(IrBuilderSrc *irb, Scope *scope, AstNode *node) { assert(node->type == NodeTypeResume); - if (get_scope_nosuspend(scope) != nullptr) { - add_node_error(irb->codegen, node, buf_sprintf("resume in nosuspend scope")); - return irb->codegen->invalid_inst_src; - } IrInstSrc *target_inst = ir_gen_node_extra(irb, node->data.resume_expr.expr, scope, LValPtr, nullptr); if (target_inst == irb->codegen->invalid_inst_src) diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 2fb4c36ed4..808f0b31dc 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -1027,9 +1027,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { \\} \\fn foo() void {} , &[_][]const u8{ - "tmp.zig:3:21: error: async call in nosuspend scope", "tmp.zig:4:9: error: suspend in nosuspend scope", - "tmp.zig:5:9: error: resume in nosuspend scope", }); cases.add("atomicrmw with bool op not .Xchg", diff --git a/test/stage1/behavior/async_fn.zig b/test/stage1/behavior/async_fn.zig index 16c7b14944..40269df5ec 100644 --- a/test/stage1/behavior/async_fn.zig +++ b/test/stage1/behavior/async_fn.zig @@ -1,5 +1,5 @@ const std = @import("std"); -const builtin = @import("builtin"); +const builtin = std.builtin; const expect = std.testing.expect; const expectEqual = std.testing.expectEqual; const expectEqualStrings = std.testing.expectEqualStrings; @@ -1545,6 +1545,68 @@ test "nosuspend on function calls" { expectEqual(@as(i32, 42), (try nosuspend S1.d()).b); } +test "nosuspend on async function calls" { + const S0 = struct { + b: i32 = 42, + }; + const S1 = struct { + fn c() S0 { + return S0{}; + } + fn d() !S0 { + return S0{}; + } + }; + var frame_c = nosuspend async S1.c(); + expectEqual(@as(i32, 42), (await frame_c).b); + var frame_d = nosuspend async S1.d(); + expectEqual(@as(i32, 42), (try await frame_d).b); +} + +// test "resume nosuspend async function calls" { +// const S0 = struct { +// b: i32 = 42, +// }; +// const S1 = struct { +// fn c() S0 { +// suspend; +// return S0{}; +// } +// fn d() !S0 { +// suspend; +// return S0{}; +// } +// }; +// var frame_c = nosuspend async S1.c(); +// resume frame_c; +// expectEqual(@as(i32, 42), (await frame_c).b); +// var frame_d = nosuspend async S1.d(); +// resume frame_d; +// expectEqual(@as(i32, 42), (try await frame_d).b); +// } + +test "nosuspend resume async function calls" { + const S0 = struct { + b: i32 = 42, + }; + const S1 = struct { + fn c() S0 { + suspend; + return S0{}; + } + fn d() !S0 { + suspend; + return S0{}; + } + }; + var frame_c = async S1.c(); + nosuspend resume frame_c; + expectEqual(@as(i32, 42), (await frame_c).b); + var frame_d = async S1.d(); + nosuspend resume frame_d; + expectEqual(@as(i32, 42), (try await frame_d).b); +} + test "avoid forcing frame alignment resolution implicit cast to *c_void" { const S = struct { var x: ?*c_void = null;