diff --git a/src/all_types.hpp b/src/all_types.hpp index 1096feade0..cd64c149d9 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1726,6 +1726,7 @@ struct CodeGen { LLVMValueRef err_name_table; LLVMValueRef safety_crash_err_fn; LLVMValueRef return_err_fn; + LLVMTypeRef async_fn_llvm_type; // reminder: hash tables must be initialized before use HashMap import_table; @@ -1793,7 +1794,6 @@ struct CodeGen { ZigType *entry_global_error_set; ZigType *entry_arg_tuple; ZigType *entry_enum_literal; - ZigType *entry_frame_header; ZigType *entry_any_frame; } builtin_types; ZigType *align_amt_type; diff --git a/src/analyze.cpp b/src/analyze.cpp index e47be8f14c..c117409445 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -7348,19 +7348,13 @@ static void resolve_llvm_types_fn_type(CodeGen *g, ZigType *fn_type) { if (is_async) { fn_type->data.fn.gen_param_info = allocate(1); - ZigType *frame_type = g->builtin_types.entry_frame_header; - Error err; - if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) { - zig_unreachable(); - } - ZigType *ptr_type = get_pointer_to_type(g, frame_type, false); - gen_param_types.append(get_llvm_type(g, ptr_type)); - param_di_types.append(get_llvm_di_type(g, ptr_type)); + ZigType *frame_type = get_any_frame_type(g, fn_type_id->return_type); + gen_param_types.append(get_llvm_type(g, frame_type)); + param_di_types.append(get_llvm_di_type(g, frame_type)); fn_type->data.fn.gen_param_info[0].src_index = 0; fn_type->data.fn.gen_param_info[0].gen_index = 0; - fn_type->data.fn.gen_param_info[0].type = ptr_type; - + fn_type->data.fn.gen_param_info[0].type = frame_type; } else { fn_type->data.fn.gen_param_info = allocate(fn_type_id->param_count); for (size_t i = 0; i < fn_type_id->param_count; i += 1) { diff --git a/src/codegen.cpp b/src/codegen.cpp index c666317c17..0ee902b537 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -4902,14 +4902,28 @@ static LLVMValueRef ir_render_suspend_br(CodeGen *g, IrExecutable *executable, return nullptr; } +static LLVMTypeRef async_fn_llvm_type(CodeGen *g) { + if (g->async_fn_llvm_type != nullptr) + return g->async_fn_llvm_type; + + ZigType *anyframe_type = get_any_frame_type(g, nullptr); + LLVMTypeRef param_type = get_llvm_type(g, anyframe_type); + LLVMTypeRef return_type = LLVMVoidType(); + LLVMTypeRef fn_type = LLVMFunctionType(return_type, ¶m_type, 1, false); + g->async_fn_llvm_type = LLVMPointerType(fn_type, 0); + + return g->async_fn_llvm_type; +} + static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable, IrInstructionCoroResume *instruction) { LLVMValueRef frame = ir_llvm_value(g, instruction->frame); ZigType *frame_type = instruction->frame->value.type; - assert(frame_type->id == ZigTypeIdCoroFrame); - ZigFn *fn = frame_type->data.frame.fn; - LLVMValueRef fn_val = fn_llvm_value(g, fn); + assert(frame_type->id == ZigTypeIdAnyFrame); + LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame, coro_fn_ptr_index, ""); + LLVMValueRef uncasted_fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, ""); + LLVMValueRef fn_val = LLVMBuildIntToPtr(g->builder, uncasted_fn_val, async_fn_llvm_type(g), ""); ZigLLVMBuildCall(g->builder, fn_val, &frame, 1, LLVMFastCallConv, ZigLLVM_FnInlineAuto, ""); return nullptr; } @@ -6746,11 +6760,6 @@ static void define_builtin_types(CodeGen *g) { g->primitive_type_table.put(&entry->name, entry); } - { - const char *field_names[] = {"resume_index"}; - ZigType *field_types[] = {g->builtin_types.entry_usize}; - g->builtin_types.entry_frame_header = get_struct_type(g, "(frame header)", field_names, field_types, 1); - } } static BuiltinFnEntry *create_builtin_fn(CodeGen *g, BuiltinFnId id, const char *name, size_t count) { diff --git a/src/ir.cpp b/src/ir.cpp index e6d987a2ee..98a8f1061e 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -7764,7 +7764,7 @@ static IrInstruction *ir_gen_cancel(IrBuilder *irb, Scope *scope, AstNode *node) static IrInstruction *ir_gen_resume(IrBuilder *irb, Scope *scope, AstNode *node) { assert(node->type == NodeTypeResume); - IrInstruction *target_inst = ir_gen_node(irb, node->data.resume_expr.expr, scope); + IrInstruction *target_inst = ir_gen_node_extra(irb, node->data.resume_expr.expr, scope, LValPtr, nullptr); if (target_inst == irb->codegen->invalid_instruction) return irb->codegen->invalid_instruction; @@ -10882,6 +10882,33 @@ static IrInstruction *ir_analyze_err_set_cast(IrAnalyze *ira, IrInstruction *sou return result; } +static IrInstruction *ir_analyze_frame_ptr_to_anyframe(IrAnalyze *ira, IrInstruction *source_instr, + IrInstruction *value, ZigType *wanted_type) +{ + if (instr_is_comptime(value)) { + zig_panic("TODO comptime frame pointer"); + } + + IrInstruction *result = ir_build_cast(&ira->new_irb, source_instr->scope, source_instr->source_node, + wanted_type, value, CastOpBitCast); + result->value.type = wanted_type; + return result; +} + +static IrInstruction *ir_analyze_anyframe_to_anyframe(IrAnalyze *ira, IrInstruction *source_instr, + IrInstruction *value, ZigType *wanted_type) +{ + if (instr_is_comptime(value)) { + zig_panic("TODO comptime anyframe->T to anyframe"); + } + + IrInstruction *result = ir_build_cast(&ira->new_irb, source_instr->scope, source_instr->source_node, + wanted_type, value, CastOpBitCast); + result->value.type = wanted_type; + return result; +} + + static IrInstruction *ir_analyze_err_wrap_code(IrAnalyze *ira, IrInstruction *source_instr, IrInstruction *value, ZigType *wanted_type, ResultLoc *result_loc) { @@ -11978,6 +12005,29 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst } } + // *@Frame(func) to anyframe->T or anyframe + if (actual_type->id == ZigTypeIdPointer && actual_type->data.pointer.ptr_len == PtrLenSingle && + actual_type->data.pointer.child_type->id == ZigTypeIdCoroFrame && wanted_type->id == ZigTypeIdAnyFrame) + { + bool ok = true; + if (wanted_type->data.any_frame.result_type != nullptr) { + ZigFn *fn = actual_type->data.pointer.child_type->data.frame.fn; + ZigType *fn_return_type = fn->type_entry->data.fn.fn_type_id.return_type; + if (wanted_type->data.any_frame.result_type != fn_return_type) { + ok = false; + } + } + if (ok) { + return ir_analyze_frame_ptr_to_anyframe(ira, source_instr, value, wanted_type); + } + } + + // anyframe->T to anyframe + if (actual_type->id == ZigTypeIdAnyFrame && actual_type->data.any_frame.result_type != nullptr && + wanted_type->id == ZigTypeIdAnyFrame && wanted_type->data.any_frame.result_type == nullptr) + { + return ir_analyze_anyframe_to_anyframe(ira, source_instr, value, wanted_type); + } // cast from null literal to maybe type if (wanted_type->id == ZigTypeIdOptional && @@ -24323,17 +24373,27 @@ static IrInstruction *ir_analyze_instruction_suspend_br(IrAnalyze *ira, IrInstru } static IrInstruction *ir_analyze_instruction_coro_resume(IrAnalyze *ira, IrInstructionCoroResume *instruction) { - IrInstruction *frame = instruction->frame->child; - if (type_is_invalid(frame->value.type)) + IrInstruction *frame_ptr = instruction->frame->child; + if (type_is_invalid(frame_ptr->value.type)) return ira->codegen->invalid_instruction; - if (frame->value.type->id != ZigTypeIdCoroFrame) { - ir_add_error(ira, instruction->frame, - buf_sprintf("expected frame, found '%s'", buf_ptr(&frame->value.type->name))); - return ira->codegen->invalid_instruction; + IrInstruction *frame; + if (frame_ptr->value.type->id == ZigTypeIdPointer && + frame_ptr->value.type->data.pointer.ptr_len == PtrLenSingle && + frame_ptr->value.type->data.pointer.is_const && + frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdAnyFrame) + { + frame = ir_get_deref(ira, &instruction->base, frame_ptr, nullptr); + } else { + frame = frame_ptr; } - return ir_build_coro_resume(&ira->new_irb, instruction->base.scope, instruction->base.source_node, frame); + ZigType *any_frame_type = get_any_frame_type(ira->codegen, nullptr); + IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type); + if (type_is_invalid(casted_frame->value.type)) + return ira->codegen->invalid_instruction; + + return ir_build_coro_resume(&ira->new_irb, instruction->base.scope, instruction->base.source_node, casted_frame); } static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) { diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index 7af04d37c9..fddc912e77 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -5,15 +5,20 @@ const expect = std.testing.expect; var global_x: i32 = 1; test "simple coroutine suspend and resume" { - const p = async simpleAsyncFn(); + const frame = async simpleAsyncFn(); expect(global_x == 2); - resume p; + resume frame; expect(global_x == 3); + const af: anyframe->void = &frame; + resume frame; + expect(global_x == 4); } fn simpleAsyncFn() void { global_x += 1; suspend; global_x += 1; + suspend; + global_x += 1; } var global_y: i32 = 1;