From 30d392a87e7ebd51269a557b6bccbc3163e8db75 Mon Sep 17 00:00:00 2001 From: Frank Denis <124872+jedisct1@users.noreply.github.com> Date: Sun, 6 Nov 2022 23:52:41 +0100 Subject: [PATCH] crypto.salsa20: make the number of rounds a comptime parameter (#13442) ...instead of hard-coding it to 20. - This is consistent with the ChaCha implementation - NaCl and libsodium, that this API is designed to interop with, also support 8 and 12 round variants. The 12 round variant, in particular, provides the same security level as the 20 round variant, but is obviously faster. - scrypt currently uses its own non optimized version of Salsa, just because it use 8 rounds instead of 20. This will help remove code duplication. No behavior nor public API changes. The Salsa20 and XSalsa20 still represent the 20-round variant. --- lib/std/crypto.zig | 2 + lib/std/crypto/salsa20.zig | 600 +++++++++++++++++++------------------ 2 files changed, 303 insertions(+), 299 deletions(-) diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 46dfa6a715..f9fa50c692 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -147,6 +147,8 @@ pub const stream = struct { }; pub const salsa = struct { + pub const Salsa = @import("crypto/salsa20.zig").Salsa; + pub const XSalsa = @import("crypto/salsa20.zig").XSalsa; pub const Salsa20 = @import("crypto/salsa20.zig").Salsa20; pub const XSalsa20 = @import("crypto/salsa20.zig").XSalsa20; }; diff --git a/lib/std/crypto/salsa20.zig b/lib/std/crypto/salsa20.zig index 7477b7ad69..c4cd86b0e4 100644 --- a/lib/std/crypto/salsa20.zig +++ b/lib/std/crypto/salsa20.zig @@ -14,297 +14,293 @@ const AuthenticationError = crypto.errors.AuthenticationError; const IdentityElementError = crypto.errors.IdentityElementError; const WeakPublicKeyError = crypto.errors.WeakPublicKeyError; -const Salsa20VecImpl = struct { - const Lane = @Vector(4, u32); - const Half = @Vector(2, u32); - const BlockVec = [4]Lane; +/// The Salsa cipher with 20 rounds. +pub const Salsa20 = Salsa(20); - fn initContext(key: [8]u32, d: [4]u32) BlockVec { - const c = "expand 32-byte k"; - const constant_le = comptime [4]u32{ - mem.readIntLittle(u32, c[0..4]), - mem.readIntLittle(u32, c[4..8]), - mem.readIntLittle(u32, c[8..12]), - mem.readIntLittle(u32, c[12..16]), - }; - return BlockVec{ - Lane{ key[0], key[1], key[2], key[3] }, - Lane{ key[4], key[5], key[6], key[7] }, - Lane{ constant_le[0], constant_le[1], constant_le[2], constant_le[3] }, - Lane{ d[0], d[1], d[2], d[3] }, - }; - } +/// The XSalsa cipher with 20 rounds. +pub const XSalsa20 = XSalsa(20); - inline fn salsa20Core(x: *BlockVec, input: BlockVec, comptime feedback: bool) void { - const n1n2n3n0 = Lane{ input[3][1], input[3][2], input[3][3], input[3][0] }; - const n1n2 = Half{ n1n2n3n0[0], n1n2n3n0[1] }; - const n3n0 = Half{ n1n2n3n0[2], n1n2n3n0[3] }; - const k0k1 = Half{ input[0][0], input[0][1] }; - const k2k3 = Half{ input[0][2], input[0][3] }; - const k4k5 = Half{ input[1][0], input[1][1] }; - const k6k7 = Half{ input[1][2], input[1][3] }; - const n0k0 = Half{ n3n0[1], k0k1[0] }; - const k0n0 = Half{ n0k0[1], n0k0[0] }; - const k4k5k0n0 = Lane{ k4k5[0], k4k5[1], k0n0[0], k0n0[1] }; - const k1k6 = Half{ k0k1[1], k6k7[0] }; - const k6k1 = Half{ k1k6[1], k1k6[0] }; - const n1n2k6k1 = Lane{ n1n2[0], n1n2[1], k6k1[0], k6k1[1] }; - const k7n3 = Half{ k6k7[1], n3n0[0] }; - const n3k7 = Half{ k7n3[1], k7n3[0] }; - const k2k3n3k7 = Lane{ k2k3[0], k2k3[1], n3k7[0], n3k7[1] }; +fn SalsaVecImpl(comptime rounds: comptime_int) type { + return struct { + const Lane = @Vector(4, u32); + const Half = @Vector(2, u32); + const BlockVec = [4]Lane; - var diag0 = input[2]; - var diag1 = @shuffle(u32, k4k5k0n0, undefined, [_]i32{ 1, 2, 3, 0 }); - var diag2 = @shuffle(u32, n1n2k6k1, undefined, [_]i32{ 1, 2, 3, 0 }); - var diag3 = @shuffle(u32, k2k3n3k7, undefined, [_]i32{ 1, 2, 3, 0 }); - - const start0 = diag0; - const start1 = diag1; - const start2 = diag2; - const start3 = diag3; - - var i: usize = 0; - while (i < 20) : (i += 2) { - var a0 = diag1 +% diag0; - diag3 ^= math.rotl(Lane, a0, 7); - var a1 = diag0 +% diag3; - diag2 ^= math.rotl(Lane, a1, 9); - var a2 = diag3 +% diag2; - diag1 ^= math.rotl(Lane, a2, 13); - var a3 = diag2 +% diag1; - diag0 ^= math.rotl(Lane, a3, 18); - - var diag3_shift = @shuffle(u32, diag3, undefined, [_]i32{ 3, 0, 1, 2 }); - var diag2_shift = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 }); - var diag1_shift = @shuffle(u32, diag1, undefined, [_]i32{ 1, 2, 3, 0 }); - diag3 = diag3_shift; - diag2 = diag2_shift; - diag1 = diag1_shift; - - a0 = diag3 +% diag0; - diag1 ^= math.rotl(Lane, a0, 7); - a1 = diag0 +% diag1; - diag2 ^= math.rotl(Lane, a1, 9); - a2 = diag1 +% diag2; - diag3 ^= math.rotl(Lane, a2, 13); - a3 = diag2 +% diag3; - diag0 ^= math.rotl(Lane, a3, 18); - - diag1_shift = @shuffle(u32, diag1, undefined, [_]i32{ 3, 0, 1, 2 }); - diag2_shift = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 }); - diag3_shift = @shuffle(u32, diag3, undefined, [_]i32{ 1, 2, 3, 0 }); - diag1 = diag1_shift; - diag2 = diag2_shift; - diag3 = diag3_shift; + fn initContext(key: [8]u32, d: [4]u32) BlockVec { + const c = "expand 32-byte k"; + const constant_le = comptime [4]u32{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + Lane{ key[0], key[1], key[2], key[3] }, + Lane{ key[4], key[5], key[6], key[7] }, + Lane{ constant_le[0], constant_le[1], constant_le[2], constant_le[3] }, + Lane{ d[0], d[1], d[2], d[3] }, + }; } - if (feedback) { - diag0 +%= start0; - diag1 +%= start1; - diag2 +%= start2; - diag3 +%= start3; - } + inline fn salsaCore(x: *BlockVec, input: BlockVec, comptime feedback: bool) void { + const n1n2n3n0 = Lane{ input[3][1], input[3][2], input[3][3], input[3][0] }; + const n1n2 = Half{ n1n2n3n0[0], n1n2n3n0[1] }; + const n3n0 = Half{ n1n2n3n0[2], n1n2n3n0[3] }; + const k0k1 = Half{ input[0][0], input[0][1] }; + const k2k3 = Half{ input[0][2], input[0][3] }; + const k4k5 = Half{ input[1][0], input[1][1] }; + const k6k7 = Half{ input[1][2], input[1][3] }; + const n0k0 = Half{ n3n0[1], k0k1[0] }; + const k0n0 = Half{ n0k0[1], n0k0[0] }; + const k4k5k0n0 = Lane{ k4k5[0], k4k5[1], k0n0[0], k0n0[1] }; + const k1k6 = Half{ k0k1[1], k6k7[0] }; + const k6k1 = Half{ k1k6[1], k1k6[0] }; + const n1n2k6k1 = Lane{ n1n2[0], n1n2[1], k6k1[0], k6k1[1] }; + const k7n3 = Half{ k6k7[1], n3n0[0] }; + const n3k7 = Half{ k7n3[1], k7n3[0] }; + const k2k3n3k7 = Lane{ k2k3[0], k2k3[1], n3k7[0], n3k7[1] }; - const x0x1x10x11 = Lane{ diag0[0], diag1[1], diag0[2], diag1[3] }; - const x12x13x6x7 = Lane{ diag1[0], diag2[1], diag1[2], diag2[3] }; - const x8x9x2x3 = Lane{ diag2[0], diag3[1], diag2[2], diag3[3] }; - const x4x5x14x15 = Lane{ diag3[0], diag0[1], diag3[2], diag0[3] }; + var diag0 = input[2]; + var diag1 = @shuffle(u32, k4k5k0n0, undefined, [_]i32{ 1, 2, 3, 0 }); + var diag2 = @shuffle(u32, n1n2k6k1, undefined, [_]i32{ 1, 2, 3, 0 }); + var diag3 = @shuffle(u32, k2k3n3k7, undefined, [_]i32{ 1, 2, 3, 0 }); - x[0] = Lane{ x0x1x10x11[0], x0x1x10x11[1], x8x9x2x3[2], x8x9x2x3[3] }; - x[1] = Lane{ x4x5x14x15[0], x4x5x14x15[1], x12x13x6x7[2], x12x13x6x7[3] }; - x[2] = Lane{ x8x9x2x3[0], x8x9x2x3[1], x0x1x10x11[2], x0x1x10x11[3] }; - x[3] = Lane{ x12x13x6x7[0], x12x13x6x7[1], x4x5x14x15[2], x4x5x14x15[3] }; - } + const start0 = diag0; + const start1 = diag1; + const start2 = diag2; + const start3 = diag3; - fn hashToBytes(out: *[64]u8, x: BlockVec) void { - var i: usize = 0; - while (i < 4) : (i += 1) { - mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]); - mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]); - mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]); - mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]); - } - } + var i: usize = 0; + while (i < rounds) : (i += 2) { + diag3 ^= math.rotl(Lane, diag1 +% diag0, 7); + diag2 ^= math.rotl(Lane, diag0 +% diag3, 9); + diag1 ^= math.rotl(Lane, diag3 +% diag2, 13); + diag0 ^= math.rotl(Lane, diag2 +% diag1, 18); - fn salsa20Xor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void { - var ctx = initContext(key, d); - var x: BlockVec = undefined; - var buf: [64]u8 = undefined; - var i: usize = 0; - while (i + 64 <= in.len) : (i += 64) { - salsa20Core(x[0..], ctx, true); - hashToBytes(buf[0..], x); - var xout = out[i..]; - const xin = in[i..]; - var j: usize = 0; - while (j < 64) : (j += 1) { - xout[j] = xin[j]; + diag3 = @shuffle(u32, diag3, undefined, [_]i32{ 3, 0, 1, 2 }); + diag2 = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 }); + diag1 = @shuffle(u32, diag1, undefined, [_]i32{ 1, 2, 3, 0 }); + + diag1 ^= math.rotl(Lane, diag3 +% diag0, 7); + diag2 ^= math.rotl(Lane, diag0 +% diag1, 9); + diag3 ^= math.rotl(Lane, diag1 +% diag2, 13); + diag0 ^= math.rotl(Lane, diag2 +% diag3, 18); + + diag1 = @shuffle(u32, diag1, undefined, [_]i32{ 3, 0, 1, 2 }); + diag2 = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 }); + diag3 = @shuffle(u32, diag3, undefined, [_]i32{ 1, 2, 3, 0 }); } - j = 0; - while (j < 64) : (j += 1) { - xout[j] ^= buf[j]; + + if (feedback) { + diag0 +%= start0; + diag1 +%= start1; + diag2 +%= start2; + diag3 +%= start3; } - ctx[3][2] +%= 1; - if (ctx[3][2] == 0) { - ctx[3][3] += 1; + + const x0x1x10x11 = Lane{ diag0[0], diag1[1], diag0[2], diag1[3] }; + const x12x13x6x7 = Lane{ diag1[0], diag2[1], diag1[2], diag2[3] }; + const x8x9x2x3 = Lane{ diag2[0], diag3[1], diag2[2], diag3[3] }; + const x4x5x14x15 = Lane{ diag3[0], diag0[1], diag3[2], diag0[3] }; + + x[0] = Lane{ x0x1x10x11[0], x0x1x10x11[1], x8x9x2x3[2], x8x9x2x3[3] }; + x[1] = Lane{ x4x5x14x15[0], x4x5x14x15[1], x12x13x6x7[2], x12x13x6x7[3] }; + x[2] = Lane{ x8x9x2x3[0], x8x9x2x3[1], x0x1x10x11[2], x0x1x10x11[3] }; + x[3] = Lane{ x12x13x6x7[0], x12x13x6x7[1], x4x5x14x15[2], x4x5x14x15[3] }; + } + + fn hashToBytes(out: *[64]u8, x: BlockVec) void { + var i: usize = 0; + while (i < 4) : (i += 1) { + mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]); + mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]); + mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]); + mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]); } } - if (i < in.len) { - salsa20Core(x[0..], ctx, true); - hashToBytes(buf[0..], x); - var xout = out[i..]; - const xin = in[i..]; - var j: usize = 0; - while (j < in.len % 64) : (j += 1) { - xout[j] = xin[j] ^ buf[j]; + fn salsaXor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void { + var ctx = initContext(key, d); + var x: BlockVec = undefined; + var buf: [64]u8 = undefined; + var i: usize = 0; + while (i + 64 <= in.len) : (i += 64) { + salsaCore(x[0..], ctx, true); + hashToBytes(buf[0..], x); + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < 64) : (j += 1) { + xout[j] = xin[j]; + } + j = 0; + while (j < 64) : (j += 1) { + xout[j] ^= buf[j]; + } + ctx[3][2] +%= 1; + if (ctx[3][2] == 0) { + ctx[3][3] += 1; + } + } + if (i < in.len) { + salsaCore(x[0..], ctx, true); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < in.len % 64) : (j += 1) { + xout[j] = xin[j] ^ buf[j]; + } } } - } - fn hsalsa20(input: [16]u8, key: [32]u8) [32]u8 { - var c: [4]u32 = undefined; - for (c) |_, i| { - c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + fn hsalsa(input: [16]u8, key: [32]u8) [32]u8 { + var c: [4]u32 = undefined; + for (c) |_, i| { + c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + } + const ctx = initContext(keyToWords(key), c); + var x: BlockVec = undefined; + salsaCore(x[0..], ctx, false); + var out: [32]u8 = undefined; + mem.writeIntLittle(u32, out[0..4], x[0][0]); + mem.writeIntLittle(u32, out[4..8], x[1][1]); + mem.writeIntLittle(u32, out[8..12], x[2][2]); + mem.writeIntLittle(u32, out[12..16], x[3][3]); + mem.writeIntLittle(u32, out[16..20], x[1][2]); + mem.writeIntLittle(u32, out[20..24], x[1][3]); + mem.writeIntLittle(u32, out[24..28], x[2][0]); + mem.writeIntLittle(u32, out[28..32], x[2][1]); + return out; } - const ctx = initContext(keyToWords(key), c); - var x: BlockVec = undefined; - salsa20Core(x[0..], ctx, false); - var out: [32]u8 = undefined; - mem.writeIntLittle(u32, out[0..4], x[0][0]); - mem.writeIntLittle(u32, out[4..8], x[1][1]); - mem.writeIntLittle(u32, out[8..12], x[2][2]); - mem.writeIntLittle(u32, out[12..16], x[3][3]); - mem.writeIntLittle(u32, out[16..20], x[1][2]); - mem.writeIntLittle(u32, out[20..24], x[1][3]); - mem.writeIntLittle(u32, out[24..28], x[2][0]); - mem.writeIntLittle(u32, out[28..32], x[2][1]); - return out; - } -}; - -const Salsa20NonVecImpl = struct { - const BlockVec = [16]u32; - - fn initContext(key: [8]u32, d: [4]u32) BlockVec { - const c = "expand 32-byte k"; - const constant_le = comptime [4]u32{ - mem.readIntLittle(u32, c[0..4]), - mem.readIntLittle(u32, c[4..8]), - mem.readIntLittle(u32, c[8..12]), - mem.readIntLittle(u32, c[12..16]), - }; - return BlockVec{ - constant_le[0], key[0], key[1], key[2], - key[3], constant_le[1], d[0], d[1], - d[2], d[3], constant_le[2], key[4], - key[5], key[6], key[7], constant_le[3], - }; - } - - const QuarterRound = struct { - a: usize, - b: usize, - c: usize, - d: u6, }; +} - inline fn Rp(a: usize, b: usize, c: usize, d: u6) QuarterRound { - return QuarterRound{ - .a = a, - .b = b, - .c = c, - .d = d, +fn SalsaNonVecImpl(comptime rounds: comptime_int) type { + return struct { + const BlockVec = [16]u32; + + fn initContext(key: [8]u32, d: [4]u32) BlockVec { + const c = "expand 32-byte k"; + const constant_le = comptime [4]u32{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + constant_le[0], key[0], key[1], key[2], + key[3], constant_le[1], d[0], d[1], + d[2], d[3], constant_le[2], key[4], + key[5], key[6], key[7], constant_le[3], + }; + } + + const QuarterRound = struct { + a: usize, + b: usize, + c: usize, + d: u6, }; - } - inline fn salsa20Core(x: *BlockVec, input: BlockVec, comptime feedback: bool) void { - const arx_steps = comptime [_]QuarterRound{ - Rp(4, 0, 12, 7), Rp(8, 4, 0, 9), Rp(12, 8, 4, 13), Rp(0, 12, 8, 18), - Rp(9, 5, 1, 7), Rp(13, 9, 5, 9), Rp(1, 13, 9, 13), Rp(5, 1, 13, 18), - Rp(14, 10, 6, 7), Rp(2, 14, 10, 9), Rp(6, 2, 14, 13), Rp(10, 6, 2, 18), - Rp(3, 15, 11, 7), Rp(7, 3, 15, 9), Rp(11, 7, 3, 13), Rp(15, 11, 7, 18), - Rp(1, 0, 3, 7), Rp(2, 1, 0, 9), Rp(3, 2, 1, 13), Rp(0, 3, 2, 18), - Rp(6, 5, 4, 7), Rp(7, 6, 5, 9), Rp(4, 7, 6, 13), Rp(5, 4, 7, 18), - Rp(11, 10, 9, 7), Rp(8, 11, 10, 9), Rp(9, 8, 11, 13), Rp(10, 9, 8, 18), - Rp(12, 15, 14, 7), Rp(13, 12, 15, 9), Rp(14, 13, 12, 13), Rp(15, 14, 13, 18), - }; - x.* = input; - var j: usize = 0; - while (j < 20) : (j += 2) { - inline for (arx_steps) |r| { - x[r.a] ^= math.rotl(u32, x[r.b] +% x[r.c], r.d); - } + inline fn Rp(a: usize, b: usize, c: usize, d: u6) QuarterRound { + return QuarterRound{ + .a = a, + .b = b, + .c = c, + .d = d, + }; } - if (feedback) { - j = 0; - while (j < 16) : (j += 1) { - x[j] +%= input[j]; - } - } - } - fn hashToBytes(out: *[64]u8, x: BlockVec) void { - for (x) |w, i| { - mem.writeIntLittle(u32, out[i * 4 ..][0..4], w); - } - } - - fn salsa20Xor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void { - var ctx = initContext(key, d); - var x: BlockVec = undefined; - var buf: [64]u8 = undefined; - var i: usize = 0; - while (i + 64 <= in.len) : (i += 64) { - salsa20Core(x[0..], ctx, true); - hashToBytes(buf[0..], x); - var xout = out[i..]; - const xin = in[i..]; + inline fn salsaCore(x: *BlockVec, input: BlockVec, comptime feedback: bool) void { + const arx_steps = comptime [_]QuarterRound{ + Rp(4, 0, 12, 7), Rp(8, 4, 0, 9), Rp(12, 8, 4, 13), Rp(0, 12, 8, 18), + Rp(9, 5, 1, 7), Rp(13, 9, 5, 9), Rp(1, 13, 9, 13), Rp(5, 1, 13, 18), + Rp(14, 10, 6, 7), Rp(2, 14, 10, 9), Rp(6, 2, 14, 13), Rp(10, 6, 2, 18), + Rp(3, 15, 11, 7), Rp(7, 3, 15, 9), Rp(11, 7, 3, 13), Rp(15, 11, 7, 18), + Rp(1, 0, 3, 7), Rp(2, 1, 0, 9), Rp(3, 2, 1, 13), Rp(0, 3, 2, 18), + Rp(6, 5, 4, 7), Rp(7, 6, 5, 9), Rp(4, 7, 6, 13), Rp(5, 4, 7, 18), + Rp(11, 10, 9, 7), Rp(8, 11, 10, 9), Rp(9, 8, 11, 13), Rp(10, 9, 8, 18), + Rp(12, 15, 14, 7), Rp(13, 12, 15, 9), Rp(14, 13, 12, 13), Rp(15, 14, 13, 18), + }; + x.* = input; var j: usize = 0; - while (j < 64) : (j += 1) { - xout[j] = xin[j]; + while (j < rounds) : (j += 2) { + inline for (arx_steps) |r| { + x[r.a] ^= math.rotl(u32, x[r.b] +% x[r.c], r.d); + } } - j = 0; - while (j < 64) : (j += 1) { - xout[j] ^= buf[j]; - } - ctx[9] += @boolToInt(@addWithOverflow(u32, ctx[8], 1, &ctx[8])); - } - if (i < in.len) { - salsa20Core(x[0..], ctx, true); - hashToBytes(buf[0..], x); - - var xout = out[i..]; - const xin = in[i..]; - var j: usize = 0; - while (j < in.len % 64) : (j += 1) { - xout[j] = xin[j] ^ buf[j]; + if (feedback) { + j = 0; + while (j < 16) : (j += 1) { + x[j] +%= input[j]; + } } } - } - fn hsalsa20(input: [16]u8, key: [32]u8) [32]u8 { - var c: [4]u32 = undefined; - for (c) |_, i| { - c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + fn hashToBytes(out: *[64]u8, x: BlockVec) void { + for (x) |w, i| { + mem.writeIntLittle(u32, out[i * 4 ..][0..4], w); + } } - const ctx = initContext(keyToWords(key), c); - var x: BlockVec = undefined; - salsa20Core(x[0..], ctx, false); - var out: [32]u8 = undefined; - mem.writeIntLittle(u32, out[0..4], x[0]); - mem.writeIntLittle(u32, out[4..8], x[5]); - mem.writeIntLittle(u32, out[8..12], x[10]); - mem.writeIntLittle(u32, out[12..16], x[15]); - mem.writeIntLittle(u32, out[16..20], x[6]); - mem.writeIntLittle(u32, out[20..24], x[7]); - mem.writeIntLittle(u32, out[24..28], x[8]); - mem.writeIntLittle(u32, out[28..32], x[9]); - return out; - } -}; -const Salsa20Impl = if (builtin.cpu.arch == .x86_64) Salsa20VecImpl else Salsa20NonVecImpl; + fn salsaXor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void { + var ctx = initContext(key, d); + var x: BlockVec = undefined; + var buf: [64]u8 = undefined; + var i: usize = 0; + while (i + 64 <= in.len) : (i += 64) { + salsaCore(x[0..], ctx, true); + hashToBytes(buf[0..], x); + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < 64) : (j += 1) { + xout[j] = xin[j]; + } + j = 0; + while (j < 64) : (j += 1) { + xout[j] ^= buf[j]; + } + ctx[9] += @boolToInt(@addWithOverflow(u32, ctx[8], 1, &ctx[8])); + } + if (i < in.len) { + salsaCore(x[0..], ctx, true); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < in.len % 64) : (j += 1) { + xout[j] = xin[j] ^ buf[j]; + } + } + } + + fn hsalsa(input: [16]u8, key: [32]u8) [32]u8 { + var c: [4]u32 = undefined; + for (c) |_, i| { + c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + } + const ctx = initContext(keyToWords(key), c); + var x: BlockVec = undefined; + salsaCore(x[0..], ctx, false); + var out: [32]u8 = undefined; + mem.writeIntLittle(u32, out[0..4], x[0]); + mem.writeIntLittle(u32, out[4..8], x[5]); + mem.writeIntLittle(u32, out[8..12], x[10]); + mem.writeIntLittle(u32, out[12..16], x[15]); + mem.writeIntLittle(u32, out[16..20], x[6]); + mem.writeIntLittle(u32, out[20..24], x[7]); + mem.writeIntLittle(u32, out[24..28], x[8]); + mem.writeIntLittle(u32, out[28..32], x[9]); + return out; + } + }; +} + +const SalsaImpl = if (builtin.cpu.arch == .x86_64) SalsaVecImpl else SalsaNonVecImpl; fn keyToWords(key: [32]u8) [8]u32 { var k: [8]u32 = undefined; @@ -315,52 +311,56 @@ fn keyToWords(key: [32]u8) [8]u32 { return k; } -fn extend(key: [32]u8, nonce: [24]u8) struct { key: [32]u8, nonce: [8]u8 } { +fn extend(comptime rounds: comptime_int, key: [32]u8, nonce: [24]u8) struct { key: [32]u8, nonce: [8]u8 } { return .{ - .key = Salsa20Impl.hsalsa20(nonce[0..16].*, key), + .key = SalsaImpl(rounds).hsalsa(nonce[0..16].*, key), .nonce = nonce[16..24].*, }; } -/// The Salsa20 stream cipher. -pub const Salsa20 = struct { - /// Nonce length in bytes. - pub const nonce_length = 8; - /// Key length in bytes. - pub const key_length = 32; +/// The Salsa stream cipher. +pub fn Salsa(comptime rounds: comptime_int) type { + return struct { + /// Nonce length in bytes. + pub const nonce_length = 8; + /// Key length in bytes. + pub const key_length = 32; - /// Add the output of the Salsa20 stream cipher to `in` and stores the result into `out`. - /// WARNING: This function doesn't provide authenticated encryption. - /// Using the AEAD or one of the `box` versions is usually preferred. - pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void { - debug.assert(in.len == out.len); + /// Add the output of the Salsa stream cipher to `in` and stores the result into `out`. + /// WARNING: This function doesn't provide authenticated encryption. + /// Using the AEAD or one of the `box` versions is usually preferred. + pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void { + debug.assert(in.len == out.len); - var d: [4]u32 = undefined; - d[0] = mem.readIntLittle(u32, nonce[0..4]); - d[1] = mem.readIntLittle(u32, nonce[4..8]); - d[2] = @truncate(u32, counter); - d[3] = @truncate(u32, counter >> 32); - Salsa20Impl.salsa20Xor(out, in, keyToWords(key), d); - } -}; + var d: [4]u32 = undefined; + d[0] = mem.readIntLittle(u32, nonce[0..4]); + d[1] = mem.readIntLittle(u32, nonce[4..8]); + d[2] = @truncate(u32, counter); + d[3] = @truncate(u32, counter >> 32); + SalsaImpl(rounds).salsaXor(out, in, keyToWords(key), d); + } + }; +} -/// The XSalsa20 stream cipher. -pub const XSalsa20 = struct { - /// Nonce length in bytes. - pub const nonce_length = 24; - /// Key length in bytes. - pub const key_length = 32; +/// The XSalsa stream cipher. +pub fn XSalsa(comptime rounds: comptime_int) type { + return struct { + /// Nonce length in bytes. + pub const nonce_length = 24; + /// Key length in bytes. + pub const key_length = 32; - /// Add the output of the XSalsa20 stream cipher to `in` and stores the result into `out`. - /// WARNING: This function doesn't provide authenticated encryption. - /// Using the AEAD or one of the `box` versions is usually preferred. - pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void { - const extended = extend(key, nonce); - Salsa20.xor(out, in, counter, extended.key, extended.nonce); - } -}; + /// Add the output of the XSalsa stream cipher to `in` and stores the result into `out`. + /// WARNING: This function doesn't provide authenticated encryption. + /// Using the AEAD or one of the `box` versions is usually preferred. + pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void { + const extended = extend(rounds, key, nonce); + Salsa(rounds).xor(out, in, counter, extended.key, extended.nonce); + } + }; +} -/// The XSalsa20 stream cipher, combined with the Poly1305 MAC +/// The XSalsa stream cipher, combined with the Poly1305 MAC pub const XSalsa20Poly1305 = struct { /// Authentication tag length in bytes. pub const tag_length = Poly1305.mac_length; @@ -369,6 +369,8 @@ pub const XSalsa20Poly1305 = struct { /// Key length in bytes. pub const key_length = XSalsa20.key_length; + const rounds = 20; + /// c: ciphertext: output buffer should be of size m.len /// tag: authentication tag: output MAC /// m: message @@ -377,7 +379,7 @@ pub const XSalsa20Poly1305 = struct { /// k: private key pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) void { debug.assert(c.len == m.len); - const extended = extend(k, npub); + const extended = extend(rounds, k, npub); var block0 = [_]u8{0} ** 64; const mlen0 = math.min(32, m.len); mem.copy(u8, block0[32..][0..mlen0], m[0..mlen0]); @@ -398,7 +400,7 @@ pub const XSalsa20Poly1305 = struct { /// k: private key pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) AuthenticationError!void { debug.assert(c.len == m.len); - const extended = extend(k, npub); + const extended = extend(rounds, k, npub); var block0 = [_]u8{0} ** 64; const mlen0 = math.min(32, c.len); mem.copy(u8, block0[32..][0..mlen0], c[0..mlen0]); @@ -482,7 +484,7 @@ pub const Box = struct { pub fn createSharedSecret(public_key: [public_length]u8, secret_key: [secret_length]u8) (IdentityElementError || WeakPublicKeyError)![shared_length]u8 { const p = try X25519.scalarmult(secret_key, public_key); const zero = [_]u8{0} ** 16; - return Salsa20Impl.hsalsa20(zero, p); + return SalsaImpl(20).hsalsa(zero, p); } /// Encrypt and authenticate a message using a recipient's public key `public_key` and a sender's `secret_key`.