From fcadeb50c04199fce2d9675ba2976680c71c67ff Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 22 Jul 2019 14:36:14 -0400 Subject: [PATCH] fix multiple coroutines existing clobbering each other --- src/analyze.cpp | 68 +++++++++++++++++++++-------- src/analyze.hpp | 2 +- src/codegen.cpp | 66 +++++++++++++++++----------- test/runtime_safety.zig | 14 ++++++ test/stage1/behavior/coroutines.zig | 18 ++++++++ 5 files changed, 123 insertions(+), 45 deletions(-) diff --git a/src/analyze.cpp b/src/analyze.cpp index aff11e017f..4bb3de095e 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -1865,7 +1865,8 @@ static Error resolve_union_type(CodeGen *g, ZigType *union_type) { } static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) { - assert(frame_type->data.frame.locals_struct == nullptr); + if (frame_type->data.frame.locals_struct != nullptr) + return ErrorNone; ZigFn *fn = frame_type->data.frame.fn; switch (fn->anal_state) { @@ -3824,6 +3825,15 @@ static void analyze_fn_ir(CodeGen *g, ZigFn *fn_table_entry, AstNode *return_typ } fn_table_entry->anal_state = FnAnalStateComplete; + + if (fn_table_entry->resume_blocks.length != 0) { + ZigType *frame_type = get_coro_frame_type(g, fn_table_entry); + Error err; + if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) { + fn_table_entry->anal_state = FnAnalStateInvalid; + return; + } + } } static void analyze_fn_body(CodeGen *g, ZigFn *fn_table_entry) { @@ -7050,18 +7060,12 @@ static void resolve_llvm_types_array(CodeGen *g, ZigType *type) { debug_align_in_bits, get_llvm_di_type(g, elem_type), (int)type->data.array.len); } -void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) { - if (fn_type->llvm_di_type != nullptr) { - if (fn != nullptr) { - fn->raw_type_ref = fn_type->data.fn.raw_type_ref; - fn->raw_di_type = fn_type->data.fn.raw_di_type; - } - return; - } +static void resolve_llvm_types_fn_type(CodeGen *g, ZigType *fn_type) { + if (fn_type->llvm_di_type != nullptr) return; FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; bool first_arg_return = want_first_arg_sret(g, fn_type_id); - bool is_async = fn_type_id->cc == CallingConventionAsync || (fn != nullptr && fn->resume_blocks.length != 0); + bool is_async = fn_type_id->cc == CallingConventionAsync; bool is_c_abi = fn_type_id->cc == CallingConventionC; bool prefix_arg_error_return_trace = g->have_err_ret_tracing && fn_type_can_fail(fn_type_id); // +1 for maybe making the first argument the return value @@ -7100,7 +7104,11 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) { if (is_async) { fn_type->data.fn.gen_param_info = allocate(1); - ZigType *frame_type = (fn == nullptr) ? g->builtin_types.entry_frame_header : get_coro_frame_type(g, fn); + ZigType *frame_type = g->builtin_types.entry_frame_header; + Error err; + if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) { + zig_unreachable(); + } ZigType *ptr_type = get_pointer_to_type(g, frame_type, false); gen_param_types.append(get_llvm_type(g, ptr_type)); param_di_types.append(get_llvm_di_type(g, ptr_type)); @@ -7150,12 +7158,7 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) { for (size_t i = 0; i < gen_param_types.length; i += 1) { assert(gen_param_types.items[i] != nullptr); } - if (fn != nullptr) { - fn->raw_type_ref = LLVMFunctionType(get_llvm_type(g, gen_return_type), - gen_param_types.items, (unsigned int)gen_param_types.length, fn_type_id->is_var_args); - fn->raw_di_type = ZigLLVMCreateSubroutineType(g->dbuilder, param_di_types.items, (int)param_di_types.length, 0); - return; - } + fn_type->data.fn.raw_type_ref = LLVMFunctionType(get_llvm_type(g, gen_return_type), gen_param_types.items, (unsigned int)gen_param_types.length, fn_type_id->is_var_args); fn_type->llvm_type = LLVMPointerType(fn_type->data.fn.raw_type_ref, 0); @@ -7165,6 +7168,35 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) { LLVMABIAlignmentOfType(g->target_data_ref, fn_type->llvm_type), ""); } +void resolve_llvm_types_fn(CodeGen *g, ZigFn *fn) { + if (fn->raw_di_type != nullptr) return; + + ZigType *fn_type = fn->type_entry; + FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; + bool cc_async = fn_type_id->cc == CallingConventionAsync; + bool inferred_async = fn->resume_blocks.length != 0; + bool is_async = cc_async || inferred_async; + if (!is_async) { + resolve_llvm_types_fn_type(g, fn_type); + fn->raw_type_ref = fn_type->data.fn.raw_type_ref; + fn->raw_di_type = fn_type->data.fn.raw_di_type; + return; + } + + ZigType *gen_return_type = g->builtin_types.entry_usize; + ZigList param_di_types = {}; + // first "parameter" is return value + param_di_types.append(get_llvm_di_type(g, gen_return_type)); + + ZigType *frame_type = get_coro_frame_type(g, fn); + ZigType *ptr_type = get_pointer_to_type(g, frame_type, false); + LLVMTypeRef gen_param_type = get_llvm_type(g, ptr_type); + param_di_types.append(get_llvm_di_type(g, ptr_type)); + + fn->raw_type_ref = LLVMFunctionType(get_llvm_type(g, gen_return_type), &gen_param_type, 1, false); + fn->raw_di_type = ZigLLVMCreateSubroutineType(g->dbuilder, param_di_types.items, (int)param_di_types.length, 0); +} + static void resolve_llvm_types_anyerror(CodeGen *g) { ZigType *entry = g->builtin_types.entry_global_error_set; entry->llvm_type = get_llvm_type(g, g->err_tag_type); @@ -7241,7 +7273,7 @@ static void resolve_llvm_types(CodeGen *g, ZigType *type, ResolveStatus wanted_r case ZigTypeIdArray: return resolve_llvm_types_array(g, type); case ZigTypeIdFn: - return resolve_llvm_types_fn(g, type, nullptr); + return resolve_llvm_types_fn_type(g, type); case ZigTypeIdErrorSet: { if (type->llvm_di_type != nullptr) return; diff --git a/src/analyze.hpp b/src/analyze.hpp index 3f226080b5..57f4452104 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -247,6 +247,6 @@ void src_assert(bool ok, AstNode *source_node); bool is_container(ZigType *type_entry); ConstExprValue *analyze_const_value(CodeGen *g, Scope *scope, AstNode *node, ZigType *type_entry, Buf *type_name); -void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn); +void resolve_llvm_types_fn(CodeGen *g, ZigFn *fn); #endif diff --git a/src/codegen.cpp b/src/codegen.cpp index 339b04cc90..f3519ea72d 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -371,10 +371,12 @@ static LLVMValueRef fn_llvm_value(CodeGen *g, ZigFn *fn_table_entry) { symbol_name = buf_sprintf("\x01_%s", buf_ptr(symbol_name)); } + bool is_async = fn_table_entry->resume_blocks.length != 0 || cc == CallingConventionAsync; + ZigType *fn_type = fn_table_entry->type_entry; // Make the raw_type_ref populated - resolve_llvm_types_fn(g, fn_type, fn_table_entry); + resolve_llvm_types_fn(g, fn_table_entry); LLVMTypeRef fn_llvm_type = fn_table_entry->raw_type_ref; if (fn_table_entry->body_node == nullptr) { LLVMValueRef existing_llvm_fn = LLVMGetNamedFunction(g->module, buf_ptr(symbol_name)); @@ -397,7 +399,7 @@ static LLVMValueRef fn_llvm_value(CodeGen *g, ZigFn *fn_table_entry) { assert(entry->value->id == TldIdFn); TldFn *tld_fn = reinterpret_cast(entry->value); // Make the raw_type_ref populated - resolve_llvm_types_fn(g, tld_fn->fn_entry->type_entry, tld_fn->fn_entry); + resolve_llvm_types_fn(g, tld_fn->fn_entry); tld_fn->fn_entry->llvm_value = LLVMAddFunction(g->module, buf_ptr(symbol_name), tld_fn->fn_entry->raw_type_ref); fn_table_entry->llvm_value = LLVMConstBitCast(tld_fn->fn_entry->llvm_value, @@ -517,18 +519,22 @@ static LLVMValueRef fn_llvm_value(CodeGen *g, ZigFn *fn_table_entry) { init_gen_i = 1; } - // set parameter attributes - FnWalk fn_walk = {}; - fn_walk.id = FnWalkIdAttrs; - fn_walk.data.attrs.fn = fn_table_entry; - fn_walk.data.attrs.gen_i = init_gen_i; - walk_function_params(g, fn_type, &fn_walk); + if (is_async) { + addLLVMArgAttr(fn_table_entry->llvm_value, 0, "nonnull"); + } else { + // set parameter attributes + FnWalk fn_walk = {}; + fn_walk.id = FnWalkIdAttrs; + fn_walk.data.attrs.fn = fn_table_entry; + fn_walk.data.attrs.gen_i = init_gen_i; + walk_function_params(g, fn_type, &fn_walk); - uint32_t err_ret_trace_arg_index = get_err_ret_trace_arg_index(g, fn_table_entry); - if (err_ret_trace_arg_index != UINT32_MAX) { - // Error return trace memory is in the stack, which is impossible to be at address 0 - // on any architecture. - addLLVMArgAttr(fn_table_entry->llvm_value, (unsigned)err_ret_trace_arg_index, "nonnull"); + uint32_t err_ret_trace_arg_index = get_err_ret_trace_arg_index(g, fn_table_entry); + if (err_ret_trace_arg_index != UINT32_MAX) { + // Error return trace memory is in the stack, which is impossible to be at address 0 + // on any architecture. + addLLVMArgAttr(fn_table_entry->llvm_value, (unsigned)err_ret_trace_arg_index, "nonnull"); + } } return fn_table_entry->llvm_value; @@ -6254,14 +6260,21 @@ static void do_code_gen(CodeGen *g) { } else if (is_c_abi) { fn_walk_var.data.vars.var = var; iter_function_params_c_abi(g, fn_table_entry->type_entry, &fn_walk_var, var->src_arg_index); + } else if (is_async) { + var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_arg_start + var_i, ""); + if (var->decl_node) { + var->di_loc_var = ZigLLVMCreateAutoVariable(g->dbuilder, get_di_scope(g, var->parent_scope), + buf_ptr(&var->name), import->data.structure.root_struct->di_file, + (unsigned)(var->decl_node->line + 1), + get_llvm_di_type(g, var->var_type), !g->strip_debug_symbols, 0); + gen_var_debug_decl(g, var); + } } else { ZigType *gen_type; FnGenParamInfo *gen_info = &fn_table_entry->type_entry->data.fn.gen_param_info[var->src_arg_index]; assert(gen_info->gen_index != SIZE_MAX); - if (is_async) { - var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_arg_start + var_i, ""); - } else if (handle_is_ptr(var->var_type)) { + if (handle_is_ptr(var->var_type)) { if (gen_info->is_byval) { gen_type = var->var_type; } else { @@ -6307,16 +6320,7 @@ static void do_code_gen(CodeGen *g) { gen_store(g, LLVMConstInt(usize->llvm_type, stack_trace_ptr_count, false), len_field_ptr, get_pointer_to_type(g, usize, false)); } - // create debug variable declarations for parameters - // rely on the first variables in the variable_list being parameters. - FnWalk fn_walk_init = {}; - fn_walk_init.id = FnWalkIdInits; - fn_walk_init.data.inits.fn = fn_table_entry; - fn_walk_init.data.inits.llvm_fn = fn; - fn_walk_init.data.inits.gen_i = gen_i_init; - walk_function_params(g, fn_table_entry->type_entry, &fn_walk_init); - - if (fn_table_entry->resume_blocks.length != 0) { + if (is_async) { if (!g->strip_debug_symbols) { AstNode *source_node = fn_table_entry->proto_node; ZigLLVMSetCurrentDebugLocation(g->builder, (int)source_node->line + 1, @@ -6354,8 +6358,18 @@ static void do_code_gen(CodeGen *g) { LLVMValueRef case_value = LLVMConstInt(usize_type_ref, resume_i + 2, false); LLVMAddCase(switch_instr, case_value, fn_table_entry->resume_blocks.at(resume_i)->llvm_block); } + } else { + // create debug variable declarations for parameters + // rely on the first variables in the variable_list being parameters. + FnWalk fn_walk_init = {}; + fn_walk_init.id = FnWalkIdInits; + fn_walk_init.data.inits.fn = fn_table_entry; + fn_walk_init.data.inits.llvm_fn = fn; + fn_walk_init.data.inits.gen_i = gen_i_init; + walk_function_params(g, fn_table_entry->type_entry, &fn_walk_init); } + ir_render(g, fn_table_entry); } diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index aee607e10f..336dbb8bf0 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -1,6 +1,20 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompareOutputContext) void { + cases.addRuntimeSafety("invalid resume of async function", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\pub fn main() void { + \\ var p = async suspendOnce(); + \\ resume p; //ok + \\ resume p; //bad + \\} + \\fn suspendOnce() void { + \\ suspend; + \\} + ); + cases.addRuntimeSafety(".? operator on null pointer", \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { \\ @import("std").os.exit(126); diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index 3a54657020..4ecd4efd13 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -29,6 +29,24 @@ fn simpleAsyncFnWithArg(delta: i32) void { suspend; global_y += delta; } + +test "suspend at end of function" { + const S = struct { + var x: i32 = 1; + + fn doTheTest() void { + expect(x == 1); + const p = async suspendAtEnd(); + expect(x == 2); + } + + fn suspendAtEnd() void { + x += 1; + suspend; + } + }; + S.doTheTest(); +} //test "coroutine suspend, resume" { // seq('a'); // const p = try async testAsyncSeq();