From aa29f4a8037ae74fb2d97793312ef8c5262d025a Mon Sep 17 00:00:00 2001 From: riverbl <94326797+riverbl@users.noreply.github.com> Date: Tue, 21 Dec 2021 12:45:48 +0000 Subject: [PATCH] stage1: fix saturating arithmetic producing incorrect results on type comptime_int, allow saturating left shift on type comptime int --- src/stage1/ir.cpp | 38 ++++---- test/behavior/saturating_arithmetic.zig | 26 ++++++ test/behavior/wrapping_arithmetic.zig | 110 ++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 15 deletions(-) create mode 100644 test/behavior/wrapping_arithmetic.zig diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index 1b9e9638e2..574d3a91a7 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -10230,13 +10230,7 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b // comptime_int has no finite bit width casted_op2 = op2; - if (op_id == IrBinOpShlSat) { - ir_add_error_node(ira, bin_op_instruction->base.source_node, - buf_sprintf("saturating shift on a comptime_int which has unlimited bits")); - return ira->codegen->invalid_inst_gen; - } - - if (op_id == IrBinOpBitShiftLeftLossy) { + if (op_id == IrBinOpBitShiftLeftLossy || op_id == IrBinOpShlSat) { op_id = IrBinOpBitShiftLeftExact; } @@ -10398,6 +10392,25 @@ static bool ok_float_op(IrBinOp op) { zig_unreachable(); } +static IrBinOp map_comptime_arithmetic_op(IrBinOp op) { + switch (op) { + case IrBinOpAddWrap: + case IrBinOpAddSat: + return IrBinOpAdd; + + case IrBinOpSubWrap: + case IrBinOpSubSat: + return IrBinOpSub; + + case IrBinOpMultWrap: + case IrBinOpMultSat: + return IrBinOpMult; + + default: + return op; + } +} + static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) { switch (op) { case IrBinOpAdd: @@ -10620,15 +10633,10 @@ static Stage1AirInst *ir_analyze_bin_op_math(IrAnalyze *ira, Stage1ZirInstBinOp if (type_is_invalid(casted_op2->value->type)) return ira->codegen->invalid_inst_gen; - // Comptime integers have no fixed size + // Comptime integers have no fixed size, so wrapping or saturating operations should be mapped + // to their non wrapping or saturating equivalents if (scalar_type->id == ZigTypeIdComptimeInt) { - if (op_id == IrBinOpAddWrap) { - op_id = IrBinOpAdd; - } else if (op_id == IrBinOpSubWrap) { - op_id = IrBinOpSub; - } else if (op_id == IrBinOpMultWrap) { - op_id = IrBinOpMult; - } + op_id = map_comptime_arithmetic_op(op_id); } if (instr_is_comptime(casted_op1) && instr_is_comptime(casted_op2)) { diff --git a/test/behavior/saturating_arithmetic.zig b/test/behavior/saturating_arithmetic.zig index ef6c24a389..94e1f1ae81 100644 --- a/test/behavior/saturating_arithmetic.zig +++ b/test/behavior/saturating_arithmetic.zig @@ -29,8 +29,14 @@ test "saturating add" { try expect(x == expected); } }; + try S.doTheTest(); comptime try S.doTheTest(); + + comptime try S.testSatAdd(comptime_int, 0, 0, 0); + comptime try S.testSatAdd(comptime_int, 3, 2, 5); + comptime try S.testSatAdd(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 1119305249183743626545271163355074748512); + comptime try S.testSatAdd(comptime_int, 7, -593423721213448152027139550640105366508, -593423721213448152027139550640105366501); } test "saturating subtraction" { @@ -56,8 +62,14 @@ test "saturating subtraction" { try expect(x == expected); } }; + try S.doTheTest(); comptime try S.doTheTest(); + + comptime try S.testSatSub(comptime_int, 0, 0, 0); + comptime try S.testSatSub(comptime_int, 3, 2, 1); + comptime try S.testSatSub(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 182846383813587550256162760261375991602); + comptime try S.testSatSub(comptime_int, 7, -593423721213448152027139550640105366508, 593423721213448152027139550640105366515); } test "saturating multiplication" { @@ -90,6 +102,11 @@ test "saturating multiplication" { try S.doTheTest(); comptime try S.doTheTest(); + + comptime try S.testSatMul(comptime_int, 0, 0, 0); + comptime try S.testSatMul(comptime_int, 3, 2, 6); + comptime try S.testSatMul(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 304852860194144160265083087140337419215516305999637969803722975979232817921935); + comptime try S.testSatMul(comptime_int, 7, -593423721213448152027139550640105366508, -4153966048494137064189976854480737565556); } test "saturating shift-left" { @@ -107,6 +124,7 @@ test "saturating shift-left" { try testSatShl(u8, 1, 2, 4); try testSatShl(u8, 255, 1, 255); } + fn testSatShl(comptime T: type, lhs: T, rhs: T, expected: T) !void { try expect((lhs <<| rhs) == expected); @@ -115,8 +133,14 @@ test "saturating shift-left" { try expect(x == expected); } }; + try S.doTheTest(); comptime try S.doTheTest(); + + comptime try S.testSatShl(comptime_int, 0, 0, 0); + comptime try S.testSatShl(comptime_int, 1, 2, 4); + comptime try S.testSatShl(comptime_int, 13, 150, 18554220005177478453757717602843436772975706112); + comptime try S.testSatShl(comptime_int, -582769, 180, -893090893854873184096635538665358532628308979495815656505344); } test "saturating shl uses the LHS type" { @@ -139,4 +163,6 @@ test "saturating shl uses the LHS type" { try expect((@as(u8, 1) <<| 8) == 255); try expect((@as(u8, 1) <<| rhs_const) == 255); try expect((@as(u8, 1) <<| rhs_var) == 255); + + try expect((1 <<| @as(u8, 200)) == 1606938044258990275541962092341162602522202993782792835301376); } diff --git a/test/behavior/wrapping_arithmetic.zig b/test/behavior/wrapping_arithmetic.zig new file mode 100644 index 0000000000..5ee9b20780 --- /dev/null +++ b/test/behavior/wrapping_arithmetic.zig @@ -0,0 +1,110 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const minInt = std.math.minInt; +const maxInt = std.math.maxInt; +const expect = std.testing.expect; + +test "wrapping add" { + const S = struct { + fn doTheTest() !void { + try testWrapAdd(i8, -3, 10, 7); + try testWrapAdd(i8, -128, -128, 0); + try testWrapAdd(i2, 1, 1, -2); + try testWrapAdd(i64, maxInt(i64), 1, minInt(i64)); + try testWrapAdd(i128, maxInt(i128), -maxInt(i128), 0); + try testWrapAdd(i128, minInt(i128), maxInt(i128), -1); + try testWrapAdd(i8, 127, 127, -2); + try testWrapAdd(u8, 3, 10, 13); + try testWrapAdd(u8, 255, 255, 254); + try testWrapAdd(u2, 3, 2, 1); + try testWrapAdd(u3, 7, 1, 0); + try testWrapAdd(u128, maxInt(u128), 1, minInt(u128)); + } + + fn testWrapAdd(comptime T: type, lhs: T, rhs: T, expected: T) !void { + try expect((lhs +% rhs) == expected); + + var x = lhs; + x +%= rhs; + try expect(x == expected); + } + }; + + try S.doTheTest(); + comptime try S.doTheTest(); + + comptime try S.testWrapAdd(comptime_int, 0, 0, 0); + comptime try S.testWrapAdd(comptime_int, 3, 2, 5); + comptime try S.testWrapAdd(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 1119305249183743626545271163355074748512); + comptime try S.testWrapAdd(comptime_int, 7, -593423721213448152027139550640105366508, -593423721213448152027139550640105366501); +} + +test "wrapping subtraction" { + const S = struct { + fn doTheTest() !void { + try testWrapSub(i8, -3, 10, -13); + try testWrapSub(i8, -128, -128, 0); + try testWrapSub(i8, -1, 127, -128); + try testWrapSub(i64, minInt(i64), 1, maxInt(i64)); + try testWrapSub(i128, maxInt(i128), -1, minInt(i128)); + try testWrapSub(i128, minInt(i128), -maxInt(i128), -1); + try testWrapSub(u8, 10, 3, 7); + try testWrapSub(u8, 0, 255, 1); + try testWrapSub(u5, 0, 31, 1); + try testWrapSub(u128, 0, maxInt(u128), 1); + } + + fn testWrapSub(comptime T: type, lhs: T, rhs: T, expected: T) !void { + try expect((lhs -% rhs) == expected); + + var x = lhs; + x -%= rhs; + try expect(x == expected); + } + }; + + try S.doTheTest(); + comptime try S.doTheTest(); + + comptime try S.testWrapSub(comptime_int, 0, 0, 0); + comptime try S.testWrapSub(comptime_int, 3, 2, 1); + comptime try S.testWrapSub(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 182846383813587550256162760261375991602); + comptime try S.testWrapSub(comptime_int, 7, -593423721213448152027139550640105366508, 593423721213448152027139550640105366515); +} + +test "wrapping multiplication" { + // TODO: once #9660 has been solved, remove this line + if (builtin.cpu.arch == .wasm32) return error.SkipZigTest; + + const S = struct { + fn doTheTest() !void { + try testWrapMul(i8, -3, 10, -30); + try testWrapMul(i4, 2, 4, -8); + try testWrapMul(i8, 2, 127, -2); + try testWrapMul(i8, -128, -128, 0); + try testWrapMul(i8, maxInt(i8), maxInt(i8), 1); + try testWrapMul(i16, maxInt(i16), -1, minInt(i16) + 1); + try testWrapMul(i128, maxInt(i128), -1, minInt(i128) + 1); + try testWrapMul(i128, minInt(i128), -1, minInt(i128)); + try testWrapMul(u8, 10, 3, 30); + try testWrapMul(u8, 2, 255, 254); + try testWrapMul(u128, maxInt(u128), maxInt(u128), 1); + } + + fn testWrapMul(comptime T: type, lhs: T, rhs: T, expected: T) !void { + try expect((lhs *% rhs) == expected); + + var x = lhs; + x *%= rhs; + try expect(x == expected); + } + }; + + try S.doTheTest(); + comptime try S.doTheTest(); + + comptime try S.testWrapMul(comptime_int, 0, 0, 0); + comptime try S.testWrapMul(comptime_int, 3, 2, 6); + comptime try S.testWrapMul(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 304852860194144160265083087140337419215516305999637969803722975979232817921935); + comptime try S.testWrapMul(comptime_int, 7, -593423721213448152027139550640105366508, -4153966048494137064189976854480737565556); +}