diff --git a/doc/langref.html.in b/doc/langref.html.in index 25a90e3361..83d5e65bba 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -5682,7 +5682,7 @@ ErrorSetExpr = (PrefixOpExpression "!" PrefixOpExpression) | PrefixOpExpression BlockOrExpression = Block | Expression -Expression = TryExpression | ReturnExpression | BreakExpression | AssignmentExpression | CancelExpression +Expression = TryExpression | ReturnExpression | BreakExpression | AssignmentExpression | CancelExpression | ResumeExpression AsmExpression = "asm" option("volatile") "(" String option(AsmOutput) ")" @@ -5730,6 +5730,8 @@ BreakExpression = "break" option(":" Symbol) option(Expression) CancelExpression = "cancel" Expression; +ResumeExpression = "resume" Expression; + Defer(body) = ("defer" | "deferror") body IfExpression(body) = "if" "(" Expression ")" body option("else" BlockExpression(body)) diff --git a/src/all_types.hpp b/src/all_types.hpp index d727f4a862..d2c7875943 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -403,6 +403,7 @@ enum NodeType { NodeTypeTestExpr, NodeTypeErrorSetDecl, NodeTypeCancel, + NodeTypeResume, NodeTypeAwaitExpr, NodeTypeSuspend, }; @@ -849,6 +850,10 @@ struct AstNodeCancelExpr { AstNode *expr; }; +struct AstNodeResumeExpr { + AstNode *expr; +}; + struct AstNodeContinueExpr { Buf *name; }; @@ -930,6 +935,7 @@ struct AstNode { AstNodeVarLiteral var_literal; AstNodeErrorSetDecl err_set_decl; AstNodeCancelExpr cancel_expr; + AstNodeResumeExpr resume_expr; AstNodeAwaitExpr await_expr; AstNodeSuspend suspend; } data; diff --git a/src/analyze.cpp b/src/analyze.cpp index be01f6b5f8..8842b4967e 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -3212,6 +3212,7 @@ void scan_decls(CodeGen *g, ScopeDecls *decls_scope, AstNode *node) { case NodeTypeTestExpr: case NodeTypeErrorSetDecl: case NodeTypeCancel: + case NodeTypeResume: case NodeTypeAwaitExpr: case NodeTypeSuspend: zig_unreachable(); diff --git a/src/ast_render.cpp b/src/ast_render.cpp index 5f3e1998fd..6318ba3cff 100644 --- a/src/ast_render.cpp +++ b/src/ast_render.cpp @@ -246,6 +246,8 @@ static const char *node_type_str(NodeType node_type) { return "ErrorSetDecl"; case NodeTypeCancel: return "Cancel"; + case NodeTypeResume: + return "Resume"; case NodeTypeAwaitExpr: return "AwaitExpr"; case NodeTypeSuspend: @@ -1049,6 +1051,12 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) { render_node_grouped(ar, node->data.cancel_expr.expr); break; } + case NodeTypeResume: + { + fprintf(ar->f, "resume "); + render_node_grouped(ar, node->data.resume_expr.expr); + break; + } case NodeTypeAwaitExpr: { fprintf(ar->f, "await "); diff --git a/src/codegen.cpp b/src/codegen.cpp index 315699b826..59956c9279 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -4004,7 +4004,7 @@ static LLVMValueRef ir_render_coro_save(CodeGen *g, IrExecutable *executable, Ir static LLVMValueRef get_coro_alloc_helper_fn_val(CodeGen *g, LLVMTypeRef alloc_fn_type_ref, TypeTableEntry *fn_type) { if (g->coro_alloc_helper_fn_val != nullptr) - return g->coro_alloc_fn_val; + return g->coro_alloc_helper_fn_val; assert(fn_type->id == TypeTableEntryIdFn); diff --git a/src/ir.cpp b/src/ir.cpp index dc845bdaf7..4222196f37 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -5927,6 +5927,16 @@ static IrInstruction *ir_gen_cancel(IrBuilder *irb, Scope *parent_scope, AstNode return ir_build_cancel(irb, parent_scope, node, target_inst); } +static IrInstruction *ir_gen_resume(IrBuilder *irb, Scope *parent_scope, AstNode *node) { + assert(node->type == NodeTypeResume); + + IrInstruction *target_inst = ir_gen_node(irb, node->data.resume_expr.expr, parent_scope); + if (target_inst == irb->codegen->invalid_instruction) + return irb->codegen->invalid_instruction; + + return ir_build_coro_resume(irb, parent_scope, node, target_inst); +} + static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, AstNode *node) { assert(node->type == NodeTypeAwaitExpr); @@ -6101,6 +6111,8 @@ static IrInstruction *ir_gen_node_raw(IrBuilder *irb, AstNode *node, Scope *scop return ir_lval_wrap(irb, scope, ir_gen_err_set_decl(irb, scope, node), lval); case NodeTypeCancel: return ir_lval_wrap(irb, scope, ir_gen_cancel(irb, scope, node), lval); + case NodeTypeResume: + return ir_lval_wrap(irb, scope, ir_gen_resume(irb, scope, node), lval); case NodeTypeAwaitExpr: return ir_lval_wrap(irb, scope, ir_gen_await_expr(irb, scope, node), lval); case NodeTypeSuspend: @@ -17364,8 +17376,12 @@ static TypeTableEntry *ir_analyze_instruction_coro_resume(IrAnalyze *ira, IrInst if (type_is_invalid(awaiter_handle->value.type)) return ira->codegen->builtin_types.entry_invalid; + IrInstruction *casted_target = ir_implicit_cast(ira, awaiter_handle, ira->codegen->builtin_types.entry_promise); + if (type_is_invalid(casted_target->value.type)) + return ira->codegen->builtin_types.entry_invalid; + IrInstruction *result = ir_build_coro_resume(&ira->new_irb, instruction->base.scope, - instruction->base.source_node, awaiter_handle); + instruction->base.source_node, casted_target); ir_link_new_instruction(result, &instruction->base); result->value.type = ira->codegen->builtin_types.entry_void; return result->value.type; diff --git a/src/parser.cpp b/src/parser.cpp index 763273fd0a..38994c79fc 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -1638,6 +1638,24 @@ static AstNode *ast_parse_cancel_expr(ParseContext *pc, size_t *token_index) { return node; } +/* +ResumeExpression = "resume" Expression; +*/ +static AstNode *ast_parse_resume_expr(ParseContext *pc, size_t *token_index) { + Token *token = &pc->tokens->at(*token_index); + + if (token->id != TokenIdKeywordResume) { + return nullptr; + } + *token_index += 1; + + AstNode *node = ast_create_node(pc, NodeTypeResume, token); + + node->data.resume_expr.expr = ast_parse_expression(pc, token_index, false); + + return node; +} + /* Defer(body) = ("defer" | "errdefer") body */ @@ -2266,7 +2284,7 @@ static AstNode *ast_parse_block_or_expression(ParseContext *pc, size_t *token_in } /* -Expression = TryExpression | ReturnExpression | BreakExpression | AssignmentExpression | CancelExpression +Expression = TryExpression | ReturnExpression | BreakExpression | AssignmentExpression | CancelExpression | ResumeExpression */ static AstNode *ast_parse_expression(ParseContext *pc, size_t *token_index, bool mandatory) { Token *token = &pc->tokens->at(*token_index); @@ -2287,6 +2305,10 @@ static AstNode *ast_parse_expression(ParseContext *pc, size_t *token_index, bool if (cancel_expr) return cancel_expr; + AstNode *resume_expr = ast_parse_resume_expr(pc, token_index); + if (resume_expr) + return resume_expr; + AstNode *ass_expr = ast_parse_ass_expr(pc, token_index, false); if (ass_expr) return ass_expr; @@ -3060,6 +3082,9 @@ void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *cont case NodeTypeCancel: visit_field(&node->data.cancel_expr.expr, visit, context); break; + case NodeTypeResume: + visit_field(&node->data.resume_expr.expr, visit, context); + break; case NodeTypeAwaitExpr: visit_field(&node->data.await_expr.expr, visit, context); break; diff --git a/test/cases/coroutines.zig b/test/cases/coroutines.zig index f5e70774fa..2a5505360c 100644 --- a/test/cases/coroutines.zig +++ b/test/cases/coroutines.zig @@ -14,3 +14,30 @@ async fn simpleAsyncFn() void { suspend; x += 1; } + +test "coroutine suspend, resume, cancel" { + seq('a'); + const p = (async(std.debug.global_allocator) testAsyncSeq()) catch unreachable; + seq('c'); + resume p; + seq('f'); + cancel p; + seq('g'); + + assert(std.mem.eql(u8, points, "abcdefg")); +} + +async fn testAsyncSeq() void { + defer seq('e'); + + seq('b'); + suspend; + seq('d'); +} +var points = []u8{0} ** "abcdefg".len; +var index: usize = 0; + +fn seq(c: u8) void { + points[index] = c; + index += 1; +}