From 0099583bd3eccbecbb827edbde46a40cf821fecf Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 8 May 2019 17:39:00 -0400 Subject: [PATCH] C pointers support .? operator see #1967 --- src/all_types.hpp | 7 ++++ src/codegen.cpp | 38 +++++++++++++++++++++- src/ir.cpp | 53 ++++++++++++++++++++++++++++--- src/ir_print.cpp | 9 ++++++ test/runtime_safety.zig | 20 ++++++++++++ test/stage1/behavior/pointers.zig | 18 ++++++----- 6 files changed, 132 insertions(+), 13 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 5a5c1cfda4..5cba3f2230 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -2293,6 +2293,7 @@ enum IrInstructionId { IrInstructionIdVectorToArray, IrInstructionIdArrayToVector, IrInstructionIdAssertZero, + IrInstructionIdAssertNonNull, }; struct IrInstruction { @@ -3482,6 +3483,12 @@ struct IrInstructionAssertZero { IrInstruction *target; }; +struct IrInstructionAssertNonNull { + IrInstruction base; + + IrInstruction *target; +}; + static const size_t slice_ptr_index = 0; static const size_t slice_len_index = 1; diff --git a/src/codegen.cpp b/src/codegen.cpp index 712ee908cb..db7d96f4df 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1008,10 +1008,19 @@ static void gen_panic(CodeGen *g, LLVMValueRef msg_arg, LLVMValueRef stack_trace LLVMBuildUnreachable(g->builder); } +// TODO update most callsites to call gen_assertion instead of this static void gen_safety_crash(CodeGen *g, PanicMsgId msg_id) { gen_panic(g, get_panic_msg_ptr_val(g, msg_id), nullptr); } +static void gen_assertion(CodeGen *g, PanicMsgId msg_id, IrInstruction *source_instruction) { + if (ir_want_runtime_safety(g, source_instruction)) { + gen_safety_crash(g, msg_id); + } else { + LLVMBuildUnreachable(g->builder); + } +} + static LLVMValueRef get_stacksave_fn_val(CodeGen *g) { if (g->stacksave_fn_val) return g->stacksave_fn_val; @@ -4056,8 +4065,8 @@ static LLVMValueRef ir_render_optional_unwrap_ptr(CodeGen *g, IrExecutable *exec if (ir_want_runtime_safety(g, &instruction->base) && instruction->safety_check_on) { LLVMValueRef maybe_handle = get_handle_value(g, maybe_ptr, maybe_type, ptr_type); LLVMValueRef non_null_bit = gen_non_null_bit(g, maybe_type, maybe_handle); - LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnwrapOptionalOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnwrapOptionalFail"); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnwrapOptionalOk"); LLVMBuildCondBr(g->builder, non_null_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -5487,6 +5496,31 @@ static LLVMValueRef ir_render_assert_zero(CodeGen *g, IrExecutable *executable, return nullptr; } +static LLVMValueRef ir_render_assert_non_null(CodeGen *g, IrExecutable *executable, + IrInstructionAssertNonNull *instruction) +{ + LLVMValueRef target = ir_llvm_value(g, instruction->target); + ZigType *target_type = instruction->target->value.type; + + if (target_type->id == ZigTypeIdPointer) { + assert(target_type->data.pointer.ptr_len == PtrLenC); + LLVMValueRef non_null_bit = LLVMBuildICmp(g->builder, LLVMIntNE, target, + LLVMConstNull(get_llvm_type(g, target_type)), ""); + + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "AssertNonNullFail"); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "AssertNonNullOk"); + LLVMBuildCondBr(g->builder, non_null_bit, ok_block, fail_block); + + LLVMPositionBuilderAtEnd(g->builder, fail_block); + gen_assertion(g, PanicMsgIdUnwrapOptionalFail, &instruction->base); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); + } else { + zig_unreachable(); + } + return nullptr; +} + static void set_debug_location(CodeGen *g, IrInstruction *instruction) { AstNode *source_node = instruction->source_node; Scope *scope = instruction->scope; @@ -5741,6 +5775,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_vector_to_array(g, executable, (IrInstructionVectorToArray *)instruction); case IrInstructionIdAssertZero: return ir_render_assert_zero(g, executable, (IrInstructionAssertZero *)instruction); + case IrInstructionIdAssertNonNull: + return ir_render_assert_non_null(g, executable, (IrInstructionAssertNonNull *)instruction); case IrInstructionIdResizeSlice: return ir_render_resize_slice(g, executable, (IrInstructionResizeSlice *)instruction); } diff --git a/src/ir.cpp b/src/ir.cpp index 1b0fbd1f7f..90e8f1ed8f 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -1003,6 +1003,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionAssertZero *) { return IrInstructionIdAssertZero; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionAssertNonNull *) { + return IrInstructionIdAssertNonNull; +} + template static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) { T *special_instruction = allocate(1); @@ -3037,6 +3041,19 @@ static IrInstruction *ir_build_assert_zero(IrAnalyze *ira, IrInstruction *source return &instruction->base; } +static IrInstruction *ir_build_assert_non_null(IrAnalyze *ira, IrInstruction *source_instruction, + IrInstruction *target) +{ + IrInstructionAssertNonNull *instruction = ir_build_instruction(&ira->new_irb, + source_instruction->scope, source_instruction->source_node); + instruction->base.value.type = ira->codegen->builtin_types.entry_void; + instruction->target = target; + + ir_ref_instruction(target, ira->new_irb.current_basic_block); + + return &instruction->base; +} + static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) { results[ReturnKindUnconditional] = 0; results[ReturnKindError] = 0; @@ -16869,6 +16886,32 @@ static IrInstruction *ir_analyze_unwrap_optional_payload(IrAnalyze *ira, IrInstr if (type_is_invalid(type_entry)) return ira->codegen->invalid_instruction; + if (type_entry->id == ZigTypeIdPointer && type_entry->data.pointer.ptr_len == PtrLenC) { + if (instr_is_comptime(base_ptr)) { + ConstExprValue *val = ir_resolve_const(ira, base_ptr, UndefBad); + if (!val) + return ira->codegen->invalid_instruction; + if (val->data.x_ptr.mut != ConstPtrMutRuntimeVar) { + ConstExprValue *c_ptr_val = const_ptr_pointee(ira, ira->codegen, val, source_instr->source_node); + if (c_ptr_val == nullptr) + return ira->codegen->invalid_instruction; + bool is_null = c_ptr_val->data.x_ptr.special == ConstPtrSpecialNull || + (c_ptr_val->data.x_ptr.special == ConstPtrSpecialHardCodedAddr && + c_ptr_val->data.x_ptr.data.hard_coded_addr.addr == 0); + if (is_null) { + ir_add_error(ira, source_instr, buf_sprintf("unable to unwrap null")); + return ira->codegen->invalid_instruction; + } + return base_ptr; + } + } + if (!safety_check_on) + return base_ptr; + IrInstruction *c_ptr_val = ir_get_deref(ira, source_instr, base_ptr); + ir_build_assert_non_null(ira, source_instr, c_ptr_val); + return base_ptr; + } + if (type_entry->id != ZigTypeIdOptional) { ir_add_error_node(ira, base_ptr->source_node, buf_sprintf("expected optional type, found '%s'", buf_ptr(&type_entry->name))); @@ -16883,11 +16926,11 @@ static IrInstruction *ir_analyze_unwrap_optional_payload(IrAnalyze *ira, IrInstr ConstExprValue *val = ir_resolve_const(ira, base_ptr, UndefBad); if (!val) return ira->codegen->invalid_instruction; - ConstExprValue *maybe_val = const_ptr_pointee(ira, ira->codegen, val, source_instr->source_node); - if (maybe_val == nullptr) - return ira->codegen->invalid_instruction; - if (val->data.x_ptr.mut != ConstPtrMutRuntimeVar) { + ConstExprValue *maybe_val = const_ptr_pointee(ira, ira->codegen, val, source_instr->source_node); + if (maybe_val == nullptr) + return ira->codegen->invalid_instruction; + if (optional_value_is_null(maybe_val)) { ir_add_error(ira, source_instr, buf_sprintf("unable to unwrap null")); return ira->codegen->invalid_instruction; @@ -22942,6 +22985,7 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio case IrInstructionIdArrayToVector: case IrInstructionIdVectorToArray: case IrInstructionIdAssertZero: + case IrInstructionIdAssertNonNull: case IrInstructionIdResizeSlice: case IrInstructionIdLoadPtrGen: case IrInstructionIdBitCastGen: @@ -23346,6 +23390,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdCmpxchgGen: case IrInstructionIdCmpxchgSrc: case IrInstructionIdAssertZero: + case IrInstructionIdAssertNonNull: case IrInstructionIdResizeSlice: case IrInstructionIdGlobalAsm: return true; diff --git a/src/ir_print.cpp b/src/ir_print.cpp index dba0e4ee00..08f5cd01a4 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1003,6 +1003,12 @@ static void ir_print_assert_zero(IrPrint *irp, IrInstructionAssertZero *instruct fprintf(irp->f, ")"); } +static void ir_print_assert_non_null(IrPrint *irp, IrInstructionAssertNonNull *instruction) { + fprintf(irp->f, "AssertNonNull("); + ir_print_other_instruction(irp, instruction->target); + fprintf(irp->f, ")"); +} + static void ir_print_resize_slice(IrPrint *irp, IrInstructionResizeSlice *instruction) { fprintf(irp->f, "@resizeSlice("); ir_print_other_instruction(irp, instruction->operand); @@ -1880,6 +1886,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdAssertZero: ir_print_assert_zero(irp, (IrInstructionAssertZero *)instruction); break; + case IrInstructionIdAssertNonNull: + ir_print_assert_non_null(irp, (IrInstructionAssertNonNull *)instruction); + break; case IrInstructionIdResizeSlice: ir_print_resize_slice(irp, (IrInstructionResizeSlice *)instruction); break; diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index 78b45ac05f..b10accd213 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -1,6 +1,26 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompareOutputContext) void { + cases.addRuntimeSafety(".? operator on null pointer", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\pub fn main() void { + \\ var ptr: ?*i32 = null; + \\ var b = ptr.?; + \\} + ); + + cases.addRuntimeSafety(".? operator on C pointer", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\pub fn main() void { + \\ var ptr: [*c]i32 = null; + \\ var b = ptr.?; + \\} + ); + cases.addRuntimeSafety("@ptrToInt address zero to non-optional pointer", \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { \\ @import("std").os.exit(126); diff --git a/test/stage1/behavior/pointers.zig b/test/stage1/behavior/pointers.zig index 8b1f7d0cb8..1a7616e757 100644 --- a/test/stage1/behavior/pointers.zig +++ b/test/stage1/behavior/pointers.zig @@ -159,10 +159,10 @@ test "assign null directly to C pointer and test null equality" { expect(!(null != x)); const y: [*c]i32 = null; - expect(y == null); - expect(null == y); - expect(!(y != null)); - expect(!(null != y)); + comptime expect(y == null); + comptime expect(null == y); + comptime expect(!(y != null)); + comptime expect(!(null != y)); var n: i32 = 1234; var x1: [*c]i32 = &n; @@ -170,11 +170,13 @@ test "assign null directly to C pointer and test null equality" { expect(!(null == x1)); expect(x1 != null); expect(null != x1); + expect(x1.?.* == 1234); const nc: i32 = 1234; const y1: [*c]const i32 = &nc; - expect(!(y1 == null)); - expect(!(null == y1)); - expect(y1 != null); - expect(null != y1); + comptime expect(!(y1 == null)); + comptime expect(!(null == y1)); + comptime expect(y1 != null); + comptime expect(null != y1); + comptime expect(y1.?.* == 1234); }