diff --git a/lib/std/crypto/25519/scalar.zig b/lib/std/crypto/25519/scalar.zig index ee3a59c244..c3170673d1 100644 --- a/lib/std/crypto/25519/scalar.zig +++ b/lib/std/crypto/25519/scalar.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const crypto = std.crypto; const mem = std.mem; const NonCanonicalError = std.crypto.errors.NonCanonicalError; @@ -15,7 +16,7 @@ pub const CompressedScalar = [32]u8; pub const zero = [_]u8{0} ** 32; /// Reject a scalar whose encoding is not canonical. -pub fn rejectNonCanonical(s: [32]u8) NonCanonicalError!void { +pub fn rejectNonCanonical(s: CompressedScalar) NonCanonicalError!void { var c: u8 = 0; var n: u8 = 1; var i: usize = 31; @@ -32,34 +33,34 @@ pub fn rejectNonCanonical(s: [32]u8) NonCanonicalError!void { } /// Reduce a scalar to the field size. -pub fn reduce(s: [32]u8) [32]u8 { +pub fn reduce(s: CompressedScalar) CompressedScalar { return Scalar.fromBytes(s).toBytes(); } /// Reduce a 64-bytes scalar to the field size. -pub fn reduce64(s: [64]u8) [32]u8 { +pub fn reduce64(s: [64]u8) CompressedScalar { return ScalarDouble.fromBytes64(s).toBytes(); } /// Perform the X25519 "clamping" operation. /// The scalar is then guaranteed to be a multiple of the cofactor. -pub inline fn clamp(s: *[32]u8) void { +pub inline fn clamp(s: *CompressedScalar) void { s[0] &= 248; s[31] = (s[31] & 127) | 64; } /// Return a*b (mod L) -pub fn mul(a: [32]u8, b: [32]u8) [32]u8 { +pub fn mul(a: CompressedScalar, b: CompressedScalar) CompressedScalar { return Scalar.fromBytes(a).mul(Scalar.fromBytes(b)).toBytes(); } /// Return a*b+c (mod L) -pub fn mulAdd(a: [32]u8, b: [32]u8, c: [32]u8) [32]u8 { +pub fn mulAdd(a: CompressedScalar, b: CompressedScalar, c: CompressedScalar) CompressedScalar { return Scalar.fromBytes(a).mul(Scalar.fromBytes(b)).add(Scalar.fromBytes(c)).toBytes(); } /// Return a*8 (mod L) -pub fn mul8(s: [32]u8) [32]u8 { +pub fn mul8(s: CompressedScalar) CompressedScalar { var x = Scalar.fromBytes(s); x = x.add(x); x = x.add(x); @@ -68,12 +69,12 @@ pub fn mul8(s: [32]u8) [32]u8 { } /// Return a+b (mod L) -pub fn add(a: [32]u8, b: [32]u8) [32]u8 { +pub fn add(a: CompressedScalar, b: CompressedScalar) CompressedScalar { return Scalar.fromBytes(a).add(Scalar.fromBytes(b)).toBytes(); } /// Return -s (mod L) -pub fn neg(s: [32]u8) [32]u8 { +pub fn neg(s: CompressedScalar) CompressedScalar { const fs: [64]u8 = field_size ++ [_]u8{0} ** 32; var sx: [64]u8 = undefined; mem.copy(u8, sx[0..32], s[0..]); @@ -89,23 +90,33 @@ pub fn neg(s: [32]u8) [32]u8 { } /// Return (a-b) (mod L) -pub fn sub(a: [32]u8, b: [32]u8) [32]u8 { +pub fn sub(a: CompressedScalar, b: CompressedScalar) CompressedScalar { return add(a, neg(b)); } +/// Return a random scalar < L +pub fn random() CompressedScalar { + return Scalar.random().toBytes(); +} + /// A scalar in unpacked representation pub const Scalar = struct { const Limbs = [5]u64; limbs: Limbs = undefined, /// Unpack a 32-byte representation of a scalar - pub fn fromBytes(bytes: [32]u8) Scalar { + pub fn fromBytes(bytes: CompressedScalar) Scalar { return ScalarDouble.fromBytes32(bytes).reduce(5); } + /// Unpack a 64-byte representation of a scalar + pub fn fromBytes64(bytes: [64]u8) Scalar { + return ScalarDouble.fromBytes64(bytes).reduce(5); + } + /// Pack a scalar into bytes - pub fn toBytes(expanded: *const Scalar) [32]u8 { - var bytes: [32]u8 = undefined; + pub fn toBytes(expanded: *const Scalar) CompressedScalar { + var bytes: CompressedScalar = undefined; var i: usize = 0; while (i < 4) : (i += 1) { mem.writeIntLittle(u64, bytes[i * 7 ..][0..8], expanded.limbs[i]); @@ -114,7 +125,13 @@ pub const Scalar = struct { return bytes; } - /// Return x+y (mod l) + /// Return true if the scalar is zero + pub fn isZero(n: Scalar) bool { + const limbs = n.limbs; + return (limbs[0] | limbs[1] | limbs[2] | limbs[3] | limbs[4]) == 0; + } + + /// Return x+y (mod L) pub fn add(x: Scalar, y: Scalar) Scalar { const carry0 = (x.limbs[0] + y.limbs[0]) >> 56; const t0 = (x.limbs[0] + y.limbs[0]) & 0xffffffffffffff; @@ -171,7 +188,7 @@ pub const Scalar = struct { return Scalar{ .limbs = .{ z00, z10, z20, z30, z40 } }; } - /// Return x*r (mod l) + /// Return x*r (mod L) pub fn mul(x: Scalar, y: Scalar) Scalar { const xy000 = @as(u128, x.limbs[0]) * @as(u128, y.limbs[0]); const xy010 = @as(u128, x.limbs[0]) * @as(u128, y.limbs[1]); @@ -483,7 +500,7 @@ pub const Scalar = struct { return Scalar{ .limbs = .{ z04, z14, z24, z34, z44 } }; } - /// Return x^2 (mod l) + /// Return x^2 (mod L) pub fn sq(x: Scalar) Scalar { return x.mul(x); } @@ -503,7 +520,7 @@ pub const Scalar = struct { return x.sqn(n).mul(y); } - /// Return the inverse of a scalar (mod l), or 0 if x=0. + /// Return the inverse of a scalar (mod L), or 0 if x=0. pub fn invert(x: Scalar) Scalar { const _10 = x.sq(); const _11 = x.mul(_10); @@ -533,6 +550,18 @@ pub const Scalar = struct { .sqn_mul(9, _1101011).sqn_mul(6, _1011).sqn_mul(14, _10010011).sqn_mul(10, _1100011) .sqn_mul(9, _10010111).sqn_mul(10, _11110101).sqn_mul(8, _11010011).sqn_mul(8, _11101011); } + + /// Return a random scalar < L. + pub fn random() Scalar { + var s: [64]u8 = undefined; + while (true) { + crypto.random.bytes(&s); + const n = Scalar.fromBytes64(s); + if (!n.isZero()) { + return n; + } + } + } }; const ScalarDouble = struct { @@ -549,7 +578,7 @@ const ScalarDouble = struct { return ScalarDouble{ .limbs = limbs }; } - fn fromBytes32(bytes: [32]u8) ScalarDouble { + fn fromBytes32(bytes: CompressedScalar) ScalarDouble { var limbs: Limbs = undefined; var i: usize = 0; while (i < 4) : (i += 1) { @@ -560,7 +589,7 @@ const ScalarDouble = struct { return ScalarDouble{ .limbs = limbs }; } - fn toBytes(expanded_double: *ScalarDouble) [32]u8 { + fn toBytes(expanded_double: *ScalarDouble) CompressedScalar { return expanded_double.reduce(10).toBytes(); } @@ -840,3 +869,15 @@ test "scalar field inversion" { const recovered_x = inv.invert(); try std.testing.expectEqualSlices(u8, &bytes, &recovered_x.toBytes()); } + +test "random scalar" { + const s1 = random(); + const s2 = random(); + try std.testing.expect(!mem.eql(u8, &s1, &s2)); +} + +test "64-bit reduction" { + const bytes = field_size ++ [_]u8{0} ** 32; + const x = Scalar.fromBytes64(bytes); + try std.testing.expect(x.isZero()); +}