diff --git a/lib/std/math/big/int.zig b/lib/std/math/big/int.zig index b55b1eb9b5..90961db3dc 100644 --- a/lib/std/math/big/int.zig +++ b/lib/std/math/big/int.zig @@ -86,6 +86,24 @@ pub fn addMulLimbWithCarry(a: Limb, b: Limb, c: Limb, carry: *Limb) Limb { return r1; } +/// a - b * c - *carry, sets carry to the overflow bits +fn subMulLimbWithBorrow(a: Limb, b: Limb, c: Limb, carry: *Limb) Limb { + // r1 = a - *carry + var r1: Limb = undefined; + const c1: Limb = @boolToInt(@subWithOverflow(Limb, a, carry.*, &r1)); + + // r2 = b * c + const bc = @as(DoubleLimb, std.math.mulWide(Limb, b, c)); + const r2 = @truncate(Limb, bc); + const c2 = @truncate(Limb, bc >> limb_bits); + + // r1 = r1 - r2 + const c3: Limb = @boolToInt(@subWithOverflow(Limb, r1, r2, &r1)); + carry.* = c1 + c2 + c3; + + return r1; +} + /// Used to indicate either limit of a 2s-complement integer. pub const TwosCompIntLimit = enum { // The low limit, either 0x00 (unsigned) or (-)0x80 (signed) for an 8-bit integer. @@ -640,7 +658,7 @@ pub const Mutable = struct { mem.set(Limb, rma.limbs[0 .. a.limbs.len + b.limbs.len + 1], 0); - llmulacc(allocator, rma.limbs, a.limbs, b.limbs); + llmulacc(.add, allocator, rma.limbs, a.limbs, b.limbs); rma.normalize(a.limbs.len + b.limbs.len); rma.positive = (a.positive == b.positive); @@ -665,9 +683,9 @@ pub const Mutable = struct { mem.set(Limb, rma.limbs[0..req_limbs], 0); if (a_limbs.len >= b_limbs.len) { - llmulacc_lo(rma.limbs, a_limbs, b_limbs); + llmulaccLow(rma.limbs, a_limbs, b_limbs); } else { - llmulacc_lo(rma.limbs, b_limbs, a_limbs); + llmulaccLow(rma.limbs, b_limbs, a_limbs); } rma.normalize(math.min(req_limbs, a.limbs.len + b.limbs.len)); @@ -691,7 +709,7 @@ pub const Mutable = struct { mem.set(Limb, rma.limbs, 0); - llsquare_basecase(rma.limbs, a.limbs); + llsquareBasecase(rma.limbs, a.limbs); rma.normalize(2 * a.limbs.len + 1); rma.positive = true; @@ -1219,7 +1237,7 @@ pub const Mutable = struct { /// Asserts `r` has enough storage to store the result. /// The upper bound is `calcTwosCompLimbCount(a.len)`. pub fn truncate(r: *Mutable, a: Const, signedness: std.builtin.Signedness, bit_count: usize) void { - const req_limbs = (bit_count + @bitSizeOf(Limb) - 1) / @bitSizeOf(Limb); + const req_limbs = calcTwosCompLimbCount(bit_count); // Handle 0-bit integers. if (req_limbs == 0 or a.eqZero()) { @@ -2319,8 +2337,8 @@ pub const Managed = struct { } }; -/// r = a * b, ignoring overflow -fn llmulacc_lo(r: []Limb, a: []const Limb, b: []const Limb) void { +/// r = r + a * b, ignoring overflow +fn llmulaccLow(r: []Limb, a: []const Limb, b: []const Limb) void { assert(r.len >= a.len); assert(a.len >= b.len); @@ -2328,32 +2346,41 @@ fn llmulacc_lo(r: []Limb, a: []const Limb, b: []const Limb) void { var i: usize = 0; while (i < b.len) : (i += 1) { - llmulDigit(r[i..], a, b[i]); + llmulLimb(.add, r[i..], a, b[i]); } } +/// Different operators which can be used in accumulation style functions +/// (llmulacc, llmulaccKaratsuba, llmulaccLong, llmulLimb). In all these functions, +/// a computed value is accumulated with an existing result. +const AccOp = enum { + /// The computed value is added to the result. + add, + + /// The computed value is subtracted from the result. + sub, +}; + /// Knuth 4.3.1, Algorithm M. /// +/// r = r (op) a * b /// r MUST NOT alias any of a or b. -fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const Limb) void { +fn llmulacc(comptime op: AccOp, opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const Limb) void { @setRuntimeSafety(debug_safety); + assert(r.len >= a.len + b.len); - const a_norm = a[0..llnormalize(a)]; - const b_norm = b[0..llnormalize(b)]; - var x = a_norm; - var y = b_norm; - if (a_norm.len > b_norm.len) { - x = b_norm; - y = a_norm; + // Order greatest first. + var x = a; + var y = b; + if (a.len < b.len) { + x = b; + y = a; } - assert(r.len >= x.len + y.len + 1); - - // 48 is a pretty abitrary size chosen based on performance of a factorial program. k_mul: { - if (x.len > 48) { + if (y.len > 48) { if (opt_allocator) |allocator| { - llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) { + llmulaccKaratsuba(op, allocator, r, x, y) catch |err| switch (err) { error.OutOfMemory => break :k_mul, // handled below }; return; @@ -2361,83 +2388,153 @@ fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const L } } - // Basecase multiplication - var i: usize = 0; - while (i < x.len) : (i += 1) { - llmulDigit(r[i..], y, x[i]); - } + llmulaccLong(op, r, x, y); } /// Knuth 4.3.1, Algorithm M. /// +/// r = r (op) a * b /// r MUST NOT alias any of a or b. -fn llmulacc_karatsuba(allocator: *Allocator, r: []Limb, x: []const Limb, y: []const Limb) error{OutOfMemory}!void { +fn llmulaccKaratsuba( + comptime op: AccOp, + allocator: *Allocator, + r: []Limb, + a: []const Limb, + b: []const Limb, +) error{OutOfMemory}!void { @setRuntimeSafety(debug_safety); + assert(r.len >= a.len + b.len); + assert(a.len >= b.len); - assert(r.len >= x.len + y.len + 1); + // Classical karatsuba algorithm: + // a = a1 * B + a0 + // b = b1 * B + b0 + // Where a0, b0 < B + // + // We then have: + // ab = a * b + // = (a1 * B + a0) * (b1 * B + b0) + // = a1 * b1 * B * B + a1 * B * b0 + a0 * b1 * B + a0 * b0 + // = a1 * b1 * B * B + (a1 * b0 + a0 * b1) * B + a0 * b0 + // + // Note that: + // a1 * b0 + a0 * b1 + // = (a1 + a0)(b1 + b0) - a1 * b1 - a0 * b0 + // = (a0 - a1)(b1 - b0) + a1 * b1 + a0 * b0 + // + // This yields: + // ab = p2 * B^2 + (p0 + p1 + p2) * B + p0 + // + // Where: + // p0 = a0 * b0 + // p1 = (a0 - a1)(b1 - b0) + // p2 = a1 * b1 + // + // Note, (a0 - a1) and (b1 - b0) produce values -B < x < B, and so we need to mind the sign here. + // We also have: + // 0 <= p0 <= 2B + // -2B <= p1 <= 2B + // + // Note, when B is a multiple of the limb size, multiplies by B amount to shifts or + // slices of a limbs array. - const split = @divFloor(x.len, 2); - var x0 = x[0..split]; - var x1 = x[split..x.len]; - var y0 = y[0..split]; - var y1 = y[split..y.len]; + const split = b.len / 2; // B + const a0 = a[0..llnormalize(a[0..split])]; + const a1 = a[split..][0..llnormalize(a[split..])]; + const b0 = b[0..llnormalize(b[0..split])]; + const b1 = b[split..][0..llnormalize(b[split..])]; - var tmp = try allocator.alloc(Limb, x1.len + y1.len + 1); + // Note that the above slices work because we have a.len > b.len. + // We now also have: + // a1.len >= a0.len + // a1.len >= b1.len >= b0.len + // a0.len == b0.len + + // We need some temporary memory to store intermediate results. + // Note, we can reduce the amount of temporaries we need by reordering the computation here: + // ab = p2 * B^2 + (p0 + p1 + p2) * B + p0 + // = p2 * B^2 + (p0 * B + p1 * B + p2 * B) + p0 + // = (p2 * B^2 + p2 * B) + (p0 * B + p0) + p1 * B + // By allocating a1.len * b1.len we can be sure that all the intermediary results fit. + const tmp = try allocator.alloc(Limb, a.len - split + b.len - split); defer allocator.free(tmp); + + // Compute p2. mem.set(Limb, tmp, 0); + llmulacc(.add, allocator, tmp, a1, b1); + const p2 = tmp[0 .. llnormalize(tmp)]; - llmulacc(allocator, tmp, x1, y1); + // Add terms p2 * B^2 and p2 * B to the result. + _ = llaccum(op, r[split..], p2); + _ = llaccum(op, r[split * 2..], p2); - var length = llnormalize(tmp); - _ = llaccum(r[split..], tmp[0..length]); - _ = llaccum(r[split * 2 ..], tmp[0..length]); + // Compute p0. + mem.set(Limb, p2, 0); + llmulacc(.add, allocator, tmp, a0, b0); + const p0 = tmp[0 .. llnormalize(tmp[0..a0.len + b0.len])]; - mem.set(Limb, tmp[0..length], 0); + // Add terms p0 * B and p0 to the result. + _ = llaccum(op, r, p0); + _ = llaccum(op, r[split..], p0); - llmulacc(allocator, tmp, x0, y0); + // Finally, compute and add p1. + const j0_sign = llcmp(a0, a1); + const j1_sign = llcmp(b1, b0); - length = llnormalize(tmp); - _ = llaccum(r[0..], tmp[0..length]); - _ = llaccum(r[split..], tmp[0..length]); - - const x_cmp = llcmp(x1, x0); - const y_cmp = llcmp(y1, y0); - if (x_cmp * y_cmp == 0) { + if (j0_sign * j1_sign == 0) { + // p1 is zero, we don't need to do any computation at all. return; } - const x0_len = llnormalize(x0); - const x1_len = llnormalize(x1); - var j0 = try allocator.alloc(Limb, math.max(x0_len, x1_len)); - defer allocator.free(j0); - if (x_cmp == 1) { - llsub(j0, x1[0..x1_len], x0[0..x0_len]); + + mem.set(Limb, tmp, 0); + + // p1 is nonzero, so compute the intermediary terms j0 = a0 - a1 and j1 = b1 - b0. + // Note that in this case, we again need some storage for intermediary results + // j0 and j1. Since we have tmp.len >= 2B, we can store both + // intermediaries in the already allocated array. + const j0 = tmp[0..a1.len]; + const j1 = tmp[a1.len..]; + + // Ensure that no subtraction overflows. + if (j0_sign == 1) { + // a0 > a1. + _ = llsubcarry(j0, a0, a1); } else { - llsub(j0, x0[0..x0_len], x1[0..x1_len]); + // a0 < a1. + _ = llsubcarry(j0, a1, a0); } - const y0_len = llnormalize(y0); - const y1_len = llnormalize(y1); - var j1 = try allocator.alloc(Limb, math.max(y0_len, y1_len)); - defer allocator.free(j1); - if (y_cmp == 1) { - llsub(j1, y1[0..y1_len], y0[0..y0_len]); + if (j1_sign == 1) { + // b1 > b0. + _ = llsubcarry(j1, b1, b0); } else { - llsub(j1, y0[0..y0_len], y1[0..y1_len]); + // b1 > b0. + _ = llsubcarry(j1, b0, b1); } - if (x_cmp == y_cmp) { - mem.set(Limb, tmp[0..length], 0); - llmulacc(allocator, tmp, j0, j1); - length = llnormalize(tmp); - llsub(r[split..], r[split..], tmp[0..length]); + if (j0_sign * j1_sign == 1) { + // If j0 and j1 are both positive, we now have: + // p1 = j0 * j1 + // If j0 and j1 are both negative, we now have: + // p1 = -j0 * -j1 = j0 * j1 + // In this case we can add p1 to the result using llmulacc. + llmulacc(op, allocator, r[split..], j0[0..llnormalize(j0)], j1[0..llnormalize(j1)]); } else { - llmulacc(allocator, r[split..], j0, j1); + // In this case either j0 or j1 is negative, an we have: + // p1 = -(j0 * j1) + // Now we need to subtract instead of accumulate. + const inverted_op = if (op == .add) .sub else .add; + llmulacc(inverted_op, allocator, r[split..], j0[0..llnormalize(j0)], j1[0..llnormalize(j1)]); } } -// r = r + a -fn llaccum(r: []Limb, a: []const Limb) Limb { +// r = r (op) a +fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) Limb { @setRuntimeSafety(debug_safety); + if (op == .sub) { + return llsubcarry(r, r, a); + } + assert(r.len != 0 and a.len != 0); assert(r.len >= a.len); @@ -2486,24 +2583,53 @@ pub fn llcmp(a: []const Limb, b: []const Limb) i8 { } } -fn llmulDigit(acc: []Limb, y: []const Limb, xi: Limb) void { +// r = r (op) y * xi +fn llmulaccLong(comptime op: AccOp, r: []Limb, a: []const Limb, b: []const Limb) void { + @setRuntimeSafety(debug_safety); + assert(r.len >= a.len + b.len); + assert(a.len >= b.len); + + var i: usize = 0; + while (i < a.len) : (i += 1) { + llmulLimb(op, r[i..], b, a[i]); + } +} + +// r = r (op) y * xi +fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) void { @setRuntimeSafety(debug_safety); if (xi == 0) { return; } - var carry: Limb = 0; var a_lo = acc[0..y.len]; var a_hi = acc[y.len..]; - var j: usize = 0; - while (j < a_lo.len) : (j += 1) { - a_lo[j] = @call(.{ .modifier = .always_inline }, addMulLimbWithCarry, .{ a_lo[j], y[j], xi, &carry }); - } + switch (op) { + .add => { + var carry: Limb = 0; + var j: usize = 0; + while (j < a_lo.len) : (j += 1) { + a_lo[j] = addMulLimbWithCarry(a_lo[j], y[j], xi, &carry); + } - j = 0; - while ((carry != 0) and (j < a_hi.len)) : (j += 1) { - carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j])); + j = 0; + while ((carry != 0) and (j < a_hi.len)) : (j += 1) { + carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j])); + } + }, + .sub => { + var borrow: Limb = 0; + var j: usize = 0; + while (j < a_lo.len) : (j += 1) { + a_lo[j] = subMulLimbWithBorrow(a_lo[j], y[j], xi, &borrow); + } + + j = 0; + while ((borrow != 0) and (j < a_hi.len)) : (j += 1) { + borrow = @boolToInt(@subWithOverflow(Limb, a_hi[j], borrow, &a_hi[j])); + } + }, } } @@ -2964,7 +3090,7 @@ fn llsignedxor(r: []Limb, a: []const Limb, a_positive: bool, b: []const Limb, b_ } /// r MUST NOT alias x. -fn llsquare_basecase(r: []Limb, x: []const Limb) void { +fn llsquareBasecase(r: []Limb, x: []const Limb) void { @setRuntimeSafety(debug_safety); const x_norm = x; @@ -2987,7 +3113,7 @@ fn llsquare_basecase(r: []Limb, x: []const Limb) void { for (x_norm) |v, i| { // Accumulate all the x[i]*x[j] (with x!=j) products - llmulDigit(r[2 * i + 1 ..], x_norm[i + 1 ..], v); + llmulLimb(.add, r[2 * i + 1 ..], x_norm[i + 1 ..], v); } // Each product appears twice, multiply by 2 @@ -2995,7 +3121,7 @@ fn llsquare_basecase(r: []Limb, x: []const Limb) void { for (x_norm) |v, i| { // Compute and add the squares - llmulDigit(r[2 * i ..], x[i .. i + 1], v); + llmulLimb(.add, r[2 * i ..], x[i .. i + 1], v); } } @@ -3034,12 +3160,12 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void { while (i < exp_bits) : (i += 1) { // Square mem.set(Limb, tmp2, 0); - llsquare_basecase(tmp2, tmp1[0..llnormalize(tmp1)]); + llsquareBasecase(tmp2, tmp1[0..llnormalize(tmp1)]); mem.swap([]Limb, &tmp1, &tmp2); // Multiply by a if (@shlWithOverflow(u32, exp, 1, &exp)) { mem.set(Limb, tmp2, 0); - llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a); + llmulacc(.add, null, tmp2, tmp1[0..llnormalize(tmp1)], a); mem.swap([]Limb, &tmp1, &tmp2); } }