diff --git a/src/all_types.hpp b/src/all_types.hpp index d2705d8ec6..63292dd8ec 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -63,6 +63,8 @@ struct IrExecutable { IrInstruction *implicit_allocator_ptr; IrBasicBlock *coro_early_final; IrBasicBlock *coro_normal_final; + IrBasicBlock *coro_suspend_block; + IrBasicBlock *coro_final_cleanup_block; }; enum OutType { @@ -1631,6 +1633,7 @@ struct CodeGen { LLVMValueRef coro_end_fn_val; LLVMValueRef coro_free_fn_val; LLVMValueRef coro_resume_fn_val; + LLVMValueRef coro_save_fn_val; bool error_during_imports; const char **clang_argv; @@ -2000,6 +2003,7 @@ enum IrInstructionId { IrInstructionIdCoroEnd, IrInstructionIdCoroFree, IrInstructionIdCoroResume, + IrInstructionIdCoroSave, }; struct IrInstruction { @@ -2902,6 +2906,12 @@ struct IrInstructionCoroResume { IrInstruction *awaiter_handle; }; +struct IrInstructionCoroSave { + IrInstruction base; + + IrInstruction *coro_handle; +}; + static const size_t slice_ptr_index = 0; static const size_t slice_len_index = 1; diff --git a/src/codegen.cpp b/src/codegen.cpp index 0c4f66daa4..f82c686b85 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1066,6 +1066,21 @@ static LLVMValueRef get_coro_resume_fn_val(CodeGen *g) { return g->coro_resume_fn_val; } +static LLVMValueRef get_coro_save_fn_val(CodeGen *g) { + if (g->coro_save_fn_val) + return g->coro_save_fn_val; + + LLVMTypeRef param_types[] = { + LLVMPointerType(LLVMInt8Type(), 0), + }; + LLVMTypeRef fn_type = LLVMFunctionType(ZigLLVMTokenTypeInContext(LLVMGetGlobalContext()), param_types, 1, false); + Buf *name = buf_sprintf("llvm.coro.save"); + g->coro_save_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type); + assert(LLVMGetIntrinsicID(g->coro_save_fn_val)); + + return g->coro_save_fn_val; +} + static LLVMValueRef get_return_address_fn_val(CodeGen *g) { if (g->return_address_fn_val) return g->return_address_fn_val; @@ -3954,6 +3969,11 @@ static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable, return LLVMBuildCall(g->builder, get_coro_resume_fn_val(g), &awaiter_handle, 1, ""); } +static LLVMValueRef ir_render_coro_save(CodeGen *g, IrExecutable *executable, IrInstructionCoroSave *instruction) { + LLVMValueRef coro_handle = ir_llvm_value(g, instruction->coro_handle); + return LLVMBuildCall(g->builder, get_coro_save_fn_val(g), &coro_handle, 1, ""); +} + static void set_debug_location(CodeGen *g, IrInstruction *instruction) { AstNode *source_node = instruction->source_node; Scope *scope = instruction->scope; @@ -4157,6 +4177,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_coro_free(g, executable, (IrInstructionCoroFree *)instruction); case IrInstructionIdCoroResume: return ir_render_coro_resume(g, executable, (IrInstructionCoroResume *)instruction); + case IrInstructionIdCoroSave: + return ir_render_coro_save(g, executable, (IrInstructionCoroSave *)instruction); } zig_unreachable(); } diff --git a/src/ir.cpp b/src/ir.cpp index 7ed66b92bd..2600f5e948 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -691,6 +691,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionCoroResume *) { return IrInstructionIdCoroResume; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionCoroSave *) { + return IrInstructionIdCoroSave; +} + template static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) { T *special_instruction = allocate(1); @@ -2585,6 +2589,17 @@ static IrInstruction *ir_build_coro_resume(IrBuilder *irb, Scope *scope, AstNode return &instruction->base; } +static IrInstruction *ir_build_coro_save(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *coro_handle) +{ + IrInstructionCoroSave *instruction = ir_build_instruction(irb, scope, source_node); + instruction->coro_handle = coro_handle; + + ir_ref_instruction(coro_handle, irb->current_basic_block); + + return &instruction->base; +} + static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) { results[ReturnKindUnconditional] = 0; results[ReturnKindError] = 0; @@ -5847,7 +5862,67 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, Ast static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNode *node) { assert(node->type == NodeTypeSuspend); - zig_panic("TODO: generate suspend"); + FnTableEntry *fn_entry = exec_fn_entry(irb->exec); + if (!fn_entry) { + add_node_error(irb->codegen, node, buf_sprintf("suspend outside function definition")); + return irb->codegen->invalid_instruction; + } + if (fn_entry->type_entry->data.fn.fn_type_id.cc != CallingConventionAsync) { + add_node_error(irb->codegen, node, buf_sprintf("suspend in non-async function")); + return irb->codegen->invalid_instruction; + } + + ScopeDeferExpr *scope_defer_expr = get_scope_defer_expr(parent_scope); + if (scope_defer_expr) { + if (!scope_defer_expr->reported_err) { + add_node_error(irb->codegen, node, buf_sprintf("cannot suspend inside defer expression")); + scope_defer_expr->reported_err = true; + } + return irb->codegen->invalid_instruction; + } + + Scope *outer_scope = irb->exec->begin_scope; + + + IrInstruction *suspend_code; + IrInstruction *const_bool_false = ir_build_const_bool(irb, parent_scope, node, false); + if (node->data.suspend.block == nullptr) { + suspend_code = ir_build_coro_suspend(irb, parent_scope, node, nullptr, const_bool_false); + } else { + assert(node->data.suspend.promise_symbol != nullptr); + assert(node->data.suspend.promise_symbol->type == NodeTypeSymbol); + Buf *promise_symbol_name = node->data.suspend.promise_symbol->data.symbol_expr.symbol; + Scope *child_scope; + if (!buf_eql_str(promise_symbol_name, "_")) { + VariableTableEntry *promise_var = ir_create_var(irb, node, parent_scope, promise_symbol_name, + true, true, false, const_bool_false); + ir_build_var_decl(irb, parent_scope, node, promise_var, nullptr, nullptr, irb->exec->coro_handle); + child_scope = promise_var->child_scope; + } else { + child_scope = parent_scope; + } + IrInstruction *save_token = ir_build_coro_save(irb, child_scope, node, irb->exec->coro_handle); + ir_gen_node(irb, node->data.suspend.block, child_scope); + suspend_code = ir_build_coro_suspend(irb, parent_scope, node, save_token, const_bool_false); + } + + IrBasicBlock *cleanup_block = ir_create_basic_block(irb, parent_scope, "SuspendCleanup"); + IrBasicBlock *resume_block = ir_create_basic_block(irb, parent_scope, "SuspendResume"); + + IrInstructionSwitchBrCase *cases = allocate(2); + cases[0].value = ir_build_const_u8(irb, parent_scope, node, 0); + cases[0].block = resume_block; + cases[1].value = ir_build_const_u8(irb, parent_scope, node, 1); + cases[1].block = cleanup_block; + ir_build_switch_br(irb, parent_scope, node, suspend_code, irb->exec->coro_suspend_block, + 2, cases, const_bool_false); + + ir_set_cursor_at_end_and_append_block(irb, cleanup_block); + ir_gen_defers_for_block(irb, parent_scope, outer_scope, true); + ir_build_br(irb, parent_scope, node, irb->exec->coro_final_cleanup_block, const_bool_false); + + ir_set_cursor_at_end_and_append_block(irb, resume_block); + return ir_build_const_void(irb, parent_scope, node); } static IrInstruction *ir_gen_node_raw(IrBuilder *irb, AstNode *node, Scope *scope, @@ -6099,6 +6174,8 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec irb->exec->coro_early_final = ir_create_basic_block(irb, scope, "CoroEarlyFinal"); irb->exec->coro_normal_final = ir_create_basic_block(irb, scope, "CoroNormalFinal"); + irb->exec->coro_suspend_block = ir_create_basic_block(irb, scope, "Suspend"); + irb->exec->coro_final_cleanup_block = ir_create_basic_block(irb, scope, "FinalCleanup"); } IrInstruction *result = ir_gen_node_extra(irb, node, scope, LVAL_NONE); @@ -6112,8 +6189,6 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec if (is_async) { IrBasicBlock *invalid_resume_block = ir_create_basic_block(irb, scope, "InvalidResume"); - IrBasicBlock *final_cleanup_block = ir_create_basic_block(irb, scope, "FinalCleanup"); - IrBasicBlock *suspend_block = ir_create_basic_block(irb, scope, "Suspend"); IrBasicBlock *check_free_block = ir_create_basic_block(irb, scope, "CheckFree"); ir_set_cursor_at_end_and_append_block(irb, irb->exec->coro_early_final); @@ -6123,10 +6198,10 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec cases[0].value = ir_build_const_u8(irb, scope, node, 0); cases[0].block = invalid_resume_block; cases[1].value = ir_build_const_u8(irb, scope, node, 1); - cases[1].block = final_cleanup_block; - ir_build_switch_br(irb, scope, node, suspend_code, suspend_block, 2, cases, const_bool_false); + cases[1].block = irb->exec->coro_final_cleanup_block; + ir_build_switch_br(irb, scope, node, suspend_code, irb->exec->coro_suspend_block, 2, cases, const_bool_false); - ir_set_cursor_at_end_and_append_block(irb, suspend_block); + ir_set_cursor_at_end_and_append_block(irb, irb->exec->coro_suspend_block); ir_build_coro_end(irb, scope, node); ir_build_return(irb, scope, node, irb->exec->coro_handle); @@ -6136,7 +6211,7 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec ir_set_cursor_at_end_and_append_block(irb, irb->exec->coro_normal_final); ir_build_br(irb, scope, node, check_free_block, const_bool_false); - ir_set_cursor_at_end_and_append_block(irb, final_cleanup_block); + ir_set_cursor_at_end_and_append_block(irb, irb->exec->coro_final_cleanup_block); if (type_has_bits(return_type)) { IrInstruction *result_ptr = ir_build_load_ptr(irb, scope, node, irb->exec->coro_result_ptr_field_ptr); IrInstruction *result_ptr_as_u8_ptr = ir_build_ptr_cast(irb, scope, node, u8_ptr_type, result_ptr); @@ -6152,7 +6227,7 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec ir_set_cursor_at_end_and_append_block(irb, check_free_block); IrBasicBlock **incoming_blocks = allocate(2); IrInstruction **incoming_values = allocate(2); - incoming_blocks[0] = final_cleanup_block; + incoming_blocks[0] = irb->exec->coro_final_cleanup_block; incoming_values[0] = const_bool_false; incoming_blocks[1] = irb->exec->coro_normal_final; incoming_values[1] = const_bool_true; @@ -17219,6 +17294,18 @@ static TypeTableEntry *ir_analyze_instruction_coro_resume(IrAnalyze *ira, IrInst return result->value.type; } +static TypeTableEntry *ir_analyze_instruction_coro_save(IrAnalyze *ira, IrInstructionCoroSave *instruction) { + IrInstruction *coro_handle = instruction->coro_handle->other; + if (type_is_invalid(coro_handle->value.type)) + return ira->codegen->builtin_types.entry_invalid; + + IrInstruction *result = ir_build_coro_save(&ira->new_irb, instruction->base.scope, + instruction->base.source_node, coro_handle); + ir_link_new_instruction(result, &instruction->base); + result->value.type = ira->codegen->builtin_types.entry_usize; + return result->value.type; +} + static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstruction *instruction) { switch (instruction->id) { @@ -17444,6 +17531,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi return ir_analyze_instruction_coro_free(ira, (IrInstructionCoroFree *)instruction); case IrInstructionIdCoroResume: return ir_analyze_instruction_coro_resume(ira, (IrInstructionCoroResume *)instruction); + case IrInstructionIdCoroSave: + return ir_analyze_instruction_coro_save(ira, (IrInstructionCoroSave *)instruction); } zig_unreachable(); } @@ -17566,6 +17655,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdCoroAllocFail: case IrInstructionIdCoroEnd: case IrInstructionIdCoroResume: + case IrInstructionIdCoroSave: return true; case IrInstructionIdPhi: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index ca7eb25879..2e367672a5 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1090,6 +1090,12 @@ static void ir_print_coro_resume(IrPrint *irp, IrInstructionCoroResume *instruct fprintf(irp->f, ")"); } +static void ir_print_coro_save(IrPrint *irp, IrInstructionCoroSave *instruction) { + fprintf(irp->f, "@coroSave("); + ir_print_other_instruction(irp, instruction->coro_handle); + fprintf(irp->f, ")"); +} + static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { ir_print_prefix(irp, instruction); switch (instruction->id) { @@ -1443,6 +1449,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdCoroResume: ir_print_coro_resume(irp, (IrInstructionCoroResume *)instruction); break; + case IrInstructionIdCoroSave: + ir_print_coro_save(irp, (IrInstructionCoroSave *)instruction); + break; } fprintf(irp->f, "\n"); }