diff --git a/BRANCH_TODO b/BRANCH_TODO index 7c19147aa8..4e128b78a1 100644 --- a/BRANCH_TODO +++ b/BRANCH_TODO @@ -1,5 +1,3 @@ - * fix @frameSize - * fix calling an inferred async function * await * await of a non async function * await in single-threaded mode diff --git a/src/all_types.hpp b/src/all_types.hpp index b5b8b06259..e66c9aebff 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -2605,7 +2605,6 @@ struct IrInstructionCallGen { IrInstruction **args; IrInstruction *result_loc; IrInstruction *frame_result_loc; - IrBasicBlock *resume_block; IrInstruction *new_stack; FnInline fn_inline; diff --git a/src/analyze.cpp b/src/analyze.cpp index 5e22358423..99caf9688b 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5185,13 +5185,6 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) { if (!fn_is_async(callee)) continue; - IrBasicBlock *new_resume_block = allocate(1); - new_resume_block->name_hint = "CallResume"; - new_resume_block->split_llvm_fn = reinterpret_cast(0x1); - fn->resume_blocks.append(new_resume_block); - call->resume_block = new_resume_block; - fn->analyzed_executable.basic_block_list.append(new_resume_block); - ZigType *callee_frame_type = get_coro_frame_type(g, callee); IrInstructionAllocaGen *alloca_gen = allocate(1); diff --git a/src/codegen.cpp b/src/codegen.cpp index d955736083..d0aadaabe1 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3327,6 +3327,92 @@ static void set_call_instr_sret(CodeGen *g, LLVMValueRef call_instr) { LLVMAddCallSiteAttribute(call_instr, 1, sret_attr); } +static void render_async_spills(CodeGen *g) { + ZigType *fn_type = g->cur_fn->type_entry; + ZigType *import = get_scope_import(&g->cur_fn->fndef_scope->base); + size_t async_var_index = coro_arg_start + (type_has_bits(fn_type->data.fn.fn_type_id.return_type) ? 2 : 0); + for (size_t var_i = 0; var_i < g->cur_fn->variable_list.length; var_i += 1) { + ZigVar *var = g->cur_fn->variable_list.at(var_i); + + if (!type_has_bits(var->var_type)) { + continue; + } + if (ir_get_var_is_comptime(var)) + continue; + switch (type_requires_comptime(g, var->var_type)) { + case ReqCompTimeInvalid: + zig_unreachable(); + case ReqCompTimeYes: + continue; + case ReqCompTimeNo: + break; + } + if (var->src_arg_index == SIZE_MAX) { + continue; + } + + var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index, + buf_ptr(&var->name)); + async_var_index += 1; + if (var->decl_node) { + var->di_loc_var = ZigLLVMCreateAutoVariable(g->dbuilder, get_di_scope(g, var->parent_scope), + buf_ptr(&var->name), import->data.structure.root_struct->di_file, + (unsigned)(var->decl_node->line + 1), + get_llvm_di_type(g, var->var_type), !g->strip_debug_symbols, 0); + gen_var_debug_decl(g, var); + } + } + for (size_t alloca_i = 0; alloca_i < g->cur_fn->alloca_gen_list.length; alloca_i += 1) { + IrInstructionAllocaGen *instruction = g->cur_fn->alloca_gen_list.at(alloca_i); + ZigType *ptr_type = instruction->base.value.type; + assert(ptr_type->id == ZigTypeIdPointer); + ZigType *child_type = ptr_type->data.pointer.child_type; + if (!type_has_bits(child_type)) + continue; + if (instruction->base.ref_count == 0) + continue; + if (instruction->base.value.special != ConstValSpecialRuntime) { + if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special != + ConstValSpecialRuntime) + { + continue; + } + } + instruction->base.llvm_value = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index, + instruction->name_hint); + async_var_index += 1; + } +} + +static void render_async_var_decls(CodeGen *g, Scope *scope) { + render_async_spills(g); + for (;;) { + switch (scope->id) { + case ScopeIdCImport: + zig_unreachable(); + case ScopeIdFnDef: + return; + case ScopeIdVarDecl: { + ZigVar *var = reinterpret_cast(scope)->var; + if (var->ptr_instruction != nullptr) { + render_decl_var(g, var); + } + // fallthrough + } + case ScopeIdDecls: + case ScopeIdBlock: + case ScopeIdDefer: + case ScopeIdDeferExpr: + case ScopeIdLoop: + case ScopeIdSuspend: + case ScopeIdCompTime: + case ScopeIdRuntime: + scope = scope->parent; + continue; + } + } +} + static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstructionCallGen *instruction) { LLVMValueRef fn_val; ZigType *fn_type; @@ -3431,15 +3517,19 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr ZigLLVMBuildCall(g->builder, fn_val, &frame_result_loc, 1, llvm_cc, fn_inline, ""); return nullptr; } else if (callee_is_async) { + LLVMValueRef split_llvm_fn = make_fn_llvm_value(g, g->cur_fn); LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_fn_ptr_index, ""); - LLVMValueRef new_fn_ptr = instruction->resume_block->split_llvm_fn; - LLVMBuildStore(g->builder, new_fn_ptr, fn_ptr_ptr); + LLVMBuildStore(g->builder, split_llvm_fn, fn_ptr_ptr); LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, fn_val, &frame_result_loc, 1, llvm_cc, fn_inline, ""); ZigLLVMSetTailCall(call_inst); LLVMBuildRetVoid(g->builder); - LLVMPositionBuilderAtEnd(g->builder, instruction->resume_block->llvm_block); + g->cur_fn_val = split_llvm_fn; + g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0); + LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "CallResume"); + LLVMPositionBuilderAtEnd(g->builder, call_bb); + render_async_var_decls(g, instruction->base.scope); return nullptr; } @@ -5193,92 +5283,6 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, zig_unreachable(); } -static void render_async_spills(CodeGen *g) { - ZigType *fn_type = g->cur_fn->type_entry; - ZigType *import = get_scope_import(&g->cur_fn->fndef_scope->base); - size_t async_var_index = coro_arg_start + (type_has_bits(fn_type->data.fn.fn_type_id.return_type) ? 2 : 0); - for (size_t var_i = 0; var_i < g->cur_fn->variable_list.length; var_i += 1) { - ZigVar *var = g->cur_fn->variable_list.at(var_i); - - if (!type_has_bits(var->var_type)) { - continue; - } - if (ir_get_var_is_comptime(var)) - continue; - switch (type_requires_comptime(g, var->var_type)) { - case ReqCompTimeInvalid: - zig_unreachable(); - case ReqCompTimeYes: - continue; - case ReqCompTimeNo: - break; - } - if (var->src_arg_index == SIZE_MAX) { - continue; - } - - var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index, - buf_ptr(&var->name)); - async_var_index += 1; - if (var->decl_node) { - var->di_loc_var = ZigLLVMCreateAutoVariable(g->dbuilder, get_di_scope(g, var->parent_scope), - buf_ptr(&var->name), import->data.structure.root_struct->di_file, - (unsigned)(var->decl_node->line + 1), - get_llvm_di_type(g, var->var_type), !g->strip_debug_symbols, 0); - gen_var_debug_decl(g, var); - } - } - for (size_t alloca_i = 0; alloca_i < g->cur_fn->alloca_gen_list.length; alloca_i += 1) { - IrInstructionAllocaGen *instruction = g->cur_fn->alloca_gen_list.at(alloca_i); - ZigType *ptr_type = instruction->base.value.type; - assert(ptr_type->id == ZigTypeIdPointer); - ZigType *child_type = ptr_type->data.pointer.child_type; - if (!type_has_bits(child_type)) - continue; - if (instruction->base.ref_count == 0) - continue; - if (instruction->base.value.special != ConstValSpecialRuntime) { - if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special != - ConstValSpecialRuntime) - { - continue; - } - } - instruction->base.llvm_value = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index, - instruction->name_hint); - async_var_index += 1; - } -} - -static void render_async_var_decls(CodeGen *g, Scope *scope) { - render_async_spills(g); - for (;;) { - switch (scope->id) { - case ScopeIdCImport: - zig_unreachable(); - case ScopeIdFnDef: - return; - case ScopeIdVarDecl: { - ZigVar *var = reinterpret_cast(scope)->var; - if (var->ptr_instruction != nullptr) { - render_decl_var(g, var); - } - // fallthrough - } - case ScopeIdDecls: - case ScopeIdBlock: - case ScopeIdDefer: - case ScopeIdDeferExpr: - case ScopeIdLoop: - case ScopeIdSuspend: - case ScopeIdCompTime: - case ScopeIdRuntime: - scope = scope->parent; - continue; - } - } -} - static void ir_render(CodeGen *g, ZigFn *fn_entry) { assert(fn_entry); diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index 7a8a4a07df..fddc912e77 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -82,55 +82,55 @@ test "local variable in async function" { S.doTheTest(); } -//test "calling an inferred async function" { -// const S = struct { -// var x: i32 = 1; -// var other_frame: *@Frame(other) = undefined; -// -// fn doTheTest() void { -// const p = async first(); -// expect(x == 1); -// resume other_frame.*; -// expect(x == 2); -// } -// -// fn first() void { -// other(); -// } -// fn other() void { -// other_frame = @frame(); -// suspend; -// x += 1; -// } -// }; -// S.doTheTest(); -//} -// -//test "@frameSize" { -// const S = struct { -// fn doTheTest() void { -// { -// var ptr = @ptrCast(async fn(i32) void, other); -// const size = @frameSize(ptr); -// expect(size == @sizeOf(@Frame(other))); -// } -// { -// var ptr = @ptrCast(async fn() void, first); -// const size = @frameSize(ptr); -// expect(size == @sizeOf(@Frame(first))); -// } -// } -// -// fn first() void { -// other(1); -// } -// fn other(param: i32) void { -// var local: i32 = undefined; -// suspend; -// } -// }; -// S.doTheTest(); -//} +test "calling an inferred async function" { + const S = struct { + var x: i32 = 1; + var other_frame: *@Frame(other) = undefined; + + fn doTheTest() void { + const p = async first(); + expect(x == 1); + resume other_frame.*; + expect(x == 2); + } + + fn first() void { + other(); + } + fn other() void { + other_frame = @frame(); + suspend; + x += 1; + } + }; + S.doTheTest(); +} + +test "@frameSize" { + const S = struct { + fn doTheTest() void { + { + var ptr = @ptrCast(async fn(i32) void, other); + const size = @frameSize(ptr); + expect(size == @sizeOf(@Frame(other))); + } + { + var ptr = @ptrCast(async fn() void, first); + const size = @frameSize(ptr); + expect(size == @sizeOf(@Frame(first))); + } + } + + fn first() void { + other(1); + } + fn other(param: i32) void { + var local: i32 = undefined; + suspend; + } + }; + S.doTheTest(); +} //test "coroutine suspend, resume" { // seq('a');