From 300fceac6eca99fd858678b03a47357e72856e10 Mon Sep 17 00:00:00 2001 From: LemonBoy Date: Tue, 10 Mar 2020 20:54:05 +0100 Subject: [PATCH 1/3] ir: Implement more safety checks for shl/shr The checks are now valid on types whose size is not a power of two. Closes #2096 --- src/all_types.hpp | 1 + src/codegen.cpp | 32 +++++++++++++++++++++++++++ src/ir.cpp | 48 ++++++++++++++++++++--------------------- test/compile_errors.zig | 35 ++++++++++++++++++++++++++++-- test/runtime_safety.zig | 31 ++++++++++++++++++++++++++ 5 files changed, 120 insertions(+), 27 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 14b99228ca..53aae9e236 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1834,6 +1834,7 @@ enum PanicMsgId { PanicMsgIdBadNoAsyncCall, PanicMsgIdResumeNotSuspendedFn, PanicMsgIdBadSentinel, + PanicMsgIdShxTooBigRhs, PanicMsgIdCount, }; diff --git a/src/codegen.cpp b/src/codegen.cpp index 22cb975205..d659e27d86 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -974,6 +974,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { return buf_create_from_str("resumed a non-suspended function"); case PanicMsgIdBadSentinel: return buf_create_from_str("sentinel mismatch"); + case PanicMsgIdShxTooBigRhs: + return buf_create_from_str("shift amount is greater than the type size"); } zig_unreachable(); } @@ -2841,6 +2843,26 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast } +static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type, LLVMValueRef value) { + // We only check if the rhs value of the shift expression is greater or + // equal to the number of bits of the lhs if it's not a power of two, + // otherwise the check is useful as the allowed values are limited by the + // operand type itself + if (!is_power_of_2(lhs_type->data.integral.bit_count)) { + LLVMValueRef bit_count_value = LLVMConstInt(get_llvm_type(g, rhs_type), + lhs_type->data.integral.bit_count, false); + LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, ""); + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckFail"); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk"); + LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block); + + LLVMPositionBuilderAtEnd(g->builder, fail_block); + gen_safety_crash(g, PanicMsgIdShxTooBigRhs); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); + } +} + static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, IrInstGenBinOp *bin_op_instruction) { @@ -2949,6 +2971,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, { assert(scalar_type->id == ZigTypeIdInt); LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value); + + if (want_runtime_safety) { + gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value); + } + bool is_sloppy = (op_id == IrBinOpBitShiftLeftLossy); if (is_sloppy) { return LLVMBuildShl(g->builder, op1_value, op2_casted, ""); @@ -2965,6 +2992,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, { assert(scalar_type->id == ZigTypeIdInt); LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value); + + if (want_runtime_safety) { + gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value); + } + bool is_sloppy = (op_id == IrBinOpBitShiftRightLossy); if (is_sloppy) { if (scalar_type->data.integral.is_signed) { diff --git a/src/ir.cpp b/src/ir.cpp index e5b28f84c2..b9b42141bc 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -16648,36 +16648,34 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in return ira->codegen->invalid_inst_gen; } } else { + assert(op1->value->type->data.integral.bit_count > 0); ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen, - op1->value->type->data.integral.bit_count - 1); - if (bin_op_instruction->op_id == IrBinOpBitShiftLeftLossy && - op2->value->type->id == ZigTypeIdComptimeInt) { - - ZigValue *op2_val = ir_resolve_const(ira, op2, UndefBad); - if (op2_val == nullptr) - return ira->codegen->invalid_inst_gen; - if (!bigint_fits_in_bits(&op2_val->data.x_bigint, - shift_amt_type->data.integral.bit_count, - op2_val->data.x_bigint.is_negative)) { - Buf *val_buf = buf_alloc(); - bigint_append_buf(val_buf, &op2_val->data.x_bigint, 10); - ErrorMsg* msg = ir_add_error(ira, - &bin_op_instruction->base.base, - buf_sprintf("RHS of shift is too large for LHS type")); - add_error_note( - ira->codegen, - msg, - op2->base.source_node, - buf_sprintf("value %s cannot fit into type %s", - buf_ptr(val_buf), - buf_ptr(&shift_amt_type->name))); - return ira->codegen->invalid_inst_gen; - } - } + op1->value->type->data.integral.bit_count - 1); casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type); if (type_is_invalid(casted_op2->value->type)) return ira->codegen->invalid_inst_gen; + + if (instr_is_comptime(casted_op2)) { + ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad); + if (op2_val == nullptr) + return ira->codegen->invalid_inst_gen; + + BigInt bit_count_value = {0}; + bigint_init_unsigned(&bit_count_value, op1->value->type->data.integral.bit_count); + + if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) { + ErrorMsg* msg = ir_add_error(ira, + &bin_op_instruction->base.base, + buf_sprintf("RHS of shift is too large for LHS type")); + add_error_note(ira->codegen, msg, op1->base.source_node, + buf_sprintf("type %s has only %u bits", + buf_ptr(&op1->value->type->name), + op1->value->type->data.integral.bit_count)); + + return ira->codegen->invalid_inst_gen; + } + } } if (instr_is_comptime(op1) && instr_is_comptime(casted_op2)) { diff --git a/test/compile_errors.zig b/test/compile_errors.zig index a2d4e8ac23..5078453332 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -2,6 +2,38 @@ const tests = @import("tests.zig"); const std = @import("std"); pub fn addCases(cases: *tests.CompileErrorContext) void { + cases.addTest("shift on type with non-power-of-two size", + \\export fn entry() void { + \\ const S = struct { + \\ fn a() void { + \\ var x: u24 = 42; + \\ _ = x >> 24; + \\ } + \\ fn b() void { + \\ var x: u24 = 42; + \\ _ = x << 24; + \\ } + \\ fn c() void { + \\ var x: u24 = 42; + \\ _ = @shlExact(x, 24); + \\ } + \\ fn d() void { + \\ var x: u24 = 42; + \\ _ = @shrExact(x, 24); + \\ } + \\ }; + \\ S.a(); + \\ S.b(); + \\ S.c(); + \\ S.d(); + \\} + , &[_][]const u8{ + "tmp.zig:5:19: error: RHS of shift is too large for LHS type", + "tmp.zig:9:19: error: RHS of shift is too large for LHS type", + "tmp.zig:13:17: error: RHS of shift is too large for LHS type", + "tmp.zig:17:17: error: RHS of shift is too large for LHS type", + }); + cases.addTest("combination of noasync and async", \\export fn entry() void { \\ noasync { @@ -4029,8 +4061,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { \\} \\export fn entry() u16 { return f(); } , &[_][]const u8{ - "tmp.zig:3:14: error: RHS of shift is too large for LHS type", - "tmp.zig:3:17: note: value 8 cannot fit into type u3", + "tmp.zig:3:17: error: integer value 8 cannot be coerced to type 'u3'", }); cases.add("missing function call param", diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index e183c6979f..9855aae16b 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -1,6 +1,37 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompareOutputContext) void { + cases.addRuntimeSafety("shift left by huge amount", + \\const std = @import("std"); + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ std.debug.warn("{}\n", .{message}); + \\ if (std.mem.eql(u8, message, "shift amount is greater than the type size")) { + \\ std.process.exit(126); // good + \\ } + \\ std.process.exit(0); // test failed + \\} + \\pub fn main() void { + \\ var x: u24 = 42; + \\ var y: u5 = 24; + \\ var z = x >> y; + \\} + ); + + cases.addRuntimeSafety("shift right by huge amount", + \\const std = @import("std"); + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ if (std.mem.eql(u8, message, "shift amount is greater than the type size")) { + \\ std.process.exit(126); // good + \\ } + \\ std.process.exit(0); // test failed + \\} + \\pub fn main() void { + \\ var x: u24 = 42; + \\ var y: u5 = 24; + \\ var z = x << y; + \\} + ); + cases.addRuntimeSafety("slice sentinel mismatch - optional pointers", \\const std = @import("std"); \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { From 4ab13a359dfb14c93869e7f88320ec2aa438da9c Mon Sep 17 00:00:00 2001 From: LemonBoy Date: Tue, 10 Mar 2020 23:04:49 +0100 Subject: [PATCH 2/3] ir: Fix shift code for u0 operands --- src/ir.cpp | 46 +++++++++++++++++++++++------------ test/stage1/behavior/math.zig | 19 +++++++++++++++ 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/src/ir.cpp b/src/ir.cpp index b9b42141bc..bb2dc75c64 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -16635,34 +16635,47 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in IrInstGen *casted_op2; IrBinOp op_id = bin_op_instruction->op_id; if (op1->value->type->id == ZigTypeIdComptimeInt) { + // comptime_int has no finite bit width casted_op2 = op2; if (op_id == IrBinOpBitShiftLeftLossy) { op_id = IrBinOpBitShiftLeftExact; } - if (casted_op2->value->data.x_bigint.is_negative) { + if (!instr_is_comptime(op2)) { + ir_add_error(ira, &bin_op_instruction->base.base, + buf_sprintf("LHS of shift must be an integer type, or RHS must be compile-time known")); + return ira->codegen->invalid_inst_gen; + } + + ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad); + if (op2_val == nullptr) + return ira->codegen->invalid_inst_gen; + + if (op2_val->data.x_bigint.is_negative) { Buf *val_buf = buf_alloc(); - bigint_append_buf(val_buf, &casted_op2->value->data.x_bigint, 10); - ir_add_error(ira, &casted_op2->base, buf_sprintf("shift by negative value %s", buf_ptr(val_buf))); + bigint_append_buf(val_buf, &op2_val->data.x_bigint, 10); + ir_add_error(ira, &casted_op2->base, + buf_sprintf("shift by negative value %s", buf_ptr(val_buf))); return ira->codegen->invalid_inst_gen; } } else { - assert(op1->value->type->data.integral.bit_count > 0); + const unsigned bit_count = op1->value->type->data.integral.bit_count; ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen, - op1->value->type->data.integral.bit_count - 1); + bit_count > 0 ? bit_count - 1 : 0); casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type); if (type_is_invalid(casted_op2->value->type)) return ira->codegen->invalid_inst_gen; - if (instr_is_comptime(casted_op2)) { + // This check is only valid iff op1 has at least one bit + if (bit_count > 0 && instr_is_comptime(casted_op2)) { ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad); if (op2_val == nullptr) return ira->codegen->invalid_inst_gen; BigInt bit_count_value = {0}; - bigint_init_unsigned(&bit_count_value, op1->value->type->data.integral.bit_count); + bigint_init_unsigned(&bit_count_value, bit_count); if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) { ErrorMsg* msg = ir_add_error(ira, @@ -16670,14 +16683,23 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in buf_sprintf("RHS of shift is too large for LHS type")); add_error_note(ira->codegen, msg, op1->base.source_node, buf_sprintf("type %s has only %u bits", - buf_ptr(&op1->value->type->name), - op1->value->type->data.integral.bit_count)); + buf_ptr(&op1->value->type->name), bit_count)); return ira->codegen->invalid_inst_gen; } } } + // Fast path for zero RHS + if (instr_is_comptime(casted_op2)) { + ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad); + if (op2_val == nullptr) + return ira->codegen->invalid_inst_gen; + + if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ) + return ir_analyze_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1); + } + if (instr_is_comptime(op1) && instr_is_comptime(casted_op2)) { ZigValue *op1_val = ir_resolve_const(ira, op1, UndefBad); if (op1_val == nullptr) @@ -16688,12 +16710,6 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in return ira->codegen->invalid_inst_gen; return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1->value->type, op1_val, op_id, op2_val); - } else if (op1->value->type->id == ZigTypeIdComptimeInt) { - ir_add_error(ira, &bin_op_instruction->base.base, - buf_sprintf("LHS of shift must be an integer type, or RHS must be compile-time known")); - return ira->codegen->invalid_inst_gen; - } else if (instr_is_comptime(casted_op2) && bigint_cmp_zero(&casted_op2->value->data.x_bigint) == CmpEQ) { - return ir_build_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1, CastOpNoop); } return ir_build_bin_op_gen(ira, &bin_op_instruction->base.base, op1->value->type, diff --git a/test/stage1/behavior/math.zig b/test/stage1/behavior/math.zig index e657b5472b..fb70fb7e44 100644 --- a/test/stage1/behavior/math.zig +++ b/test/stage1/behavior/math.zig @@ -453,6 +453,25 @@ fn testShrExact(x: u8) void { expect(shifted == 0b00101101); } +test "shift left/right on u0 operand" { + const S = struct { + fn doTheTest() void { + var x: u0 = 0; + var y: u0 = 0; + expectEqual(@as(u0, 0), x << 0); + expectEqual(@as(u0, 0), x >> 0); + expectEqual(@as(u0, 0), x << y); + expectEqual(@as(u0, 0), x >> y); + expectEqual(@as(u0, 0), @shlExact(x, 0)); + expectEqual(@as(u0, 0), @shrExact(x, 0)); + expectEqual(@as(u0, 0), @shlExact(x, y)); + expectEqual(@as(u0, 0), @shrExact(x, y)); + } + }; + S.doTheTest(); + comptime S.doTheTest(); +} + test "comptime_int addition" { comptime { expect(35361831660712422535336160538497375248 + 101752735581729509668353361206450473702 == 137114567242441932203689521744947848950); From 2f1052a313cb09f87f04cef56805c33be62eb169 Mon Sep 17 00:00:00 2001 From: LemonBoy Date: Tue, 10 Mar 2020 23:50:04 +0100 Subject: [PATCH 3/3] std: Fix broken tests --- lib/std/io.zig | 10 ++++++++-- lib/std/mem.zig | 8 +++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/std/io.zig b/lib/std/io.zig index 99e9391f1d..f823eb8115 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -350,12 +350,18 @@ pub fn BitInStream(endian: builtin.Endian, comptime Error: type) type { switch (endian) { .Big => { out_buffer = @as(Buf, self.bit_buffer >> shift); - self.bit_buffer <<= n; + if (n >= u7_bit_count) + self.bit_buffer = 0 + else + self.bit_buffer <<= n; }, .Little => { const value = (self.bit_buffer << shift) >> shift; out_buffer = @as(Buf, value); - self.bit_buffer >>= n; + if (n >= u7_bit_count) + self.bit_buffer = 0 + else + self.bit_buffer >>= n; }, } self.bit_count -= n; diff --git a/lib/std/mem.zig b/lib/std/mem.zig index 4da7829570..bee38a30f6 100644 --- a/lib/std/mem.zig +++ b/lib/std/mem.zig @@ -921,6 +921,9 @@ pub fn writeInt(comptime T: type, buffer: *[@divExact(T.bit_count, 8)]u8, value: pub fn writeIntSliceLittle(comptime T: type, buffer: []u8, value: T) void { assert(buffer.len >= @divExact(T.bit_count, 8)); + if (T.bit_count == 0) + return set(u8, buffer, 0); + // TODO I want to call writeIntLittle here but comptime eval facilities aren't good enough const uint = std.meta.IntType(false, T.bit_count); var bits = @truncate(uint, value); @@ -938,6 +941,9 @@ pub fn writeIntSliceLittle(comptime T: type, buffer: []u8, value: T) void { pub fn writeIntSliceBig(comptime T: type, buffer: []u8, value: T) void { assert(buffer.len >= @divExact(T.bit_count, 8)); + if (T.bit_count == 0) + return set(u8, buffer, 0); + // TODO I want to call writeIntBig here but comptime eval facilities aren't good enough const uint = std.meta.IntType(false, T.bit_count); var bits = @truncate(uint, value); @@ -1807,7 +1813,7 @@ test "sliceAsBytes" { } test "sliceAsBytes with sentinel slice" { - const empty_string:[:0]const u8 = ""; + const empty_string: [:0]const u8 = ""; const bytes = sliceAsBytes(empty_string); testing.expect(bytes.len == 0); }