From 52207f22de29e69c1084b321277b25e5a6aefb04 Mon Sep 17 00:00:00 2001 From: Brendan Hansknecht Date: Sat, 2 Nov 2019 21:37:53 -0700 Subject: [PATCH] Add karatsuba to big ints --- lib/std/math/big/int.zig | 184 +++++++++++++++++++++++++++++++++++---- 1 file changed, 168 insertions(+), 16 deletions(-) diff --git a/lib/std/math/big/int.zig b/lib/std/math/big/int.zig index 8a6f6c1f75..bfdc768375 100644 --- a/lib/std/math/big/int.zig +++ b/lib/std/math/big/int.zig @@ -766,13 +766,11 @@ pub const Int = struct { r.deinit(); }; - try r.ensureCapacity(a.len() + b.len()); + try r.ensureCapacity(a.len() + b.len() + 1); - if (a.len() >= b.len()) { - llmul(r.limbs, a.limbs[0..a.len()], b.limbs[0..b.len()]); - } else { - llmul(r.limbs, b.limbs[0..b.len()], a.limbs[0..a.len()]); - } + mem.set(Limb, r.limbs[0 .. a.len() + b.len() + 1], 0); + + try llmulacc(rma.allocator.?, r.limbs, a.limbs[0..a.len()], b.limbs[0..b.len()]); r.normalize(a.len() + b.len()); r.setSign(a.isPositive() == b.isPositive()); @@ -780,6 +778,7 @@ pub const Int = struct { // a + b * c + *carry, sets carry to the overflow bits pub fn addMulLimbWithCarry(a: Limb, b: Limb, c: Limb, carry: *Limb) Limb { + @setRuntimeSafety(false); var r1: Limb = undefined; // r1 = a + *carry @@ -800,25 +799,178 @@ pub const Int = struct { return r1; } + fn llmulDigit(acc: []Limb, y: []const Limb, xi: Limb) void { + @setRuntimeSafety(false); + if (xi == 0) { + return; + } + + var carry: usize = 0; + var a_lo = acc[0..y.len]; + var a_hi = acc[y.len..]; + + var j: usize = 0; + while (j < a_lo.len) : (j += 1) { + a_lo[j] = @inlineCall(addMulLimbWithCarry, a_lo[j], y[j], xi, &carry); + } + + j = 0; + while ((carry != 0) and (j < a_hi.len)) : (j += 1) { + carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j])); + } + } + // Knuth 4.3.1, Algorithm M. // // r MUST NOT alias any of a or b. - fn llmul(r: []Limb, a: []const Limb, b: []const Limb) void { + fn llmulacc(allocator: *Allocator, r: []Limb, a: []const Limb, b: []const Limb) error{OutOfMemory}!void { @setRuntimeSafety(false); - debug.assert(a.len >= b.len); - debug.assert(r.len >= a.len + b.len); - mem.set(Limb, r[0 .. a.len + b.len], 0); + const a_norm = a[0..llnormalize(a)]; + const b_norm = b[0..llnormalize(b)]; + var x = a_norm; + var y = b_norm; + if (a_norm.len > b_norm.len) { + x = b_norm; + y = a_norm; + } + + debug.assert(r.len >= x.len + y.len + 1); + + // 48 is a pretty abitrary size chosen based on performance of a factorial program. + if (x.len <= 48) { + // Basecase multiplication + var i: usize = 0; + while (i < x.len) : (i += 1) { + llmulDigit(r[i..], y, x[i]); + } + } else { + // Karatsuba multiplication + const split = @divFloor(x.len, 2); + var x0 = x[0..split]; + var x1 = x[split..x.len]; + var y0 = y[0..split]; + var y1 = y[split..y.len]; + + var tmp = try allocator.alloc(Limb, x1.len + y1.len + 1); + defer allocator.free(tmp); + mem.set(Limb, tmp, 0); + + try llmulacc(allocator, tmp, x1, y1); + + var length = llnormalize(tmp); + _ = llaccum(r[split..], tmp[0..length]); + _ = llaccum(r[split * 2 ..], tmp[0..length]); + + mem.set(Limb, tmp[0..length], 0); + + try llmulacc(allocator, tmp, x0, y0); + + length = llnormalize(tmp); + _ = llaccum(r[0..], tmp[0..length]); + _ = llaccum(r[split..], tmp[0..length]); + + const x_cmp = llcmp(x1, x0); + const y_cmp = llcmp(y1, y0); + if (x_cmp * y_cmp == 0) { + return; + } + const x0_len = llnormalize(x0); + const x1_len = llnormalize(x1); + var j0 = try allocator.alloc(Limb, math.max(x0_len, x1_len)); + defer allocator.free(j0); + if (x_cmp == 1) { + llsub(j0, x1[0..x1_len], x0[0..x0_len]); + } else { + llsub(j0, x0[0..x0_len], x1[0..x1_len]); + } + + const y0_len = llnormalize(y0); + const y1_len = llnormalize(y1); + var j1 = try allocator.alloc(Limb, math.max(y0_len, y1_len)); + defer allocator.free(j1); + if (y_cmp == 1) { + llsub(j1, y1[0..y1_len], y0[0..y0_len]); + } else { + llsub(j1, y0[0..y0_len], y1[0..y1_len]); + } + const j0_len = llnormalize(j0); + const j1_len = llnormalize(j1); + if (x_cmp == y_cmp) { + mem.set(Limb, tmp[0..length], 0); + try llmulacc(allocator, tmp, j0, j1); + + length = Int.llnormalize(tmp); + llsub(r[split..], r[split..], tmp[0..length]); + } else { + try llmulacc(allocator, r[split..], j0, j1); + } + } + } + + // r = r + a + fn llaccum(r: []Limb, a: []const Limb) Limb { + @setRuntimeSafety(false); + debug.assert(r.len != 0 and a.len != 0); + debug.assert(r.len >= a.len); var i: usize = 0; + var carry: Limb = 0; + while (i < a.len) : (i += 1) { - var carry: Limb = 0; - var j: usize = 0; - while (j < b.len) : (j += 1) { - r[i + j] = @inlineCall(addMulLimbWithCarry, r[i + j], a[i], b[j], &carry); - } - r[i + j] = carry; + var c: Limb = 0; + c += @boolToInt(@addWithOverflow(Limb, r[i], a[i], &r[i])); + c += @boolToInt(@addWithOverflow(Limb, r[i], carry, &r[i])); + carry = c; } + + while ((carry != 0) and i < r.len) : (i += 1) { + carry = @boolToInt(@addWithOverflow(Limb, r[i], carry, &r[i])); + } + + return carry; + } + + /// Returns -1, 0, 1 if |a| < |b|, |a| == |b| or |a| > |b| respectively for limbs. + pub fn llcmp(a: []const Limb, b: []const Limb) i8 { + @setRuntimeSafety(false); + const a_len = llnormalize(a); + const b_len = llnormalize(b); + if (a_len < b_len) { + return -1; + } + if (a_len > b_len) { + return 1; + } + + var i: usize = a_len - 1; + while (i != 0) : (i -= 1) { + if (a[i] != b[i]) { + break; + } + } + + if (a[i] < b[i]) { + return -1; + } else if (a[i] > b[i]) { + return 1; + } else { + return 0; + } + } + + // returns the min length the limb could be. + fn llnormalize(a: []const Limb) usize { + @setRuntimeSafety(false); + var j = a.len; + while (j > 0) : (j -= 1) { + if (a[j - 1] != 0) { + break; + } + } + + // Handle zero + return if (j != 0) j else 1; } /// q = a / b (rem r)