mirror of
https://github.com/ziglang/zig.git
synced 2026-02-13 04:48:20 +00:00
add runtime safety for resuming an awaiting function
This commit is contained in:
parent
24d78177ee
commit
e444e737b7
@ -6,7 +6,6 @@
|
||||
* @asyncCall with an async function pointer
|
||||
* cancel
|
||||
* defer and errdefer
|
||||
* safety for resuming when it is awaiting
|
||||
* safety for double await
|
||||
* implicit cast of normal function to async function should be allowed when it is inferred to be async
|
||||
* go over the commented out tests
|
||||
@ -19,3 +18,6 @@
|
||||
* make sure there are safety tests for all the new safety features (search the new PanicFnId enum values)
|
||||
* error return tracing
|
||||
* compile error for casting a function to a non-async function pointer, but then later it gets inferred to be an async function
|
||||
* compile error for copying a frame
|
||||
* compile error for resuming a const frame pointer
|
||||
* runtime safety enabling/disabling scope has to be coordinated across resume/await/calls/return
|
||||
|
||||
@ -1552,6 +1552,7 @@ enum PanicMsgId {
|
||||
PanicMsgIdBadResume,
|
||||
PanicMsgIdBadAwait,
|
||||
PanicMsgIdBadReturn,
|
||||
PanicMsgIdResumedAnAwaitingFn,
|
||||
|
||||
PanicMsgIdCount,
|
||||
};
|
||||
|
||||
@ -877,6 +877,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
|
||||
return buf_create_from_str("async function awaited twice");
|
||||
case PanicMsgIdBadReturn:
|
||||
return buf_create_from_str("async function returned twice");
|
||||
case PanicMsgIdResumedAnAwaitingFn:
|
||||
return buf_create_from_str("awaiting function resumed");
|
||||
}
|
||||
zig_unreachable();
|
||||
}
|
||||
@ -2018,7 +2020,10 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns
|
||||
}
|
||||
result_ptr_as_usize = LLVMBuildPtrToInt(g->builder, result_ptr, usize_type_ref, "");
|
||||
} else {
|
||||
result_ptr_as_usize = LLVMGetUndef(usize_type_ref);
|
||||
// For debug safety, this value has to be anything other than all 1's, which signals
|
||||
// that it is being resumed. 0 is a bad choice since null pointers are special.
|
||||
result_ptr_as_usize = ir_want_runtime_safety(g, &return_instruction->base) ?
|
||||
LLVMConstInt(usize_type_ref, 1, false) : LLVMGetUndef(usize_type_ref);
|
||||
}
|
||||
LLVMValueRef zero = LLVMConstNull(usize_type_ref);
|
||||
LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
|
||||
@ -3582,8 +3587,9 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
|
||||
LLVMBuildStore(g->builder, gen_param_values.at(arg_i), arg_ptr);
|
||||
}
|
||||
}
|
||||
LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
|
||||
if (instruction->is_async) {
|
||||
LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)};
|
||||
LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(usize_type_ref)};
|
||||
ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, "");
|
||||
return nullptr;
|
||||
} else if (callee_is_async) {
|
||||
@ -3591,8 +3597,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
|
||||
LLVMValueRef split_llvm_fn = make_fn_llvm_value(g, g->cur_fn);
|
||||
LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_fn_ptr_index, "");
|
||||
LLVMBuildStore(g->builder, split_llvm_fn, fn_ptr_ptr);
|
||||
|
||||
LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)};
|
||||
LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(usize_type_ref)};
|
||||
LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, "");
|
||||
ZigLLVMSetTailCall(call_inst);
|
||||
LLVMBuildRetVoid(g->builder);
|
||||
@ -3601,6 +3606,21 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
|
||||
g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0);
|
||||
LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "CallResume");
|
||||
LLVMPositionBuilderAtEnd(g->builder, call_bb);
|
||||
|
||||
if (ir_want_runtime_safety(g, &instruction->base)) {
|
||||
LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "BadResume");
|
||||
LLVMBasicBlockRef ok_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "OkResume");
|
||||
LLVMValueRef arg_val = LLVMGetParam(split_llvm_fn, 1);
|
||||
LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
|
||||
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntNE, arg_val, all_ones, "");
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_resume_block, bad_resume_block);
|
||||
|
||||
LLVMPositionBuilderAtEnd(g->builder, bad_resume_block);
|
||||
gen_safety_crash(g, PanicMsgIdResumedAnAwaitingFn);
|
||||
|
||||
LLVMPositionBuilderAtEnd(g->builder, ok_resume_block);
|
||||
}
|
||||
|
||||
render_async_var_decls(g, instruction->base.scope);
|
||||
|
||||
if (type_has_bits(src_return_type)) {
|
||||
@ -5139,6 +5159,21 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
|
||||
g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0);
|
||||
LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "AwaitResume");
|
||||
LLVMPositionBuilderAtEnd(g->builder, call_bb);
|
||||
|
||||
if (ir_want_runtime_safety(g, &instruction->base)) {
|
||||
LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "BadResume");
|
||||
LLVMBasicBlockRef ok_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "OkResume");
|
||||
LLVMValueRef arg_val = LLVMGetParam(split_llvm_fn, 1);
|
||||
LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
|
||||
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntNE, arg_val, all_ones, "");
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_resume_block, bad_resume_block);
|
||||
|
||||
LLVMPositionBuilderAtEnd(g->builder, bad_resume_block);
|
||||
gen_safety_crash(g, PanicMsgIdResumedAnAwaitingFn);
|
||||
|
||||
LLVMPositionBuilderAtEnd(g->builder, ok_resume_block);
|
||||
}
|
||||
|
||||
render_async_var_decls(g, instruction->base.scope);
|
||||
|
||||
if (type_has_bits(result_type)) {
|
||||
@ -5178,7 +5213,9 @@ static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable,
|
||||
LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame, coro_fn_ptr_index, "");
|
||||
LLVMValueRef uncasted_fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
|
||||
LLVMValueRef fn_val = LLVMBuildIntToPtr(g->builder, uncasted_fn_val, anyframe_fn_type(g), "");
|
||||
LLVMValueRef args[] = {frame, LLVMGetUndef(usize_type_ref)};
|
||||
LLVMValueRef arg_val = ir_want_runtime_safety(g, &instruction->base) ?
|
||||
LLVMConstAllOnes(usize_type_ref) : LLVMGetUndef(usize_type_ref);
|
||||
LLVMValueRef args[] = {frame, arg_val};
|
||||
ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, "");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -1,6 +1,38 @@
|
||||
const tests = @import("tests.zig");
|
||||
|
||||
pub fn addCases(cases: *tests.CompareOutputContext) void {
|
||||
cases.addRuntimeSafety("resuming a function which is awaiting a frame",
|
||||
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
|
||||
\\ @import("std").os.exit(126);
|
||||
\\}
|
||||
\\pub fn main() void {
|
||||
\\ var frame = async first();
|
||||
\\ resume frame;
|
||||
\\}
|
||||
\\fn first() void {
|
||||
\\ var frame = async other();
|
||||
\\ await frame;
|
||||
\\}
|
||||
\\fn other() void {
|
||||
\\ suspend;
|
||||
\\}
|
||||
);
|
||||
cases.addRuntimeSafety("resuming a function which is awaiting a call",
|
||||
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
|
||||
\\ @import("std").os.exit(126);
|
||||
\\}
|
||||
\\pub fn main() void {
|
||||
\\ var frame = async first();
|
||||
\\ resume frame;
|
||||
\\}
|
||||
\\fn first() void {
|
||||
\\ other();
|
||||
\\}
|
||||
\\fn other() void {
|
||||
\\ suspend;
|
||||
\\}
|
||||
);
|
||||
|
||||
cases.addRuntimeSafety("invalid resume of async function",
|
||||
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
|
||||
\\ @import("std").os.exit(126);
|
||||
|
||||
@ -89,7 +89,7 @@ test "calling an inferred async function" {
|
||||
var other_frame: *@Frame(other) = undefined;
|
||||
|
||||
fn doTheTest() void {
|
||||
const p = async first();
|
||||
_ = async first();
|
||||
expect(x == 1);
|
||||
resume other_frame.*;
|
||||
expect(x == 2);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user