From b08924e93845a4119b990d31e0456bff9c9dc77a Mon Sep 17 00:00:00 2001 From: samy007 Date: Mon, 14 Apr 2025 20:46:06 +0200 Subject: [PATCH] std.math.big.int: changed llshr and llshl implementation --- lib/std/math/big/int.zig | 155 +++++++++++++++++++++++++-------------- 1 file changed, 100 insertions(+), 55 deletions(-) diff --git a/lib/std/math/big/int.zig b/lib/std/math/big/int.zig index 801872b886..444aacd9dd 100644 --- a/lib/std/math/big/int.zig +++ b/lib/std/math/big/int.zig @@ -17,7 +17,6 @@ const Endian = std.builtin.Endian; const Signedness = std.builtin.Signedness; const native_endian = builtin.cpu.arch.endian(); - /// Returns the number of limbs needed to store `scalar`, which must be a /// primitive integer value. /// Note: A comptime-known upper bound of this value that may be used @@ -210,7 +209,7 @@ pub const Mutable = struct { for (self.limbs[0..self.len]) |limb| { std.debug.print("{x} ", .{limb}); } - std.debug.print("capacity={} positive={}\n", .{ self.limbs.len, self.positive }); + std.debug.print("len={} capacity={} positive={}\n", .{ self.len, self.limbs.len, self.positive }); } /// Clones an Mutable and returns a new Mutable with the same value. The new Mutable is a deep copy and @@ -1104,8 +1103,8 @@ pub const Mutable = struct { /// Asserts there is enough memory to fit the result. The upper bound Limb count is /// `a.limbs.len + (shift / (@sizeOf(Limb) * 8))`. pub fn shiftLeft(r: *Mutable, a: Const, shift: usize) void { - llshl(r.limbs, a.limbs, shift); - r.normalize(a.limbs.len + (shift / limb_bits) + 1); + const new_len = llshl(r.limbs, a.limbs, shift); + r.normalize(new_len); r.positive = a.positive; } @@ -1173,8 +1172,8 @@ pub const Mutable = struct { // This shift should not be able to overflow, so invoke llshl and normalize manually // to avoid the extra required limb. - llshl(r.limbs, a.limbs, shift); - r.normalize(a.limbs.len + (shift / limb_bits)); + const new_len = llshl(r.limbs, a.limbs, shift); + r.normalize(new_len); r.positive = a.positive; } @@ -1182,7 +1181,7 @@ pub const Mutable = struct { /// r and a may alias. /// /// Asserts there is enough memory to fit the result. The upper bound Limb count is - /// `a.limbs.len - (shift / (@sizeOf(Limb) * 8))`. + /// `a.limbs.len - (shift / (@bitSizeOf(Limb)))`. pub fn shiftRight(r: *Mutable, a: Const, shift: usize) void { const full_limbs_shifted_out = shift / limb_bits; const remaining_bits_shifted_out = shift % limb_bits; @@ -1210,9 +1209,9 @@ pub const Mutable = struct { break :nonzero a.limbs[full_limbs_shifted_out] << not_covered != 0; }; - llshr(r.limbs, a.limbs, shift); + const new_len = llshr(r.limbs, a.limbs, shift); - r.len = a.limbs.len - full_limbs_shifted_out; + r.len = new_len; r.positive = a.positive; if (nonzero_negative_shiftout) r.addScalar(r.toConst(), -1); r.normalize(r.len); @@ -1971,7 +1970,7 @@ pub const Const = struct { for (self.limbs[0..self.limbs.len]) |limb| { std.debug.print("{x} ", .{limb}); } - std.debug.print("positive={}\n", .{self.positive}); + std.debug.print("len={} positive={}\n", .{ self.len, self.positive }); } pub fn abs(self: Const) Const { @@ -2673,7 +2672,7 @@ pub const Managed = struct { for (self.limbs[0..self.len()]) |limb| { std.debug.print("{x} ", .{limb}); } - std.debug.print("capacity={} positive={}\n", .{ self.limbs.len, self.isPositive() }); + std.debug.print("len={} capacity={} positive={}\n", .{ self.len(), self.limbs.len, self.isPositive() }); } /// Negate the sign. @@ -3711,68 +3710,114 @@ fn lldiv0p5(quo: []Limb, rem: *Limb, a: []const Limb, b: HalfLimb) void { } } -fn llshl(r: []Limb, a: []const Limb, shift: usize) void { - @setRuntimeSafety(debug_safety); - assert(a.len >= 1); +/// Performs r = a << shift and returns the amount of limbs affected +/// +/// if a and r overlaps, then r.ptr >= a.ptr is asserted +/// r must have the capacity to store a << shift +fn llshl(r: []Limb, a: []const Limb, shift: usize) usize { + std.debug.assert(a.len >= 1); + if (slicesOverlap(a, r)) + std.debug.assert(@intFromPtr(r.ptr) >= @intFromPtr(a.ptr)); - const interior_limb_shift = @as(Log2Limb, @truncate(shift)); + if (shift == 0) { + if (a.ptr != r.ptr) + std.mem.copyBackwards(Limb, r[0..a.len], a); + return a.len; + } + if (shift >= limb_bits) { + const limb_shift = shift / limb_bits; + + const affected = llshl(r[limb_shift..], a, shift % limb_bits); + @memset(r[0..limb_shift], 0); + + return limb_shift + affected; + } + + // shift is guaranteed to be < limb_bits + const bit_shift: Log2Limb = @truncate(shift); + const opposite_bit_shift: Log2Limb = @truncate(limb_bits - bit_shift); // We only need the extra limb if the shift of the last element overflows. // This is useful for the implementation of `shiftLeftSat`. - if (a[a.len - 1] << interior_limb_shift >> interior_limb_shift != a[a.len - 1]) { - assert(r.len >= a.len + (shift / limb_bits) + 1); + const overflows = a[a.len - 1] >> opposite_bit_shift != 0; + if (overflows) { + std.debug.assert(r.len >= a.len + 1); } else { - assert(r.len >= a.len + (shift / limb_bits)); + std.debug.assert(r.len >= a.len); } - const limb_shift = shift / limb_bits + 1; - - var carry: Limb = 0; - var i: usize = 0; - while (i < a.len) : (i += 1) { - const src_i = a.len - i - 1; - const dst_i = src_i + limb_shift; - - const src_digit = a[src_i]; - r[dst_i] = carry | @call(.always_inline, math.shr, .{ - Limb, - src_digit, - limb_bits - @as(Limb, @intCast(interior_limb_shift)), - }); - carry = (src_digit << interior_limb_shift); + var i: usize = a.len; + if (overflows) { + // r is asserted to be large enough above + r[a.len] = a[a.len - 1] >> opposite_bit_shift; } + while (i > 1) { + i -= 1; + r[i] = (a[i - 1] >> opposite_bit_shift) | (a[i] << bit_shift); + } + r[0] = a[0] << bit_shift; - r[limb_shift - 1] = carry; - @memset(r[0 .. limb_shift - 1], 0); + return a.len + @intFromBool(overflows); } -fn llshr(r: []Limb, a: []const Limb, shift: usize) void { - @setRuntimeSafety(debug_safety); - assert(a.len >= 1); - assert(r.len >= a.len - (shift / limb_bits)); +/// Performs r = a >> shift and returns the amount of limbs affected +/// +/// if a and r overlaps, then r.ptr <= a.ptr is asserted +/// r must have the capacity to store a >> shift +/// +/// See tests below for examples of behaviour +fn llshr(r: []Limb, a: []const Limb, shift: usize) usize { + if (slicesOverlap(a, r)) + std.debug.assert(@intFromPtr(r.ptr) <= @intFromPtr(a.ptr)); - const limb_shift = shift / limb_bits; - const interior_limb_shift = @as(Log2Limb, @truncate(shift)); + if (a.len == 0) return 0; + + if (shift == 0) { + std.debug.assert(r.len >= a.len); + + if (a.ptr != r.ptr) + std.mem.copyForwards(Limb, r[0..a.len], a); + return a.len; + } + if (shift >= limb_bits) { + if (shift / limb_bits >= a.len) { + r[0] = 0; + return 1; + } + return llshr(r, a[shift / limb_bits ..], shift % limb_bits); + } + + // shift is guaranteed to be < limb_bits + const bit_shift: Log2Limb = @truncate(shift); + const opposite_bit_shift: Log2Limb = @truncate(limb_bits - bit_shift); + + // special case, where there is a risk to set r to 0 + if (a.len == 1) { + r[0] = a[0] >> bit_shift; + return 1; + } + if (a.len == 0) { + r[0] = 0; + return 1; + } + + // if the most significant limb becomes 0 after the shift + const shrink = a[a.len - 1] >> bit_shift == 0; + std.debug.assert(r.len >= a.len - @intFromBool(!shrink)); var i: usize = 0; - while (i < a.len - limb_shift) : (i += 1) { - const dst_i = i; - const src_i = dst_i + limb_shift; - - const src_digit = a[src_i]; - const src_digit_next = if (src_i + 1 < a.len) a[src_i + 1] else 0; - const carry = @call(.always_inline, math.shl, .{ - Limb, - src_digit_next, - limb_bits - @as(Limb, @intCast(interior_limb_shift)), - }); - r[dst_i] = carry | (src_digit >> interior_limb_shift); + while (i < a.len - 1) : (i += 1) { + r[i] = (a[i] >> bit_shift) | (a[i + 1] << opposite_bit_shift); } + + if (!shrink) + r[i] = a[i] >> bit_shift; + + return a.len - @intFromBool(shrink); } // r = ~r fn llnot(r: []Limb) void { - for (r) |*elem| { elem.* = ~elem.*; } @@ -4107,7 +4152,7 @@ fn llsquareBasecase(r: []Limb, x: []const Limb) void { } // Each product appears twice, multiply by 2 - llshl(r, r[0 .. 2 * x_norm.len], 1); + _ = llshl(r, r[0 .. 2 * x_norm.len], 1); for (x_norm, 0..) |v, i| { // Compute and add the squares