From 9ca8d9e21ad657b023c23db5c440fb79a3303771 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 6 Sep 2019 16:17:39 -0400 Subject: [PATCH] fix await used in an expression generating bad LLVM --- src/analyze.cpp | 112 +++++++++++++++++++++--------- src/analyze.hpp | 4 ++ src/codegen.cpp | 25 ++++--- test/stage1/behavior/async_fn.zig | 16 +++++ 4 files changed, 113 insertions(+), 44 deletions(-) diff --git a/src/analyze.cpp b/src/analyze.cpp index fa93a9764c..c7da620428 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -4232,31 +4232,40 @@ static Error analyze_callee_async(CodeGen *g, ZigFn *fn, ZigFn *callee, AstNode { if (modifier == CallModifierNoAsync) return ErrorNone; - if (callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified) - return ErrorNone; - if (callee->anal_state == FnAnalStateReady) { - analyze_fn_body(g, callee); - if (callee->anal_state == FnAnalStateInvalid) { - return ErrorSemanticAnalyzeFail; - } + bool callee_is_async = false; + switch (callee->type_entry->data.fn.fn_type_id.cc) { + case CallingConventionUnspecified: + break; + case CallingConventionAsync: + callee_is_async = true; + break; + default: + return ErrorNone; } - bool callee_is_async; - if (callee->anal_state == FnAnalStateComplete) { - analyze_fn_async(g, callee, true); - if (callee->anal_state == FnAnalStateInvalid) { - return ErrorSemanticAnalyzeFail; + if (!callee_is_async) { + if (callee->anal_state == FnAnalStateReady) { + analyze_fn_body(g, callee); + if (callee->anal_state == FnAnalStateInvalid) { + return ErrorSemanticAnalyzeFail; + } } - callee_is_async = fn_is_async(callee); - } else { - // If it's already been determined, use that value. Otherwise - // assume non-async, emit an error later if it turned out to be async. - if (callee->inferred_async_node == nullptr || - callee->inferred_async_node == inferred_async_checking) - { - callee->assumed_non_async = call_node; - callee_is_async = false; + if (callee->anal_state == FnAnalStateComplete) { + analyze_fn_async(g, callee, true); + if (callee->anal_state == FnAnalStateInvalid) { + return ErrorSemanticAnalyzeFail; + } + callee_is_async = fn_is_async(callee); } else { - callee_is_async = callee->inferred_async_node != inferred_async_none; + // If it's already been determined, use that value. Otherwise + // assume non-async, emit an error later if it turned out to be async. + if (callee->inferred_async_node == nullptr || + callee->inferred_async_node == inferred_async_checking) + { + callee->assumed_non_async = call_node; + callee_is_async = false; + } else { + callee_is_async = callee->inferred_async_node != inferred_async_none; + } } } if (callee_is_async) { @@ -4333,6 +4342,8 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn, bool resolve_frame) { } for (size_t i = 0; i < fn->await_list.length; i += 1) { IrInstructionAwaitGen *await = fn->await_list.at(i); + // TODO If this is a noasync await, it doesn't count + // https://github.com/ziglang/zig/issues/3157 switch (analyze_callee_async(g, fn, await->target_fn, await->base.source_node, must_not_be_async, CallModifierNone)) { @@ -5771,15 +5782,39 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) { if (!fn_is_async(callee)) continue; - IrInstructionAllocaGen *alloca_gen = allocate(1); - alloca_gen->base.id = IrInstructionIdAllocaGen; - alloca_gen->base.source_node = call->base.source_node; - alloca_gen->base.scope = call->base.scope; - alloca_gen->base.value.type = get_pointer_to_type(g, callee_frame_type, false); - alloca_gen->base.ref_count = 1; - alloca_gen->name_hint = ""; - fn->alloca_gen_list.append(alloca_gen); - call->frame_result_loc = &alloca_gen->base; + call->frame_result_loc = ir_create_alloca(g, call->base.scope, call->base.source_node, fn, + callee_frame_type, ""); + } + // Since this frame is async, an await might represent a suspend point, and + // therefore need to spill. + for (size_t i = 0; i < fn->await_list.length; i += 1) { + IrInstructionAwaitGen *await = fn->await_list.at(i); + // TODO If this is a noasync await, it doesn't need to spill + // https://github.com/ziglang/zig/issues/3157 + if (await->result_loc != nullptr) { + // If there's a result location, that is the spill + continue; + } + if (!type_has_bits(await->base.value.type)) + continue; + if (await->base.value.special != ConstValSpecialRuntime) + continue; + if (await->base.ref_count == 0) + continue; + if (await->target_fn != nullptr) { + // we might not need to suspend + analyze_fn_async(g, await->target_fn, false); + if (await->target_fn->anal_state == FnAnalStateInvalid) { + frame_type->data.frame.locals_struct = g->builtin_types.entry_invalid; + return ErrorSemanticAnalyzeFail; + } + if (!fn_is_async(await->target_fn)) { + // This await does not represent a suspend point. No spill needed. + continue; + } + } + await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn, + await->base.value.type, ""); } FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false); @@ -8505,3 +8540,18 @@ void src_assert(bool ok, AstNode *source_node) { const char *msg = "assertion failed. This is a bug in the Zig compiler."; stage2_panic(msg, strlen(msg)); } + +IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn, + ZigType *var_type, const char *name_hint) +{ + IrInstructionAllocaGen *alloca_gen = allocate(1); + alloca_gen->base.id = IrInstructionIdAllocaGen; + alloca_gen->base.source_node = source_node; + alloca_gen->base.scope = scope; + alloca_gen->base.value.type = get_pointer_to_type(g, var_type, false); + alloca_gen->base.ref_count = 1; + alloca_gen->name_hint = name_hint; + fn->alloca_gen_list.append(alloca_gen); + return &alloca_gen->base; +} + diff --git a/src/analyze.hpp b/src/analyze.hpp index 9f2c984992..2178327571 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -258,4 +258,8 @@ ZigType *resolve_struct_field_type(CodeGen *g, TypeStructField *struct_field); void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn); +IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn, + ZigType *var_type, const char *name_hint); + + #endif diff --git a/src/codegen.cpp b/src/codegen.cpp index 6c03be32c3..bbb1d9fc87 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1661,6 +1661,14 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) { if (!type_has_bits(instruction->value.type)) return nullptr; if (!instruction->llvm_value) { + if (instruction->id == IrInstructionIdAwaitGen) { + IrInstructionAwaitGen *await = reinterpret_cast(instruction); + if (await->result_loc != nullptr) { + instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc), + await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type); + return instruction->llvm_value; + } + } src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node); assert(instruction->value.type); render_const_val(g, &instruction->value, ""); @@ -5645,7 +5653,6 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst // At this point resuming the function will continue from resume_bb. // This code is as if it is running inside the suspend block. - // supply the awaiter return pointer if (type_has_bits(result_type)) { LLVMValueRef awaiter_ret_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, frame_ret_start + 1, ""); @@ -5703,9 +5710,8 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst LLVMBuildBr(g->builder, end_bb); LLVMPositionBuilderAtEnd(g->builder, end_bb); - if (type_has_bits(result_type) && result_loc != nullptr) { - return get_handle_value(g, result_loc, result_type, ptr_result_type); - } + // Rely on the spill for the llvm_value to be populated. + // See the implementation of ir_llvm_value. return nullptr; } @@ -7153,15 +7159,8 @@ static void do_code_gen(CodeGen *g) { if (call->frame_result_loc != nullptr) continue; ZigType *callee_frame_type = get_fn_frame_type(g, call->fn_entry); - IrInstructionAllocaGen *alloca_gen = allocate(1); - alloca_gen->base.id = IrInstructionIdAllocaGen; - alloca_gen->base.source_node = call->base.source_node; - alloca_gen->base.scope = call->base.scope; - alloca_gen->base.value.type = get_pointer_to_type(g, callee_frame_type, false); - alloca_gen->base.ref_count = 1; - alloca_gen->name_hint = ""; - fn_table_entry->alloca_gen_list.append(alloca_gen); - call->frame_result_loc = &alloca_gen->base; + call->frame_result_loc = ir_create_alloca(g, call->base.scope, call->base.source_node, + fn_table_entry, callee_frame_type, ""); } // allocate temporary stack data for (size_t alloca_i = 0; alloca_i < fn_table_entry->alloca_gen_list.length; alloca_i += 1) { diff --git a/test/stage1/behavior/async_fn.zig b/test/stage1/behavior/async_fn.zig index a898889f5c..3079a7b98a 100644 --- a/test/stage1/behavior/async_fn.zig +++ b/test/stage1/behavior/async_fn.zig @@ -1108,3 +1108,19 @@ test "noasync function call" { }; S.doTheTest(); } + +test "await used in expression and awaiting fn with no suspend but async calling convention" { + const S = struct { + fn atest() void { + var f1 = async add(1, 2); + var f2 = async add(3, 4); + + const sum = (await f1) + (await f2); + expect(sum == 10); + } + async fn add(a: i32, b: i32) i32 { + return a + b; + } + }; + _ = async S.atest(); +}