crypto/25519: add scalar.random(), use CompressedScalar type

Add the ability to generate a random, canonical curve25519 scalar,
like we do for p256.

Also leverage the existing CompressedScalar type to represent these
scalars.
This commit is contained in:
Frank Denis 2022-05-26 01:38:32 +02:00 committed by Jakub Konka
parent 01607b54fc
commit b08d32ceb5

View File

@ -1,4 +1,5 @@
const std = @import("std");
const crypto = std.crypto;
const mem = std.mem;
const NonCanonicalError = std.crypto.errors.NonCanonicalError;
@ -15,7 +16,7 @@ pub const CompressedScalar = [32]u8;
pub const zero = [_]u8{0} ** 32;
/// Reject a scalar whose encoding is not canonical.
pub fn rejectNonCanonical(s: [32]u8) NonCanonicalError!void {
pub fn rejectNonCanonical(s: CompressedScalar) NonCanonicalError!void {
var c: u8 = 0;
var n: u8 = 1;
var i: usize = 31;
@ -32,34 +33,34 @@ pub fn rejectNonCanonical(s: [32]u8) NonCanonicalError!void {
}
/// Reduce a scalar to the field size.
pub fn reduce(s: [32]u8) [32]u8 {
pub fn reduce(s: CompressedScalar) CompressedScalar {
return Scalar.fromBytes(s).toBytes();
}
/// Reduce a 64-bytes scalar to the field size.
pub fn reduce64(s: [64]u8) [32]u8 {
pub fn reduce64(s: [64]u8) CompressedScalar {
return ScalarDouble.fromBytes64(s).toBytes();
}
/// Perform the X25519 "clamping" operation.
/// The scalar is then guaranteed to be a multiple of the cofactor.
pub inline fn clamp(s: *[32]u8) void {
pub inline fn clamp(s: *CompressedScalar) void {
s[0] &= 248;
s[31] = (s[31] & 127) | 64;
}
/// Return a*b (mod L)
pub fn mul(a: [32]u8, b: [32]u8) [32]u8 {
pub fn mul(a: CompressedScalar, b: CompressedScalar) CompressedScalar {
return Scalar.fromBytes(a).mul(Scalar.fromBytes(b)).toBytes();
}
/// Return a*b+c (mod L)
pub fn mulAdd(a: [32]u8, b: [32]u8, c: [32]u8) [32]u8 {
pub fn mulAdd(a: CompressedScalar, b: CompressedScalar, c: CompressedScalar) CompressedScalar {
return Scalar.fromBytes(a).mul(Scalar.fromBytes(b)).add(Scalar.fromBytes(c)).toBytes();
}
/// Return a*8 (mod L)
pub fn mul8(s: [32]u8) [32]u8 {
pub fn mul8(s: CompressedScalar) CompressedScalar {
var x = Scalar.fromBytes(s);
x = x.add(x);
x = x.add(x);
@ -68,12 +69,12 @@ pub fn mul8(s: [32]u8) [32]u8 {
}
/// Return a+b (mod L)
pub fn add(a: [32]u8, b: [32]u8) [32]u8 {
pub fn add(a: CompressedScalar, b: CompressedScalar) CompressedScalar {
return Scalar.fromBytes(a).add(Scalar.fromBytes(b)).toBytes();
}
/// Return -s (mod L)
pub fn neg(s: [32]u8) [32]u8 {
pub fn neg(s: CompressedScalar) CompressedScalar {
const fs: [64]u8 = field_size ++ [_]u8{0} ** 32;
var sx: [64]u8 = undefined;
mem.copy(u8, sx[0..32], s[0..]);
@ -89,23 +90,33 @@ pub fn neg(s: [32]u8) [32]u8 {
}
/// Return (a-b) (mod L)
pub fn sub(a: [32]u8, b: [32]u8) [32]u8 {
pub fn sub(a: CompressedScalar, b: CompressedScalar) CompressedScalar {
return add(a, neg(b));
}
/// Return a random scalar < L
pub fn random() CompressedScalar {
return Scalar.random().toBytes();
}
/// A scalar in unpacked representation
pub const Scalar = struct {
const Limbs = [5]u64;
limbs: Limbs = undefined,
/// Unpack a 32-byte representation of a scalar
pub fn fromBytes(bytes: [32]u8) Scalar {
pub fn fromBytes(bytes: CompressedScalar) Scalar {
return ScalarDouble.fromBytes32(bytes).reduce(5);
}
/// Unpack a 64-byte representation of a scalar
pub fn fromBytes64(bytes: [64]u8) Scalar {
return ScalarDouble.fromBytes64(bytes).reduce(5);
}
/// Pack a scalar into bytes
pub fn toBytes(expanded: *const Scalar) [32]u8 {
var bytes: [32]u8 = undefined;
pub fn toBytes(expanded: *const Scalar) CompressedScalar {
var bytes: CompressedScalar = undefined;
var i: usize = 0;
while (i < 4) : (i += 1) {
mem.writeIntLittle(u64, bytes[i * 7 ..][0..8], expanded.limbs[i]);
@ -114,7 +125,13 @@ pub const Scalar = struct {
return bytes;
}
/// Return x+y (mod l)
/// Return true if the scalar is zero
pub fn isZero(n: Scalar) bool {
const limbs = n.limbs;
return (limbs[0] | limbs[1] | limbs[2] | limbs[3] | limbs[4]) == 0;
}
/// Return x+y (mod L)
pub fn add(x: Scalar, y: Scalar) Scalar {
const carry0 = (x.limbs[0] + y.limbs[0]) >> 56;
const t0 = (x.limbs[0] + y.limbs[0]) & 0xffffffffffffff;
@ -171,7 +188,7 @@ pub const Scalar = struct {
return Scalar{ .limbs = .{ z00, z10, z20, z30, z40 } };
}
/// Return x*r (mod l)
/// Return x*r (mod L)
pub fn mul(x: Scalar, y: Scalar) Scalar {
const xy000 = @as(u128, x.limbs[0]) * @as(u128, y.limbs[0]);
const xy010 = @as(u128, x.limbs[0]) * @as(u128, y.limbs[1]);
@ -483,7 +500,7 @@ pub const Scalar = struct {
return Scalar{ .limbs = .{ z04, z14, z24, z34, z44 } };
}
/// Return x^2 (mod l)
/// Return x^2 (mod L)
pub fn sq(x: Scalar) Scalar {
return x.mul(x);
}
@ -503,7 +520,7 @@ pub const Scalar = struct {
return x.sqn(n).mul(y);
}
/// Return the inverse of a scalar (mod l), or 0 if x=0.
/// Return the inverse of a scalar (mod L), or 0 if x=0.
pub fn invert(x: Scalar) Scalar {
const _10 = x.sq();
const _11 = x.mul(_10);
@ -533,6 +550,18 @@ pub const Scalar = struct {
.sqn_mul(9, _1101011).sqn_mul(6, _1011).sqn_mul(14, _10010011).sqn_mul(10, _1100011)
.sqn_mul(9, _10010111).sqn_mul(10, _11110101).sqn_mul(8, _11010011).sqn_mul(8, _11101011);
}
/// Return a random scalar < L.
pub fn random() Scalar {
var s: [64]u8 = undefined;
while (true) {
crypto.random.bytes(&s);
const n = Scalar.fromBytes64(s);
if (!n.isZero()) {
return n;
}
}
}
};
const ScalarDouble = struct {
@ -549,7 +578,7 @@ const ScalarDouble = struct {
return ScalarDouble{ .limbs = limbs };
}
fn fromBytes32(bytes: [32]u8) ScalarDouble {
fn fromBytes32(bytes: CompressedScalar) ScalarDouble {
var limbs: Limbs = undefined;
var i: usize = 0;
while (i < 4) : (i += 1) {
@ -560,7 +589,7 @@ const ScalarDouble = struct {
return ScalarDouble{ .limbs = limbs };
}
fn toBytes(expanded_double: *ScalarDouble) [32]u8 {
fn toBytes(expanded_double: *ScalarDouble) CompressedScalar {
return expanded_double.reduce(10).toBytes();
}
@ -840,3 +869,15 @@ test "scalar field inversion" {
const recovered_x = inv.invert();
try std.testing.expectEqualSlices(u8, &bytes, &recovered_x.toBytes());
}
test "random scalar" {
const s1 = random();
const s2 = random();
try std.testing.expect(!mem.eql(u8, &s1, &s2));
}
test "64-bit reduction" {
const bytes = field_size ++ [_]u8{0} ** 32;
const x = Scalar.fromBytes64(bytes);
try std.testing.expect(x.isZero());
}