From 55fe34100f8b516480cf530eb58d00ea8b665765 Mon Sep 17 00:00:00 2001 From: Veikka Tuominen Date: Sat, 16 Jul 2022 16:10:11 +0300 Subject: [PATCH] Sema: exact division safety --- src/Sema.zig | 43 +++++++++++++++++++ test/behavior/math.zig | 1 + .../exact division failure - vectors.zig | 10 +++-- test/cases/safety/exact division failure.zig | 10 +++-- 4 files changed, 56 insertions(+), 8 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index e30b7f5cec..f8b9d044d0 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -11917,6 +11917,47 @@ fn analyzeArithmetic( }, else => {}, } + if (rs.air_tag == .div_exact) { + const result = try block.addBinOp(.div_exact, casted_lhs, casted_rhs); + const ok = if (scalar_tag == .Float) ok: { + const floored = try block.addUnOp(.floor, result); + + if (resolved_type.zigTypeTag() == .Vector) { + const eql = try block.addCmpVector(result, floored, .eq, try sema.addType(resolved_type)); + break :ok try block.addInst(.{ + .tag = .reduce, + .data = .{ .reduce = .{ + .operand = eql, + .operation = .And, + } }, + }); + } else { + const is_in_range = try block.addBinOp(.cmp_eq, result, floored); + break :ok is_in_range; + } + } else ok: { + const remainder = try block.addBinOp(.rem, casted_lhs, casted_rhs); + + if (resolved_type.zigTypeTag() == .Vector) { + const zero_val = try Value.Tag.repeated.create(sema.arena, Value.zero); + const zero = try sema.addConstant(sema.typeOf(casted_rhs), zero_val); + const eql = try block.addCmpVector(remainder, zero, .eq, try sema.addType(resolved_type)); + break :ok try block.addInst(.{ + .tag = .reduce, + .data = .{ .reduce = .{ + .operand = eql, + .operation = .And, + } }, + }); + } else { + const zero = try sema.addConstant(sema.typeOf(casted_rhs), Value.zero); + const is_in_range = try block.addBinOp(.cmp_eq, remainder, zero); + break :ok is_in_range; + } + }; + try sema.addSafetyCheck(block, ok, .exact_division_remainder); + return result; + } } return block.addBinOp(rs.air_tag, casted_lhs, casted_rhs); } @@ -18856,6 +18897,7 @@ pub const PanicId = enum { shr_overflow, divide_by_zero, remainder_division_zero_negative, + exact_division_remainder, }; fn addSafetyCheck( @@ -19077,6 +19119,7 @@ fn safetyPanic( .shr_overflow => "right shift overflowed bits", .divide_by_zero => "division by zero", .remainder_division_zero_negative => "remainder division by zero or negative value", + .exact_division_remainder => "exact division produced remainder", }; const msg_inst = msg_inst: { diff --git a/test/behavior/math.zig b/test/behavior/math.zig index 7b280bca4e..c8d0becbd6 100644 --- a/test/behavior/math.zig +++ b/test/behavior/math.zig @@ -377,6 +377,7 @@ fn testBinaryNot(x: u16) !void { } test "division" { + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO diff --git a/test/cases/safety/exact division failure - vectors.zig b/test/cases/safety/exact division failure - vectors.zig index a514213f58..9b792b33cf 100644 --- a/test/cases/safety/exact division failure - vectors.zig +++ b/test/cases/safety/exact division failure - vectors.zig @@ -1,9 +1,11 @@ const std = @import("std"); pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn { - _ = message; _ = stack_trace; - std.process.exit(0); + if (std.mem.eql(u8, message, "exact division produced remainder")) { + std.process.exit(0); + } + std.process.exit(1); } pub fn main() !void { @@ -17,5 +19,5 @@ fn divExact(a: @Vector(4, i32), b: @Vector(4, i32)) @Vector(4, i32) { return @divExact(a, b); } // run -// backend=stage1 -// target=native \ No newline at end of file +// backend=llvm +// target=native diff --git a/test/cases/safety/exact division failure.zig b/test/cases/safety/exact division failure.zig index 5e30f14b06..ea4d39ed22 100644 --- a/test/cases/safety/exact division failure.zig +++ b/test/cases/safety/exact division failure.zig @@ -1,9 +1,11 @@ const std = @import("std"); pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn { - _ = message; _ = stack_trace; - std.process.exit(0); + if (std.mem.eql(u8, message, "exact division produced remainder")) { + std.process.exit(0); + } + std.process.exit(1); } pub fn main() !void { @@ -15,5 +17,5 @@ fn divExact(a: i32, b: i32) i32 { return @divExact(a, b); } // run -// backend=stage1 -// target=native \ No newline at end of file +// backend=llvm +// target=native