diff --git a/src/all_types.hpp b/src/all_types.hpp index 7ea15dc763..052c67d334 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1715,6 +1715,7 @@ enum PanicMsgId { PanicMsgIdFrameTooSmall, PanicMsgIdResumedFnPendingAwait, PanicMsgIdBadNoAsyncCall, + PanicMsgIdResumeNotSuspendedFn, PanicMsgIdCount, }; @@ -1886,6 +1887,7 @@ struct CodeGen { size_t cur_resume_block_count; LLVMValueRef cur_err_ret_trace_val_arg; LLVMValueRef cur_err_ret_trace_val_stack; + LLVMValueRef cur_bad_not_suspended_index; LLVMValueRef memcpy_fn_val; LLVMValueRef memset_fn_val; LLVMValueRef trap_fn_val; diff --git a/src/codegen.cpp b/src/codegen.cpp index 37715716e1..1b14c5551b 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -933,6 +933,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { return buf_create_from_str("resumed an async function which can only be awaited"); case PanicMsgIdBadNoAsyncCall: return buf_create_from_str("async function called with noasync suspended"); + case PanicMsgIdResumeNotSuspendedFn: + return buf_create_from_str("resumed a non-suspended function"); } zig_unreachable(); } @@ -2234,6 +2236,12 @@ static void gen_assert_resume_id(CodeGen *g, IrInstruction *source_instr, Resume LLVMBasicBlockRef end_bb) { LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; + + if (ir_want_runtime_safety(g, source_instr)) { + // Write a value to the resume index which indicates the function was resumed while not suspended. + LLVMBuildStore(g->builder, g->cur_bad_not_suspended_index, g->cur_async_resume_index_ptr); + } + LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(g->cur_fn_val, "BadResume"); if (end_bb == nullptr) end_bb = LLVMAppendBasicBlock(g->cur_fn_val, "OkResume"); LLVMValueRef expected_value = LLVMConstSub(LLVMConstAllOnes(usize_type_ref), @@ -5764,6 +5772,9 @@ static LLVMValueRef ir_render_suspend_finish(CodeGen *g, IrExecutable *executabl LLVMBuildRetVoid(g->builder); LLVMPositionBuilderAtEnd(g->builder, instruction->begin->resume_bb); + if (ir_want_runtime_safety(g, &instruction->base)) { + LLVMBuildStore(g->builder, g->cur_bad_not_suspended_index, g->cur_async_resume_index_ptr); + } render_async_var_decls(g, instruction->base.scope); return nullptr; } @@ -7542,7 +7553,20 @@ static void do_code_gen(CodeGen *g) { IrBasicBlock *entry_block = executable->basic_block_list.at(0); LLVMAddCase(switch_instr, zero, entry_block->llvm_block); g->cur_resume_block_count += 1; + + { + LLVMBasicBlockRef bad_not_suspended_bb = LLVMAppendBasicBlock(g->cur_fn_val, "NotSuspended"); + size_t new_block_index = g->cur_resume_block_count; + g->cur_resume_block_count += 1; + g->cur_bad_not_suspended_index = LLVMConstInt(usize_type_ref, new_block_index, false); + LLVMAddCase(g->cur_async_switch_instr, g->cur_bad_not_suspended_index, bad_not_suspended_bb); + + LLVMPositionBuilderAtEnd(g->builder, bad_not_suspended_bb); + gen_assertion_scope(g, PanicMsgIdResumeNotSuspendedFn, fn_table_entry->child_scope); + } + LLVMPositionBuilderAtEnd(g->builder, entry_block->llvm_block); + LLVMBuildStore(g->builder, g->cur_bad_not_suspended_index, g->cur_async_resume_index_ptr); if (trace_field_index_stack != UINT32_MAX) { if (codegen_fn_has_err_ret_tracing_arg(g, fn_type_id->return_type)) { LLVMValueRef trace_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index 17f0f3230c..d278407ee1 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -1,6 +1,54 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompareOutputContext) void { + cases.addRuntimeSafety("resuming a non-suspended function which never been suspended", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\fn foo() void { + \\ var f = async bar(@frame()); + \\ @import("std").os.exit(0); + \\} + \\ + \\fn bar(frame: anyframe) void { + \\ suspend { + \\ resume frame; + \\ } + \\ @import("std").os.exit(0); + \\} + \\ + \\pub fn main() void { + \\ _ = async foo(); + \\} + ); + + cases.addRuntimeSafety("resuming a non-suspended function which has been suspended and resumed", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\fn foo() void { + \\ suspend { + \\ global_frame = @frame(); + \\ } + \\ var f = async bar(@frame()); + \\ @import("std").os.exit(0); + \\} + \\ + \\fn bar(frame: anyframe) void { + \\ suspend { + \\ resume frame; + \\ } + \\ @import("std").os.exit(0); + \\} + \\ + \\var global_frame: anyframe = undefined; + \\pub fn main() void { + \\ _ = async foo(); + \\ resume global_frame; + \\ @import("std").os.exit(0); + \\} + ); + cases.addRuntimeSafety("noasync function call, callee suspends", \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { \\ @import("std").os.exit(126);