diff --git a/lib/std/debug.zig b/lib/std/debug.zig index 9b24e1acd3..76685666ad 100644 --- a/lib/std/debug.zig +++ b/lib/std/debug.zig @@ -219,7 +219,7 @@ pub fn panic(comptime format: []const u8, args: var) noreturn { } /// TODO multithreaded awareness -var panicking: u8 = 0; // TODO make this a bool +var panicking: u8 = 0; pub fn panicExtra(trace: ?*const builtin.StackTrace, first_trace_addr: ?usize, comptime format: []const u8, args: var) noreturn { @setCold(true); @@ -230,21 +230,25 @@ pub fn panicExtra(trace: ?*const builtin.StackTrace, first_trace_addr: ?usize, c resetSegfaultHandler(); } - if (@atomicRmw(u8, &panicking, builtin.AtomicRmwOp.Xchg, 1, builtin.AtomicOrder.SeqCst) == 1) { - // Panicked during a panic. - - // TODO detect if a different thread caused the panic, because in that case - // we would want to return here instead of calling abort, so that the thread - // which first called panic can finish printing a stack trace. - os.abort(); + switch (@atomicRmw(u8, &panicking, .Add, 1, .SeqCst)) { + 0 => { + const stderr = getStderrStream(); + stderr.print(format ++ "\n", args) catch os.abort(); + if (trace) |t| { + dumpStackTrace(t.*); + } + dumpCurrentStackTrace(first_trace_addr); + }, + 1 => { + // TODO detect if a different thread caused the panic, because in that case + // we would want to return here instead of calling abort, so that the thread + // which first called panic can finish printing a stack trace. + warn("Panicked during a panic. Aborting.\n", .{}); + }, + else => { + // Panicked while printing "Panicked during a panic." + }, } - const stderr = getStderrStream(); - stderr.print(format ++ "\n", args) catch os.abort(); - if (trace) |t| { - dumpStackTrace(t.*); - } - dumpCurrentStackTrace(first_trace_addr); - os.abort(); } diff --git a/lib/std/mem.zig b/lib/std/mem.zig index 824ab2f556..133e3b8a2d 100644 --- a/lib/std/mem.zig +++ b/lib/std/mem.zig @@ -364,11 +364,11 @@ pub fn len(comptime T: type, ptr: [*:0]const T) usize { } pub fn toSliceConst(comptime T: type, ptr: [*:0]const T) [:0]const T { - return ptr[0..len(T, ptr)]; + return ptr[0..len(T, ptr) :0]; } pub fn toSlice(comptime T: type, ptr: [*:0]T) [:0]T { - return ptr[0..len(T, ptr)]; + return ptr[0..len(T, ptr) :0]; } /// Returns true if all elements in a slice are equal to the scalar value provided diff --git a/src/all_types.hpp b/src/all_types.hpp index 144cdd4fc5..ea46ab81a6 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1779,6 +1779,7 @@ enum PanicMsgId { PanicMsgIdResumedFnPendingAwait, PanicMsgIdBadNoAsyncCall, PanicMsgIdResumeNotSuspendedFn, + PanicMsgIdBadSentinel, PanicMsgIdCount, }; diff --git a/src/codegen.cpp b/src/codegen.cpp index edc2c7f435..01a15b5f2c 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -941,6 +941,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { return buf_create_from_str("async function called with noasync suspended"); case PanicMsgIdResumeNotSuspendedFn: return buf_create_from_str("resumed a non-suspended function"); + case PanicMsgIdBadSentinel: + return buf_create_from_str("sentinel mismatch"); } zig_unreachable(); } @@ -1419,6 +1421,22 @@ static void add_bounds_check(CodeGen *g, LLVMValueRef target_val, LLVMPositionBuilderAtEnd(g->builder, ok_block); } +static void add_sentinel_check(CodeGen *g, LLVMValueRef sentinel_elem_ptr, ZigValue *sentinel) { + LLVMValueRef expected_sentinel = gen_const_val(g, sentinel, ""); + + LLVMValueRef actual_sentinel = gen_load_untyped(g, sentinel_elem_ptr, 0, false, ""); + LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, actual_sentinel, expected_sentinel, ""); + + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "SentinelFail"); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "SentinelOk"); + LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); + + LLVMPositionBuilderAtEnd(g->builder, fail_block); + gen_safety_crash(g, PanicMsgIdBadSentinel); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); +} + static LLVMValueRef gen_assert_zero(CodeGen *g, LLVMValueRef expr_val, ZigType *int_type) { LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, int_type)); LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, zero, ""); @@ -5244,6 +5262,9 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutable *executable, IrInst bool want_runtime_safety = instruction->safety_check_on && ir_want_runtime_safety(g, &instruction->base); + ZigType *res_slice_ptr_type = instruction->base.value->type->data.structure.fields[slice_ptr_index]->type_entry; + ZigValue *sentinel = res_slice_ptr_type->data.pointer.sentinel; + if (array_type->id == ZigTypeIdArray || (array_type->id == ZigTypeIdPointer && array_type->data.pointer.ptr_len == PtrLenSingle)) { @@ -5265,6 +5286,15 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutable *executable, IrInst LLVMValueRef array_end = LLVMConstInt(g->builtin_types.entry_usize->llvm_type, array_type->data.array.len, false); add_bounds_check(g, end_val, LLVMIntEQ, nullptr, LLVMIntULE, array_end); + + if (sentinel != nullptr) { + LLVMValueRef indices[] = { + LLVMConstNull(g->builtin_types.entry_usize->llvm_type), + end_val, + }; + LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 2, ""); + add_sentinel_check(g, sentinel_elem_ptr, sentinel); + } } } if (!type_has_bits(array_type)) { @@ -5297,6 +5327,10 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutable *executable, IrInst if (want_runtime_safety) { add_bounds_check(g, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val); + if (sentinel != nullptr) { + LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, &end_val, 1, ""); + add_sentinel_check(g, sentinel_elem_ptr, sentinel); + } } if (type_has_bits(array_type)) { @@ -5337,18 +5371,24 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutable *executable, IrInst end_val = prev_end; } + LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, array_ptr, (unsigned)ptr_index, ""); + LLVMValueRef src_ptr = gen_load_untyped(g, src_ptr_ptr, 0, false, ""); + if (want_runtime_safety) { assert(prev_end); add_bounds_check(g, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val); if (instruction->end) { add_bounds_check(g, end_val, LLVMIntEQ, nullptr, LLVMIntULE, prev_end); + + if (sentinel != nullptr) { + LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &end_val, 1, ""); + add_sentinel_check(g, sentinel_elem_ptr, sentinel); + } } } - LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, array_ptr, (unsigned)ptr_index, ""); - LLVMValueRef src_ptr = gen_load_untyped(g, src_ptr_ptr, 0, false, ""); LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, (unsigned)ptr_index, ""); - LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &start_val, (unsigned)len_index, ""); + LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &start_val, 1, ""); gen_store_untyped(g, slice_start_ptr, ptr_field_ptr, 0, false); LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, (unsigned)len_index, ""); diff --git a/src/ir.cpp b/src/ir.cpp index e172b69c96..280ba76f2b 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -25122,14 +25122,16 @@ static IrInstruction *ir_analyze_instruction_slice(IrAnalyze *ira, IrInstruction if (array_type->data.pointer.ptr_len == PtrLenC) { array_type = adjust_ptr_len(ira->codegen, array_type, PtrLenUnknown); } - non_sentinel_slice_ptr_type = array_type; + ZigType *maybe_sentineled_slice_ptr_type = array_type; + non_sentinel_slice_ptr_type = adjust_ptr_sentinel(ira->codegen, maybe_sentineled_slice_ptr_type, nullptr); if (!end) { ir_add_error(ira, &instruction->base, buf_sprintf("slice of pointer must include end value")); return ira->codegen->invalid_instruction; } } } else if (is_slice(array_type)) { - non_sentinel_slice_ptr_type = array_type->data.structure.fields[slice_ptr_index]->type_entry; + ZigType *maybe_sentineled_slice_ptr_type = array_type->data.structure.fields[slice_ptr_index]->type_entry; + non_sentinel_slice_ptr_type = adjust_ptr_sentinel(ira->codegen, maybe_sentineled_slice_ptr_type, nullptr); elem_type = non_sentinel_slice_ptr_type->data.pointer.child_type; } else { ir_add_error(ira, &instruction->base, diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index eec5f7a86b..6a1cc808fd 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -1,12 +1,57 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompareOutputContext) void { - cases.addRuntimeSafety("intToPtr with misaligned address", + cases.addRuntimeSafety("pointer slice sentinel mismatch", + \\const std = @import("std"); \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { - \\ if (@import("std").mem.eql(u8, message, "incorrect alignment")) { - \\ @import("std").os.exit(126); // good + \\ if (std.mem.eql(u8, message, "sentinel mismatch")) { + \\ std.process.exit(126); // good \\ } - \\ @import("std").os.exit(0); // test failed + \\ std.process.exit(0); // test failed + \\} + \\pub fn main() void { + \\ var buf: [4]u8 = undefined; + \\ const ptr = buf[0..].ptr; + \\ const slice = ptr[0..3 :0]; + \\} + ); + + cases.addRuntimeSafety("slice slice sentinel mismatch", + \\const std = @import("std"); + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ if (std.mem.eql(u8, message, "sentinel mismatch")) { + \\ std.process.exit(126); // good + \\ } + \\ std.process.exit(0); // test failed + \\} + \\pub fn main() void { + \\ var buf: [4]u8 = undefined; + \\ const slice = buf[0..]; + \\ const slice2 = slice[0..3 :0]; + \\} + ); + + cases.addRuntimeSafety("array slice sentinel mismatch", + \\const std = @import("std"); + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ if (std.mem.eql(u8, message, "sentinel mismatch")) { + \\ std.process.exit(126); // good + \\ } + \\ std.process.exit(0); // test failed + \\} + \\pub fn main() void { + \\ var buf: [4]u8 = undefined; + \\ const slice = buf[0..3 :0]; + \\} + ); + + cases.addRuntimeSafety("intToPtr with misaligned address", + \\const std = @import("std"); + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ if (std.mem.eql(u8, message, "incorrect alignment")) { + \\ std.os.exit(126); // good + \\ } + \\ std.os.exit(0); // test failed \\} \\pub fn main() void { \\ var x: usize = 5;