diff --git a/BRANCH_TODO b/BRANCH_TODO index f3d881f5e5..a9bc5f3666 100644 --- a/BRANCH_TODO +++ b/BRANCH_TODO @@ -6,7 +6,6 @@ * @asyncCall with an async function pointer * cancel * defer and errdefer - * safety for resuming when it is awaiting * safety for double await * implicit cast of normal function to async function should be allowed when it is inferred to be async * go over the commented out tests @@ -19,3 +18,6 @@ * make sure there are safety tests for all the new safety features (search the new PanicFnId enum values) * error return tracing * compile error for casting a function to a non-async function pointer, but then later it gets inferred to be an async function + * compile error for copying a frame + * compile error for resuming a const frame pointer + * runtime safety enabling/disabling scope has to be coordinated across resume/await/calls/return diff --git a/src/all_types.hpp b/src/all_types.hpp index 079b8ded95..0f8cce1376 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1552,6 +1552,7 @@ enum PanicMsgId { PanicMsgIdBadResume, PanicMsgIdBadAwait, PanicMsgIdBadReturn, + PanicMsgIdResumedAnAwaitingFn, PanicMsgIdCount, }; diff --git a/src/codegen.cpp b/src/codegen.cpp index b050e02e0a..db617e636a 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -877,6 +877,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { return buf_create_from_str("async function awaited twice"); case PanicMsgIdBadReturn: return buf_create_from_str("async function returned twice"); + case PanicMsgIdResumedAnAwaitingFn: + return buf_create_from_str("awaiting function resumed"); } zig_unreachable(); } @@ -2018,7 +2020,10 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns } result_ptr_as_usize = LLVMBuildPtrToInt(g->builder, result_ptr, usize_type_ref, ""); } else { - result_ptr_as_usize = LLVMGetUndef(usize_type_ref); + // For debug safety, this value has to be anything other than all 1's, which signals + // that it is being resumed. 0 is a bad choice since null pointers are special. + result_ptr_as_usize = ir_want_runtime_safety(g, &return_instruction->base) ? + LLVMConstInt(usize_type_ref, 1, false) : LLVMGetUndef(usize_type_ref); } LLVMValueRef zero = LLVMConstNull(usize_type_ref); LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref); @@ -3582,8 +3587,9 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr LLVMBuildStore(g->builder, gen_param_values.at(arg_i), arg_ptr); } } + LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; if (instruction->is_async) { - LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)}; + LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(usize_type_ref)}; ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, ""); return nullptr; } else if (callee_is_async) { @@ -3591,8 +3597,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr LLVMValueRef split_llvm_fn = make_fn_llvm_value(g, g->cur_fn); LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_fn_ptr_index, ""); LLVMBuildStore(g->builder, split_llvm_fn, fn_ptr_ptr); - - LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)}; + LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(usize_type_ref)}; LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, ""); ZigLLVMSetTailCall(call_inst); LLVMBuildRetVoid(g->builder); @@ -3601,6 +3606,21 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0); LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "CallResume"); LLVMPositionBuilderAtEnd(g->builder, call_bb); + + if (ir_want_runtime_safety(g, &instruction->base)) { + LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "BadResume"); + LLVMBasicBlockRef ok_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "OkResume"); + LLVMValueRef arg_val = LLVMGetParam(split_llvm_fn, 1); + LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref); + LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntNE, arg_val, all_ones, ""); + LLVMBuildCondBr(g->builder, ok_bit, ok_resume_block, bad_resume_block); + + LLVMPositionBuilderAtEnd(g->builder, bad_resume_block); + gen_safety_crash(g, PanicMsgIdResumedAnAwaitingFn); + + LLVMPositionBuilderAtEnd(g->builder, ok_resume_block); + } + render_async_var_decls(g, instruction->base.scope); if (type_has_bits(src_return_type)) { @@ -5139,6 +5159,21 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0); LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "AwaitResume"); LLVMPositionBuilderAtEnd(g->builder, call_bb); + + if (ir_want_runtime_safety(g, &instruction->base)) { + LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "BadResume"); + LLVMBasicBlockRef ok_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "OkResume"); + LLVMValueRef arg_val = LLVMGetParam(split_llvm_fn, 1); + LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref); + LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntNE, arg_val, all_ones, ""); + LLVMBuildCondBr(g->builder, ok_bit, ok_resume_block, bad_resume_block); + + LLVMPositionBuilderAtEnd(g->builder, bad_resume_block); + gen_safety_crash(g, PanicMsgIdResumedAnAwaitingFn); + + LLVMPositionBuilderAtEnd(g->builder, ok_resume_block); + } + render_async_var_decls(g, instruction->base.scope); if (type_has_bits(result_type)) { @@ -5178,7 +5213,9 @@ static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable, LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame, coro_fn_ptr_index, ""); LLVMValueRef uncasted_fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, ""); LLVMValueRef fn_val = LLVMBuildIntToPtr(g->builder, uncasted_fn_val, anyframe_fn_type(g), ""); - LLVMValueRef args[] = {frame, LLVMGetUndef(usize_type_ref)}; + LLVMValueRef arg_val = ir_want_runtime_safety(g, &instruction->base) ? + LLVMConstAllOnes(usize_type_ref) : LLVMGetUndef(usize_type_ref); + LLVMValueRef args[] = {frame, arg_val}; ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, ""); return nullptr; } diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index 336dbb8bf0..43cf0856c3 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -1,6 +1,38 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompareOutputContext) void { + cases.addRuntimeSafety("resuming a function which is awaiting a frame", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\pub fn main() void { + \\ var frame = async first(); + \\ resume frame; + \\} + \\fn first() void { + \\ var frame = async other(); + \\ await frame; + \\} + \\fn other() void { + \\ suspend; + \\} + ); + cases.addRuntimeSafety("resuming a function which is awaiting a call", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\pub fn main() void { + \\ var frame = async first(); + \\ resume frame; + \\} + \\fn first() void { + \\ other(); + \\} + \\fn other() void { + \\ suspend; + \\} + ); + cases.addRuntimeSafety("invalid resume of async function", \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { \\ @import("std").os.exit(126); diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index aa77541d19..2b82dce707 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -89,7 +89,7 @@ test "calling an inferred async function" { var other_frame: *@Frame(other) = undefined; fn doTheTest() void { - const p = async first(); + _ = async first(); expect(x == 1); resume other_frame.*; expect(x == 2);