From aa29f4a8037ae74fb2d97793312ef8c5262d025a Mon Sep 17 00:00:00 2001 From: riverbl <94326797+riverbl@users.noreply.github.com> Date: Tue, 21 Dec 2021 12:45:48 +0000 Subject: [PATCH 1/3] stage1: fix saturating arithmetic producing incorrect results on type comptime_int, allow saturating left shift on type comptime int --- src/stage1/ir.cpp | 38 ++++---- test/behavior/saturating_arithmetic.zig | 26 ++++++ test/behavior/wrapping_arithmetic.zig | 110 ++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 15 deletions(-) create mode 100644 test/behavior/wrapping_arithmetic.zig diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index 1b9e9638e2..574d3a91a7 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -10230,13 +10230,7 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b // comptime_int has no finite bit width casted_op2 = op2; - if (op_id == IrBinOpShlSat) { - ir_add_error_node(ira, bin_op_instruction->base.source_node, - buf_sprintf("saturating shift on a comptime_int which has unlimited bits")); - return ira->codegen->invalid_inst_gen; - } - - if (op_id == IrBinOpBitShiftLeftLossy) { + if (op_id == IrBinOpBitShiftLeftLossy || op_id == IrBinOpShlSat) { op_id = IrBinOpBitShiftLeftExact; } @@ -10398,6 +10392,25 @@ static bool ok_float_op(IrBinOp op) { zig_unreachable(); } +static IrBinOp map_comptime_arithmetic_op(IrBinOp op) { + switch (op) { + case IrBinOpAddWrap: + case IrBinOpAddSat: + return IrBinOpAdd; + + case IrBinOpSubWrap: + case IrBinOpSubSat: + return IrBinOpSub; + + case IrBinOpMultWrap: + case IrBinOpMultSat: + return IrBinOpMult; + + default: + return op; + } +} + static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) { switch (op) { case IrBinOpAdd: @@ -10620,15 +10633,10 @@ static Stage1AirInst *ir_analyze_bin_op_math(IrAnalyze *ira, Stage1ZirInstBinOp if (type_is_invalid(casted_op2->value->type)) return ira->codegen->invalid_inst_gen; - // Comptime integers have no fixed size + // Comptime integers have no fixed size, so wrapping or saturating operations should be mapped + // to their non wrapping or saturating equivalents if (scalar_type->id == ZigTypeIdComptimeInt) { - if (op_id == IrBinOpAddWrap) { - op_id = IrBinOpAdd; - } else if (op_id == IrBinOpSubWrap) { - op_id = IrBinOpSub; - } else if (op_id == IrBinOpMultWrap) { - op_id = IrBinOpMult; - } + op_id = map_comptime_arithmetic_op(op_id); } if (instr_is_comptime(casted_op1) && instr_is_comptime(casted_op2)) { diff --git a/test/behavior/saturating_arithmetic.zig b/test/behavior/saturating_arithmetic.zig index ef6c24a389..94e1f1ae81 100644 --- a/test/behavior/saturating_arithmetic.zig +++ b/test/behavior/saturating_arithmetic.zig @@ -29,8 +29,14 @@ test "saturating add" { try expect(x == expected); } }; + try S.doTheTest(); comptime try S.doTheTest(); + + comptime try S.testSatAdd(comptime_int, 0, 0, 0); + comptime try S.testSatAdd(comptime_int, 3, 2, 5); + comptime try S.testSatAdd(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 1119305249183743626545271163355074748512); + comptime try S.testSatAdd(comptime_int, 7, -593423721213448152027139550640105366508, -593423721213448152027139550640105366501); } test "saturating subtraction" { @@ -56,8 +62,14 @@ test "saturating subtraction" { try expect(x == expected); } }; + try S.doTheTest(); comptime try S.doTheTest(); + + comptime try S.testSatSub(comptime_int, 0, 0, 0); + comptime try S.testSatSub(comptime_int, 3, 2, 1); + comptime try S.testSatSub(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 182846383813587550256162760261375991602); + comptime try S.testSatSub(comptime_int, 7, -593423721213448152027139550640105366508, 593423721213448152027139550640105366515); } test "saturating multiplication" { @@ -90,6 +102,11 @@ test "saturating multiplication" { try S.doTheTest(); comptime try S.doTheTest(); + + comptime try S.testSatMul(comptime_int, 0, 0, 0); + comptime try S.testSatMul(comptime_int, 3, 2, 6); + comptime try S.testSatMul(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 304852860194144160265083087140337419215516305999637969803722975979232817921935); + comptime try S.testSatMul(comptime_int, 7, -593423721213448152027139550640105366508, -4153966048494137064189976854480737565556); } test "saturating shift-left" { @@ -107,6 +124,7 @@ test "saturating shift-left" { try testSatShl(u8, 1, 2, 4); try testSatShl(u8, 255, 1, 255); } + fn testSatShl(comptime T: type, lhs: T, rhs: T, expected: T) !void { try expect((lhs <<| rhs) == expected); @@ -115,8 +133,14 @@ test "saturating shift-left" { try expect(x == expected); } }; + try S.doTheTest(); comptime try S.doTheTest(); + + comptime try S.testSatShl(comptime_int, 0, 0, 0); + comptime try S.testSatShl(comptime_int, 1, 2, 4); + comptime try S.testSatShl(comptime_int, 13, 150, 18554220005177478453757717602843436772975706112); + comptime try S.testSatShl(comptime_int, -582769, 180, -893090893854873184096635538665358532628308979495815656505344); } test "saturating shl uses the LHS type" { @@ -139,4 +163,6 @@ test "saturating shl uses the LHS type" { try expect((@as(u8, 1) <<| 8) == 255); try expect((@as(u8, 1) <<| rhs_const) == 255); try expect((@as(u8, 1) <<| rhs_var) == 255); + + try expect((1 <<| @as(u8, 200)) == 1606938044258990275541962092341162602522202993782792835301376); } diff --git a/test/behavior/wrapping_arithmetic.zig b/test/behavior/wrapping_arithmetic.zig new file mode 100644 index 0000000000..5ee9b20780 --- /dev/null +++ b/test/behavior/wrapping_arithmetic.zig @@ -0,0 +1,110 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const minInt = std.math.minInt; +const maxInt = std.math.maxInt; +const expect = std.testing.expect; + +test "wrapping add" { + const S = struct { + fn doTheTest() !void { + try testWrapAdd(i8, -3, 10, 7); + try testWrapAdd(i8, -128, -128, 0); + try testWrapAdd(i2, 1, 1, -2); + try testWrapAdd(i64, maxInt(i64), 1, minInt(i64)); + try testWrapAdd(i128, maxInt(i128), -maxInt(i128), 0); + try testWrapAdd(i128, minInt(i128), maxInt(i128), -1); + try testWrapAdd(i8, 127, 127, -2); + try testWrapAdd(u8, 3, 10, 13); + try testWrapAdd(u8, 255, 255, 254); + try testWrapAdd(u2, 3, 2, 1); + try testWrapAdd(u3, 7, 1, 0); + try testWrapAdd(u128, maxInt(u128), 1, minInt(u128)); + } + + fn testWrapAdd(comptime T: type, lhs: T, rhs: T, expected: T) !void { + try expect((lhs +% rhs) == expected); + + var x = lhs; + x +%= rhs; + try expect(x == expected); + } + }; + + try S.doTheTest(); + comptime try S.doTheTest(); + + comptime try S.testWrapAdd(comptime_int, 0, 0, 0); + comptime try S.testWrapAdd(comptime_int, 3, 2, 5); + comptime try S.testWrapAdd(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 1119305249183743626545271163355074748512); + comptime try S.testWrapAdd(comptime_int, 7, -593423721213448152027139550640105366508, -593423721213448152027139550640105366501); +} + +test "wrapping subtraction" { + const S = struct { + fn doTheTest() !void { + try testWrapSub(i8, -3, 10, -13); + try testWrapSub(i8, -128, -128, 0); + try testWrapSub(i8, -1, 127, -128); + try testWrapSub(i64, minInt(i64), 1, maxInt(i64)); + try testWrapSub(i128, maxInt(i128), -1, minInt(i128)); + try testWrapSub(i128, minInt(i128), -maxInt(i128), -1); + try testWrapSub(u8, 10, 3, 7); + try testWrapSub(u8, 0, 255, 1); + try testWrapSub(u5, 0, 31, 1); + try testWrapSub(u128, 0, maxInt(u128), 1); + } + + fn testWrapSub(comptime T: type, lhs: T, rhs: T, expected: T) !void { + try expect((lhs -% rhs) == expected); + + var x = lhs; + x -%= rhs; + try expect(x == expected); + } + }; + + try S.doTheTest(); + comptime try S.doTheTest(); + + comptime try S.testWrapSub(comptime_int, 0, 0, 0); + comptime try S.testWrapSub(comptime_int, 3, 2, 1); + comptime try S.testWrapSub(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 182846383813587550256162760261375991602); + comptime try S.testWrapSub(comptime_int, 7, -593423721213448152027139550640105366508, 593423721213448152027139550640105366515); +} + +test "wrapping multiplication" { + // TODO: once #9660 has been solved, remove this line + if (builtin.cpu.arch == .wasm32) return error.SkipZigTest; + + const S = struct { + fn doTheTest() !void { + try testWrapMul(i8, -3, 10, -30); + try testWrapMul(i4, 2, 4, -8); + try testWrapMul(i8, 2, 127, -2); + try testWrapMul(i8, -128, -128, 0); + try testWrapMul(i8, maxInt(i8), maxInt(i8), 1); + try testWrapMul(i16, maxInt(i16), -1, minInt(i16) + 1); + try testWrapMul(i128, maxInt(i128), -1, minInt(i128) + 1); + try testWrapMul(i128, minInt(i128), -1, minInt(i128)); + try testWrapMul(u8, 10, 3, 30); + try testWrapMul(u8, 2, 255, 254); + try testWrapMul(u128, maxInt(u128), maxInt(u128), 1); + } + + fn testWrapMul(comptime T: type, lhs: T, rhs: T, expected: T) !void { + try expect((lhs *% rhs) == expected); + + var x = lhs; + x *%= rhs; + try expect(x == expected); + } + }; + + try S.doTheTest(); + comptime try S.doTheTest(); + + comptime try S.testWrapMul(comptime_int, 0, 0, 0); + comptime try S.testWrapMul(comptime_int, 3, 2, 6); + comptime try S.testWrapMul(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 304852860194144160265083087140337419215516305999637969803722975979232817921935); + comptime try S.testWrapMul(comptime_int, 7, -593423721213448152027139550640105366508, -4153966048494137064189976854480737565556); +} From 54634991a2b9767d20725b20a2c273b6db60a825 Mon Sep 17 00:00:00 2001 From: riverbl <94326797+riverbl@users.noreply.github.com> Date: Sun, 26 Dec 2021 00:04:48 +0000 Subject: [PATCH 2/3] stage1: fix issue with bigint_add that caused incorrect results when adding a large and a small comptime_int of differing sign stage1: fix issue with to_twos_complement that caused a compile error when performing wrapping addition on two signed ints, both of which have the minimum possible value --- src/stage1/bigint.cpp | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/src/stage1/bigint.cpp b/src/stage1/bigint.cpp index f027281561..eab0f037cf 100644 --- a/src/stage1/bigint.cpp +++ b/src/stage1/bigint.cpp @@ -60,6 +60,9 @@ static void to_twos_complement(BigInt *dest, const BigInt *op, size_t bit_count) bigint_init_unsigned(dest, 0); return; } + + BigInt pos_op = {0}; + if (op->is_negative) { BigInt negated = {0}; bigint_negate(&negated, op); @@ -70,13 +73,14 @@ static void to_twos_complement(BigInt *dest, const BigInt *op, size_t bit_count) BigInt one = {0}; bigint_init_unsigned(&one, 1); - bigint_add(dest, &inverted, &one); - return; + bigint_add(&pos_op, &inverted, &one); + } else { + bigint_init_bigint(&pos_op, op); } dest->is_negative = false; - const uint64_t *op_digits = bigint_ptr(op); - if (op->digit_count == 1) { + const uint64_t *op_digits = bigint_ptr(&pos_op); + if (pos_op.digit_count == 1) { dest->data.digit = op_digits[0]; if (bit_count < 64) { dest->data.digit &= (1ULL << bit_count) - 1; @@ -98,11 +102,11 @@ static void to_twos_complement(BigInt *dest, const BigInt *op, size_t bit_count) } dest->data.digits = heap::c_allocator.allocate_nonzero(dest->digit_count); for (size_t i = 0; i < digits_to_copy; i += 1) { - uint64_t digit = (i < op->digit_count) ? op_digits[i] : 0; + uint64_t digit = (i < pos_op.digit_count) ? op_digits[i] : 0; dest->data.digits[i] = digit; } if (leftover_bits != 0) { - uint64_t digit = (digits_to_copy < op->digit_count) ? op_digits[digits_to_copy] : 0; + uint64_t digit = (digits_to_copy < pos_op.digit_count) ? op_digits[digits_to_copy] : 0; dest->data.digits[digits_to_copy] = digit & ((1ULL << leftover_bits) - 1); } bigint_normalize(dest); @@ -469,18 +473,18 @@ void bigint_min(BigInt* dest, const BigInt *op1, const BigInt *op2) { } /// clamps op within bit_count/signedness boundaries -/// signed bounds are [-2^(bit_count-1)..2^(bit_count-1)-1] -/// unsigned bounds are [0..2^bit_count-1] +/// signed bounds are [-2^(bit_count-1)..2^(bit_count-1)-1] +/// unsigned bounds are [0..2^bit_count-1] void bigint_clamp_by_bitcount(BigInt* dest, uint32_t bit_count, bool is_signed) { - // compute the number of bits required to store the value, and use that + // compute the number of bits required to store the value, and use that // to decide whether to clamp the result bool is_negative = dest->is_negative; - // to workaround the fact this bits_needed calculation would yield 65 or more for - // all negative numbers, set is_negative to false. this is a cheap way to find - // bits_needed(abs(dest)). + // to workaround the fact this bits_needed calculation would yield 65 or more for + // all negative numbers, set is_negative to false. this is a cheap way to find + // bits_needed(abs(dest)). dest->is_negative = false; // because we've set is_negative to false, we have to account for the extra bit here - // by adding 1 additional bit_needed when (is_negative && !is_signed). + // by adding 1 additional bit_needed when (is_negative && !is_signed). size_t full_bits = dest->digit_count * 64; size_t leading_zero_count = bigint_clz(dest, full_bits); size_t bits_needed = full_bits - leading_zero_count + (is_negative && !is_signed); @@ -491,7 +495,7 @@ void bigint_clamp_by_bitcount(BigInt* dest, uint32_t bit_count, bool is_signed) bigint_init_unsigned(&one, 1); BigInt bit_count_big; bigint_init_unsigned(&bit_count_big, bit_count); - + if(is_signed) { if(is_negative) { BigInt bound; @@ -639,25 +643,22 @@ void bigint_add(BigInt *dest, const BigInt *op1, const BigInt *op2) { size_t i = 1; for (;;) { - bool found_digit = false; uint64_t x = bigger_op_digits[i]; uint64_t prev_overflow = overflow; overflow = 0; if (i < smaller_op->digit_count) { - found_digit = true; uint64_t digit = smaller_op_digits[i]; overflow += sub_u64_overflow(x, digit, &x); } - if (sub_u64_overflow(x, prev_overflow, &x)) { - found_digit = true; - overflow += 1; - } + + overflow += sub_u64_overflow(x, prev_overflow, &x); dest->data.digits[i] = x; i += 1; - if (!found_digit || i >= bigger_op->digit_count) + if (i >= bigger_op->digit_count) { break; + } } assert(overflow == 0); dest->digit_count = i; From 3c53667db8d44777c2fbceaf3c6d3a22a6c9caad Mon Sep 17 00:00:00 2001 From: riverbl <94326797+riverbl@users.noreply.github.com> Date: Mon, 27 Dec 2021 22:52:56 +0000 Subject: [PATCH 3/3] stage2: fix bug where performing wrapping or saturating arithmetic or saturating left shift on type comptime_int executed unreachable code --- src/Sema.zig | 36 +++++++++++++++++++++++------------- src/value.zig | 12 ++++++++++++ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index ac67b3f07f..8cb38cc9f5 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -7474,7 +7474,10 @@ fn zirShl( } const val = switch (air_tag) { .shl_exact => return sema.fail(block, lhs_src, "TODO implement Sema for comptime shl_exact", .{}), - .shl_sat => try lhs_val.shlSat(rhs_val, lhs_ty, sema.arena, sema.mod.getTarget()), + .shl_sat => if (lhs_ty.zigTypeTag() == .ComptimeInt) + try lhs_val.shl(rhs_val, sema.arena) + else + try lhs_val.shlSat(rhs_val, lhs_ty, sema.arena, sema.mod.getTarget()), .shl => try lhs_val.shl(rhs_val, sema.arena), else => unreachable, }; @@ -8189,10 +8192,12 @@ fn analyzeArithmetic( return casted_lhs; } if (maybe_lhs_val) |lhs_val| { - return sema.addConstant( - scalar_type, - try lhs_val.intAddSat(rhs_val, scalar_type, sema.arena, target), - ); + const val = if (scalar_tag == .ComptimeInt) + try lhs_val.intAdd(rhs_val, sema.arena) + else + try lhs_val.intAddSat(rhs_val, scalar_type, sema.arena, target); + + return sema.addConstant(scalar_type, val); } else break :rs .{ .src = lhs_src, .air_tag = .add_sat }; } else break :rs .{ .src = rhs_src, .air_tag = .add_sat }; }, @@ -8280,10 +8285,12 @@ fn analyzeArithmetic( return sema.addConstUndef(scalar_type); } if (maybe_rhs_val) |rhs_val| { - return sema.addConstant( - scalar_type, - try lhs_val.intSubSat(rhs_val, scalar_type, sema.arena, target), - ); + const val = if (scalar_tag == .ComptimeInt) + try lhs_val.intSub(rhs_val, sema.arena) + else + try lhs_val.intSubSat(rhs_val, scalar_type, sema.arena, target); + + return sema.addConstant(scalar_type, val); } else break :rs .{ .src = rhs_src, .air_tag = .sub_sat }; } else break :rs .{ .src = lhs_src, .air_tag = .sub_sat }; }, @@ -8663,10 +8670,13 @@ fn analyzeArithmetic( if (lhs_val.isUndef()) { return sema.addConstUndef(scalar_type); } - return sema.addConstant( - scalar_type, - try lhs_val.intMulSat(rhs_val, scalar_type, sema.arena, target), - ); + + const val = if (scalar_tag == .ComptimeInt) + try lhs_val.intMul(rhs_val, sema.arena) + else + try lhs_val.intMulSat(rhs_val, scalar_type, sema.arena, target); + + return sema.addConstant(scalar_type, val); } else break :rs .{ .src = lhs_src, .air_tag = .mul_sat }; } else break :rs .{ .src = rhs_src, .air_tag = .mul_sat }; }, diff --git a/src/value.zig b/src/value.zig index faf4f38e80..df1531533b 100644 --- a/src/value.zig +++ b/src/value.zig @@ -2275,6 +2275,10 @@ pub const Value = extern union { ) !Value { if (lhs.isUndef() or rhs.isUndef()) return Value.initTag(.undef); + if (ty.zigTypeTag() == .ComptimeInt) { + return intAdd(lhs, rhs, arena); + } + if (ty.isAnyFloat()) { return floatAdd(lhs, rhs, ty, arena); } @@ -2361,6 +2365,10 @@ pub const Value = extern union { ) !Value { if (lhs.isUndef() or rhs.isUndef()) return Value.initTag(.undef); + if (ty.zigTypeTag() == .ComptimeInt) { + return intSub(lhs, rhs, arena); + } + if (ty.isAnyFloat()) { return floatSub(lhs, rhs, ty, arena); } @@ -2440,6 +2448,10 @@ pub const Value = extern union { ) !Value { if (lhs.isUndef() or rhs.isUndef()) return Value.initTag(.undef); + if (ty.zigTypeTag() == .ComptimeInt) { + return intMul(lhs, rhs, arena); + } + if (ty.isAnyFloat()) { return floatMul(lhs, rhs, ty, arena); }