add runtime safety for resuming an awaiting function

This commit is contained in:
Andrew Kelley 2019-08-03 02:11:52 -04:00
parent 24d78177ee
commit e444e737b7
No known key found for this signature in database
GPG Key ID: 7C5F548F728501A9
5 changed files with 79 additions and 7 deletions

View File

@ -6,7 +6,6 @@
* @asyncCall with an async function pointer
* cancel
* defer and errdefer
* safety for resuming when it is awaiting
* safety for double await
* implicit cast of normal function to async function should be allowed when it is inferred to be async
* go over the commented out tests
@ -19,3 +18,6 @@
* make sure there are safety tests for all the new safety features (search the new PanicFnId enum values)
* error return tracing
* compile error for casting a function to a non-async function pointer, but then later it gets inferred to be an async function
* compile error for copying a frame
* compile error for resuming a const frame pointer
* runtime safety enabling/disabling scope has to be coordinated across resume/await/calls/return

View File

@ -1552,6 +1552,7 @@ enum PanicMsgId {
PanicMsgIdBadResume,
PanicMsgIdBadAwait,
PanicMsgIdBadReturn,
PanicMsgIdResumedAnAwaitingFn,
PanicMsgIdCount,
};

View File

@ -877,6 +877,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
return buf_create_from_str("async function awaited twice");
case PanicMsgIdBadReturn:
return buf_create_from_str("async function returned twice");
case PanicMsgIdResumedAnAwaitingFn:
return buf_create_from_str("awaiting function resumed");
}
zig_unreachable();
}
@ -2018,7 +2020,10 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns
}
result_ptr_as_usize = LLVMBuildPtrToInt(g->builder, result_ptr, usize_type_ref, "");
} else {
result_ptr_as_usize = LLVMGetUndef(usize_type_ref);
// For debug safety, this value has to be anything other than all 1's, which signals
// that it is being resumed. 0 is a bad choice since null pointers are special.
result_ptr_as_usize = ir_want_runtime_safety(g, &return_instruction->base) ?
LLVMConstInt(usize_type_ref, 1, false) : LLVMGetUndef(usize_type_ref);
}
LLVMValueRef zero = LLVMConstNull(usize_type_ref);
LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
@ -3582,8 +3587,9 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
LLVMBuildStore(g->builder, gen_param_values.at(arg_i), arg_ptr);
}
}
LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
if (instruction->is_async) {
LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)};
LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(usize_type_ref)};
ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, "");
return nullptr;
} else if (callee_is_async) {
@ -3591,8 +3597,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
LLVMValueRef split_llvm_fn = make_fn_llvm_value(g, g->cur_fn);
LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_fn_ptr_index, "");
LLVMBuildStore(g->builder, split_llvm_fn, fn_ptr_ptr);
LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)};
LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(usize_type_ref)};
LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, "");
ZigLLVMSetTailCall(call_inst);
LLVMBuildRetVoid(g->builder);
@ -3601,6 +3606,21 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0);
LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "CallResume");
LLVMPositionBuilderAtEnd(g->builder, call_bb);
if (ir_want_runtime_safety(g, &instruction->base)) {
LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "BadResume");
LLVMBasicBlockRef ok_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "OkResume");
LLVMValueRef arg_val = LLVMGetParam(split_llvm_fn, 1);
LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntNE, arg_val, all_ones, "");
LLVMBuildCondBr(g->builder, ok_bit, ok_resume_block, bad_resume_block);
LLVMPositionBuilderAtEnd(g->builder, bad_resume_block);
gen_safety_crash(g, PanicMsgIdResumedAnAwaitingFn);
LLVMPositionBuilderAtEnd(g->builder, ok_resume_block);
}
render_async_var_decls(g, instruction->base.scope);
if (type_has_bits(src_return_type)) {
@ -5139,6 +5159,21 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0);
LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "AwaitResume");
LLVMPositionBuilderAtEnd(g->builder, call_bb);
if (ir_want_runtime_safety(g, &instruction->base)) {
LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "BadResume");
LLVMBasicBlockRef ok_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "OkResume");
LLVMValueRef arg_val = LLVMGetParam(split_llvm_fn, 1);
LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntNE, arg_val, all_ones, "");
LLVMBuildCondBr(g->builder, ok_bit, ok_resume_block, bad_resume_block);
LLVMPositionBuilderAtEnd(g->builder, bad_resume_block);
gen_safety_crash(g, PanicMsgIdResumedAnAwaitingFn);
LLVMPositionBuilderAtEnd(g->builder, ok_resume_block);
}
render_async_var_decls(g, instruction->base.scope);
if (type_has_bits(result_type)) {
@ -5178,7 +5213,9 @@ static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable,
LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame, coro_fn_ptr_index, "");
LLVMValueRef uncasted_fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
LLVMValueRef fn_val = LLVMBuildIntToPtr(g->builder, uncasted_fn_val, anyframe_fn_type(g), "");
LLVMValueRef args[] = {frame, LLVMGetUndef(usize_type_ref)};
LLVMValueRef arg_val = ir_want_runtime_safety(g, &instruction->base) ?
LLVMConstAllOnes(usize_type_ref) : LLVMGetUndef(usize_type_ref);
LLVMValueRef args[] = {frame, arg_val};
ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, "");
return nullptr;
}

View File

@ -1,6 +1,38 @@
const tests = @import("tests.zig");
pub fn addCases(cases: *tests.CompareOutputContext) void {
cases.addRuntimeSafety("resuming a function which is awaiting a frame",
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
\\ @import("std").os.exit(126);
\\}
\\pub fn main() void {
\\ var frame = async first();
\\ resume frame;
\\}
\\fn first() void {
\\ var frame = async other();
\\ await frame;
\\}
\\fn other() void {
\\ suspend;
\\}
);
cases.addRuntimeSafety("resuming a function which is awaiting a call",
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
\\ @import("std").os.exit(126);
\\}
\\pub fn main() void {
\\ var frame = async first();
\\ resume frame;
\\}
\\fn first() void {
\\ other();
\\}
\\fn other() void {
\\ suspend;
\\}
);
cases.addRuntimeSafety("invalid resume of async function",
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
\\ @import("std").os.exit(126);

View File

@ -89,7 +89,7 @@ test "calling an inferred async function" {
var other_frame: *@Frame(other) = undefined;
fn doTheTest() void {
const p = async first();
_ = async first();
expect(x == 1);
resume other_frame.*;
expect(x == 2);