diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 3807ec7d79..1da25abe17 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -197,6 +197,7 @@ pub const pwhash = struct { pub const sign = struct { pub const Ed25519 = @import("crypto/25519/ed25519.zig").Ed25519; pub const ecdsa = @import("crypto/ecdsa.zig"); + pub const mldsa = @import("crypto/ml_dsa.zig"); }; /// Stream ciphers. These do not provide any kind of authentication. diff --git a/lib/std/crypto/benchmark.zig b/lib/std/crypto/benchmark.zig index 85c7820617..f295276210 100644 --- a/lib/std/crypto/benchmark.zig +++ b/lib/std/crypto/benchmark.zig @@ -166,6 +166,9 @@ const signatures = [_]Crypto{ Crypto{ .ty = crypto.sign.ecdsa.EcdsaP256Sha256, .name = "ecdsa-p256" }, Crypto{ .ty = crypto.sign.ecdsa.EcdsaP384Sha384, .name = "ecdsa-p384" }, Crypto{ .ty = crypto.sign.ecdsa.EcdsaSecp256k1Sha256, .name = "ecdsa-secp256k1" }, + Crypto{ .ty = crypto.sign.mldsa.MLDSA44, .name = "ml-dsa-44" }, + Crypto{ .ty = crypto.sign.mldsa.MLDSA65, .name = "ml-dsa-65" }, + Crypto{ .ty = crypto.sign.mldsa.MLDSA87, .name = "ml-dsa-87" }, }; pub fn benchmarkSignature(comptime Signature: anytype, comptime signatures_count: comptime_int) !u64 { @@ -189,7 +192,12 @@ pub fn benchmarkSignature(comptime Signature: anytype, comptime signatures_count return throughput; } -const signature_verifications = [_]Crypto{Crypto{ .ty = crypto.sign.Ed25519, .name = "ed25519" }}; +const signature_verifications = [_]Crypto{ + Crypto{ .ty = crypto.sign.Ed25519, .name = "ed25519" }, + Crypto{ .ty = crypto.sign.mldsa.MLDSA44, .name = "ml-dsa-44" }, + Crypto{ .ty = crypto.sign.mldsa.MLDSA65, .name = "ml-dsa-65" }, + Crypto{ .ty = crypto.sign.mldsa.MLDSA87, .name = "ml-dsa-87" }, +}; pub fn benchmarkSignatureVerification(comptime Signature: anytype, comptime signatures_count: comptime_int) !u64 { const msg = [_]u8{0} ** 64; diff --git a/lib/std/crypto/errors.zig b/lib/std/crypto/errors.zig index 83375a6901..c38f134fad 100644 --- a/lib/std/crypto/errors.zig +++ b/lib/std/crypto/errors.zig @@ -34,5 +34,8 @@ pub const WeakPublicKeyError = error{WeakPublicKey}; /// Point is not in the prime order group pub const UnexpectedSubgroupError = error{UnexpectedSubgroup}; +/// Context string is too long +pub const ContextTooLongError = error{ContextTooLong}; + /// Any error related to cryptography operations -pub const Error = AuthenticationError || OutputTooLongError || IdentityElementError || EncodingError || SignatureVerificationError || KeyMismatchError || NonCanonicalError || NotSquareError || PasswordVerificationError || WeakParametersError || WeakPublicKeyError || UnexpectedSubgroupError; +pub const Error = AuthenticationError || OutputTooLongError || IdentityElementError || EncodingError || SignatureVerificationError || KeyMismatchError || NonCanonicalError || NotSquareError || PasswordVerificationError || WeakParametersError || WeakPublicKeyError || UnexpectedSubgroupError || ContextTooLongError; diff --git a/lib/std/crypto/ml_dsa.zig b/lib/std/crypto/ml_dsa.zig new file mode 100644 index 0000000000..9a699e5439 --- /dev/null +++ b/lib/std/crypto/ml_dsa.zig @@ -0,0 +1,3598 @@ +//! Module-Lattice-Based Digital Signature Algorithm (ML-DSA) as specified in NIST FIPS 204. +//! +//! ML-DSA is a post-quantum secure digital signature scheme based on the hardness +//! of the Module Learning With Errors (MLWE) and Module Short Integer Solution (MSIS) +//! problems over module lattices. +//! +//! We provide three parameter sets: +//! +//! - ML-DSA-44: NIST security category 2 (128-bit security) +//! - ML-DSA-65: NIST security category 3 (192-bit security) +//! - ML-DSA-87: NIST security category 5 (256-bit security) + +const std = @import("std"); +const builtin = @import("builtin"); +const testing = std.testing; +const assert = std.debug.assert; +const crypto = std.crypto; +const errors = std.crypto.errors; +const math = std.math; +const mem = std.mem; +const sha3 = crypto.hash.sha3; + +const ContextTooLongError = errors.ContextTooLongError; +const EncodingError = errors.EncodingError; +const SignatureVerificationError = errors.SignatureVerificationError; + +/// ML-DSA-44 (Module-Lattice-Based Digital Signature Algorithm, 44 parameter set) +/// as specified in NIST FIPS 204. +/// +/// This is a post-quantum signature scheme providing NIST security category 2, +/// which is roughly equivalent to the security of SHA-256 or AES-128. +/// +/// Key sizes: +/// +/// - Public key: 1312 bytes +/// - Secret key: 2560 bytes +/// - Signature: 2420 bytes +/// +/// Example usage: +/// +/// ```zig +/// const kp = MLDSA44.KeyPair.generate(); +/// const msg = "Hello, post-quantum world!"; +/// const sig = try kp.sign(msg, null); +/// try sig.verify(msg, kp.public_key); +/// ``` +pub const MLDSA44 = MLDSAImpl(.{ + .name = "ML-DSA-44", + .k = 4, + .l = 4, + .eta = 2, + .omega = 80, + .tau = 39, + .gamma1_bits = 17, + .gamma2 = 95232, // (Q-1)/88 + .tr_size = 64, + .ctilde_size = 32, +}); + +/// ML-DSA-65 (Module-Lattice-Based Digital Signature Algorithm, 65 parameter set) +/// as specified in NIST FIPS 204. +/// +/// This is a post-quantum signature scheme providing NIST security category 3, +/// which is roughly equivalent to the security of SHA-384 or AES-192. +/// +/// Key sizes: +/// +/// - Public key: 1952 bytes +/// - Secret key: 4032 bytes +/// - Signature: 3309 bytes +/// +/// This parameter set offers higher security than ML-DSA-44 at the cost of +/// larger keys and signatures. +pub const MLDSA65 = MLDSAImpl(.{ + .name = "ML-DSA-65", + .k = 6, + .l = 5, + .eta = 4, + .omega = 55, + .tau = 49, + .gamma1_bits = 19, + .gamma2 = 261888, // (Q-1)/32 + .tr_size = 64, + .ctilde_size = 48, +}); + +/// ML-DSA-87 (Module-Lattice-Based Digital Signature Algorithm, 87 parameter set) +/// as specified in NIST FIPS 204. +/// +/// This is a post-quantum signature scheme providing NIST security category 5, +/// which is roughly equivalent to the security of SHA-512 or AES-256. +/// +/// Key sizes: +/// +/// - Public key: 2592 bytes +/// - Secret key: 4896 bytes +/// - Signature: 4627 bytes +/// +/// This parameter set offers the highest security level among the three ML-DSA +/// variants, suitable for applications requiring maximum security assurance. +pub const MLDSA87 = MLDSAImpl(.{ + .name = "ML-DSA-87", + .k = 8, + .l = 7, + .eta = 2, + .omega = 75, + .tau = 60, + .gamma1_bits = 19, + .gamma2 = 261888, // (Q-1)/32 + .tr_size = 64, + .ctilde_size = 64, +}); + +const N: usize = 256; // Degree of polynomials +const Q: u32 = 8380417; // Modulus: 2^23 - 2^13 + 1 +const Q_BITS: u32 = 23; +const D: u32 = 13; // Dropped bits in power2Round + +// Montgomery constant R = 2^32 mod q +const R: u64 = 1 << 32; + +// Q^(-1) mod 2^32 = -(q^-1) mod 2^32 +const Q_INV: u32 = 4236238847; + +// (256)^(-1) * R^2 mod q, used in inverse NTT +const R_OVER_256: u32 = 41978; + +// Primitive 512th root of unity +const ZETA: u32 = 1753; + +const Params = struct { + name: []const u8, + + // Matrix dimensions + k: u8, // Height of matrix A + l: u8, // Width of matrix A + + // Sampling parameter + eta: u8, // Bound for secret coefficients + + // Hint parameters + omega: u16, // Maximum number of hint bits + + // Challenge parameter + tau: u16, // Weight of challenge polynomial + + // Rounding parameters + gamma1_bits: u8, // Bits for gamma1 + gamma2: u32, // Parameter for decompose + + // Sizes + tr_size: usize, // Size of tr hash + ctilde_size: usize, // Size of challenge hash +}; + +const Poly = struct { + cs: [N]u32, + + const zero: Poly = .{ .cs = .{0} ** N }; + + // Add two polynomials (no normalization) + fn add(a: Poly, b: Poly) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = a.cs[i] + b.cs[i]; + } + return ret; + } + + // Subtract two polynomials (assumes b coefficients < 2q) + fn sub(a: Poly, b: Poly) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = a.cs[i] +% (@as(u32, 2 * Q) -% b.cs[i]); + } + return ret; + } + + // Reduce each coefficient to < 2q + fn reduceLe2Q(p: Poly) Poly { + var ret = p; + for (0..N) |i| { + ret.cs[i] = le2Q(ret.cs[i]); + } + return ret; + } + + // Normalize coefficients to [0, q) + fn normalize(p: Poly) Poly { + var ret = p; + for (0..N) |i| { + ret.cs[i] = modQ(ret.cs[i]); + } + return ret; + } + + // Normalize assuming coefficients already < 2q + fn normalizeAssumingLe2Q(p: Poly) Poly { + var ret = p; + for (0..N) |i| { + ret.cs[i] = le2qModQ(ret.cs[i]); + } + return ret; + } + + // Pointwise multiplication in NTT domain (Montgomery form) + fn mulHat(a: Poly, b: Poly) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = montReduceLe2Q(@as(u64, a.cs[i]) * @as(u64, b.cs[i])); + } + return ret; + } + + // Forward NTT + fn ntt(p: Poly) Poly { + var ret = p; + ret.nttInPlace(); + return ret; + } + + // In-place forward NTT + fn nttInPlace(p: *Poly) void { + var k: usize = 0; + var l: usize = N / 2; + + while (l > 0) : (l >>= 1) { + var offset: usize = 0; + while (offset < N - l) : (offset += 2 * l) { + k += 1; + const zeta: u64 = zetas[k]; + + for (offset..offset + l) |j| { + const t = montReduceLe2Q(zeta * @as(u64, p.cs[j + l])); + p.cs[j + l] = p.cs[j] +% (2 * Q -% t); + p.cs[j] +%= t; + } + } + } + } + + // Inverse NTT + fn invNTT(p: Poly) Poly { + var ret = p; + ret.invNTTInPlace(); + return ret; + } + + // In-place inverse NTT + fn invNTTInPlace(p: *Poly) void { + var k: usize = 0; + var l: usize = 1; + + while (l < N) : (l <<= 1) { + var offset: usize = 0; + while (offset < N - l) : (offset += 2 * l) { + const zeta: u64 = inv_zetas[k]; + k += 1; + + for (offset..offset + l) |j| { + const t = p.cs[j]; + p.cs[j] = t +% p.cs[j + l]; + p.cs[j + l] = montReduceLe2Q(zeta * @as(u64, t +% 256 * Q -% p.cs[j + l])); + } + } + } + + for (0..N) |j| { + p.cs[j] = montReduceLe2Q(@as(u64, R_OVER_256) * @as(u64, p.cs[j])); + } + } + + /// Apply Power2Round to all coefficients + /// Returns both t0 and t1 polynomials + fn power2RoundPoly(p: Poly) struct { t0: Poly, t1: Poly } { + var t0 = Poly.zero; + var t1 = Poly.zero; + for (0..N) |i| { + const result = power2Round(p.cs[i]); + t0.cs[i] = result.a0_plus_q; + t1.cs[i] = result.a1; + } + return .{ .t0 = t0, .t1 = t1 }; + } + + // Check if infinity norm exceeds bound + fn exceeds(p: Poly, bound: u32) bool { + var result: u32 = 0; + for (0..N) |i| { + const x = @as(i32, @intCast((Q - 1) / 2)) - @as(i32, @intCast(p.cs[i])); + const abs_x = x ^ (x >> 31); + const norm = @as(i32, @intCast((Q - 1) / 2)) - abs_x; + const exceeds_bit = @intFromBool(@as(u32, @intCast(norm)) >= bound); + result |= exceeds_bit; + } + return result != 0; + } +}; + +fn PolyVec(comptime len: u8) type { + return struct { + ps: [len]Poly, + + const Self = @This(); + const zero: Self = .{ .ps = .{Poly.zero} ** len }; + + /// Apply a unary operation to each polynomial in the vector + fn map(v: Self, comptime op: fn (Poly) Poly) Self { + var ret: Self = undefined; + inline for (0..len) |i| { + ret.ps[i] = op(v.ps[i]); + } + return ret; + } + + /// Apply a binary operation pairwise to two vectors + fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self { + var ret: Self = undefined; + inline for (0..len) |i| { + ret.ps[i] = op(a.ps[i], b.ps[i]); + } + return ret; + } + + /// Apply a binary operation between a vector and a scalar polynomial + fn mapBinaryPoly(v: Self, scalar: Poly, comptime op: fn (Poly, Poly) Poly) Self { + var ret: Self = undefined; + inline for (0..len) |i| { + ret.ps[i] = op(v.ps[i], scalar); + } + return ret; + } + + fn add(a: Self, b: Self) Self { + return mapBinary(a, b, Poly.add); + } + + fn sub(a: Self, b: Self) Self { + return mapBinary(a, b, Poly.sub); + } + + fn ntt(v: Self) Self { + return map(v, Poly.ntt); + } + + fn invNTT(v: Self) Self { + return map(v, Poly.invNTT); + } + + fn normalize(v: Self) Self { + return map(v, Poly.normalize); + } + + fn reduceLe2Q(v: Self) Self { + return map(v, Poly.reduceLe2Q); + } + + fn normalizeAssumingLe2Q(v: Self) Self { + return map(v, Poly.normalizeAssumingLe2Q); + } + + // Check if any polynomial in the vector exceeds the bound + fn exceeds(v: Self, bound: u32) bool { + var result = false; + for (0..len) |i| { + result = result or v.ps[i].exceeds(bound); + } + return result; + } + + /// Apply Power2Round to each polynomial in the vector + /// Returns both t0 and t1 vectors + fn power2Round(v: Self, t0_out: *Self) Self { + var t1: Self = undefined; + for (0..len) |i| { + const result = v.ps[i].power2RoundPoly(); + t0_out.ps[i] = result.t0; + t1.ps[i] = result.t1; + } + return t1; + } + + /// Generic packing function for vectors + fn packWith( + v: Self, + buf: []u8, + comptime poly_size: usize, + comptime pack_fn: fn (Poly, []u8) void, + ) void { + inline for (0..len) |i| { + const offset = i * poly_size; + pack_fn(v.ps[i], buf[offset..][0..poly_size]); + } + } + + /// Generic unpacking function for vectors + fn unpackWith( + comptime poly_size: usize, + comptime unpack_fn: fn ([]const u8) Poly, + buf: []const u8, + ) Self { + var result: Self = undefined; + inline for (0..len) |i| { + const offset = i * poly_size; + result.ps[i] = unpack_fn(buf[offset..][0..poly_size]); + } + return result; + } + + /// Pack T1 vector to bytes + fn packT1(v: Self, buf: []u8) void { + const poly_size = (N * (Q_BITS - D)) / 8; + packWith(v, buf, poly_size, polyPackT1); + } + + /// Unpack T1 vector from bytes + fn unpackT1(bytes: []const u8) Self { + const poly_size = (N * (Q_BITS - D)) / 8; + return unpackWith(poly_size, polyUnpackT1, bytes); + } + + /// Pack T0 vector to bytes + fn packT0(v: Self, buf: []u8) void { + const poly_size = (N * D) / 8; + packWith(v, buf, poly_size, polyPackT0); + } + + /// Unpack T0 vector from bytes + fn unpackT0(buf: []const u8) Self { + const poly_size = (N * D) / 8; + return unpackWith(poly_size, polyUnpackT0, buf); + } + + /// Pack vector with coefficients in [-eta, eta] + fn packLeqEta(v: Self, comptime eta: u8, buf: []u8) void { + const poly_size = if (eta == 2) 96 else 128; + const pack_fn = struct { + fn pack(p: Poly, b: []u8) void { + polyPackLeqEta(p, eta, b); + } + }.pack; + packWith(v, buf, poly_size, pack_fn); + } + + /// Unpack vector with coefficients in [-eta, eta] + fn unpackLeqEta(comptime eta: u8, buf: []const u8) Self { + const poly_size = if (eta == 2) 96 else 128; + const unpack_fn = struct { + fn unpack(b: []const u8) Poly { + return polyUnpackLeqEta(eta, b); + } + }.unpack; + return unpackWith(poly_size, unpack_fn, buf); + } + + /// Pack vector of polynomials with coefficients < gamma1 + fn packLeGamma1(v: Self, comptime gamma1_bits: u8, buf: []u8) void { + const poly_size = ((gamma1_bits + 1) * N) / 8; + const pack_fn = struct { + fn pack(p: Poly, b: []u8) void { + polyPackLeGamma1(p, gamma1_bits, b); + } + }.pack; + packWith(v, buf, poly_size, pack_fn); + } + + /// Unpack vector of polynomials with coefficients < gamma1 + fn unpackLeGamma1(comptime gamma1_bits: u8, buf: []const u8) Self { + const poly_size = ((gamma1_bits + 1) * N) / 8; + const unpack_fn = struct { + fn unpack(b: []const u8) Poly { + return polyUnpackLeGamma1(gamma1_bits, b); + } + }.unpack; + return unpackWith(poly_size, unpack_fn, buf); + } + + /// Pack high bits w1 for signature verification + fn packW1(v: Self, comptime gamma1_bits: u8, buf: []u8) void { + const poly_size = (N * (Q_BITS - gamma1_bits)) / 8; + const pack_fn = struct { + fn pack(p: Poly, b: []u8) void { + polyPackW1(p, gamma1_bits, b); + } + }.pack; + packWith(v, buf, poly_size, pack_fn); + } + + /// Decompose each polynomial in the vector into high and low bits + fn decomposeVec(v: Self, comptime gamma2: u32, w0_out: *Self) Self { + var w1: Self = undefined; + for (0..len) |i| { + for (0..N) |j| { + const r = decompose(v.ps[i].cs[j], gamma2); + w0_out.ps[i].cs[j] = r.a0_plus_q; + w1.ps[i].cs[j] = r.a1; + } + } + return w1; + } + + /// Create hints for vector, returns hint population count + fn makeHintVec(w0mcs2pct0: Self, w1: Self, comptime gamma2: u32) struct { hint: Self, pop: u32 } { + var hint: Self = undefined; + var pop: u32 = 0; + for (0..len) |i| { + const result = polyMakeHint(w0mcs2pct0.ps[i], w1.ps[i], gamma2); + hint.ps[i] = result.hint; + pop += result.count; + } + return .{ .hint = hint, .pop = pop }; + } + + /// Apply hints to recover high bits + fn useHint(v: Self, hint: Self, comptime gamma2: u32) Self { + var result: Self = undefined; + for (0..len) |i| { + result.ps[i] = polyUseHint(v.ps[i], hint.ps[i], gamma2); + } + return result; + } + + /// Multiply vector by 2^D (left shift) + fn mulBy2toD(v: Self) Self { + var result: Self = undefined; + for (0..len) |i| { + for (0..N) |j| { + result.ps[i].cs[j] = v.ps[i].cs[j] << D; + } + } + return result; + } + + /// Sample vector with coefficients uniformly in (-gamma1, gamma1] + /// Wraps expandMask (FIPS 204: ExpandMask) + fn deriveUniformLeGamma1(comptime gamma1_bits: u8, seed: *const [64]u8, nonce: u16) Self { + var result: Self = undefined; + for (0..len) |i| { + result.ps[i] = expandMask(gamma1_bits, seed, nonce + @as(u16, @intCast(i))); + } + return result; + } + + /// Pack hints into bytes + /// Format: for each polynomial, find positions where hint[i]=1, encode those positions + fn packHint(v: Self, comptime omega: u16, buf: []u8) bool { + var idx: usize = 0; + var count: u32 = 0; + + for (0..len) |i| { + for (0..N) |j| { + if (v.ps[i].cs[j] != 0) { + count += 1; + } + } + } + + if (count > omega) { + return false; + } + + // Hint encoding format per FIPS 204: + // First omega bytes: positions of set bits across all polynomials + // Last len bytes: boundary indices showing where each polynomial's hints end + for (0..len) |i| { + for (0..N) |j| { + if (v.ps[i].cs[j] != 0) { + buf[idx] = @intCast(j); + idx += 1; + } + } + buf[omega + i] = @intCast(idx); + } + + while (idx < omega) : (idx += 1) { + buf[idx] = 0; + } + + return true; + } + + /// Unpack hints from bytes + fn unpackHint(comptime omega: u16, buf: []const u8) ?Self { + var result: Self = .{ .ps = .{Poly.zero} ** len }; + var prev_sop: u8 = 0; // previous switch-over-point + + for (0..len) |i| { + const sop = buf[omega + i]; // switch-over-point + if (sop < prev_sop or sop > omega) { + return null; // ensures switch-over-points are increasing + } + + var j = prev_sop; + while (j < sop) : (j += 1) { + // Validation: indices must be strictly increasing within each polynomial + if (j > prev_sop and buf[j] <= buf[j - 1]) { + return null; + } + const pos = buf[j]; + if (pos >= N) { + return null; + } + result.ps[i].cs[pos] = 1; + } + prev_sop = sop; + } + + var j = prev_sop; + while (j < omega) : (j += 1) { + if (buf[j] != 0) { + return null; + } + } + + return result; + } + }; +} + +// Matrix of k x l polynomials + +fn Mat(comptime k: u8, comptime l: u8) type { + return struct { + rows: [k]PolyVec(l), + + const Self = @This(); + const VecL = PolyVec(l); + const VecK = PolyVec(k); + + /// Expand matrix A from seed rho using SHAKE-128 + /// This is the ExpandA function from FIPS 204 + fn derive(rho: *const [32]u8) Self { + var m: Self = undefined; + for (0..k) |i| { + if (i + 1 < k) { + @prefetch(&m.rows[i + 1], .{ .rw = .write, .locality = 2 }); + } + for (0..l) |j| { + // Nonce is i*256 + j + const nonce: u16 = (@as(u16, @intCast(i)) << 8) | @as(u16, @intCast(j)); + m.rows[i].ps[j] = polyDeriveUniform(rho, nonce); + } + } + return m; + } + + /// Multiply matrix by vector in NTT domain and return result in regular domain. + /// Takes a vector in NTT form and returns the product in regular form. + fn mulVec(self: Self, v_hat: VecL) VecK { + var result = VecK.zero; + for (0..k) |i| { + result.ps[i] = dotHat(l, self.rows[i], v_hat); + result.ps[i] = result.ps[i].reduceLe2Q(); + result.ps[i] = result.ps[i].invNTT(); + } + return result; + } + + /// Multiply matrix by vector in NTT domain and return result in NTT domain. + /// Takes a vector in NTT form and returns the product in NTT form. + fn mulVecHat(self: Self, v_hat: VecL) VecK { + var result: VecK = undefined; + for (0..k) |i| { + result.ps[i] = dotHat(l, self.rows[i], v_hat); + } + return result; + } + }; +} + +// Dot product in NTT domain +fn dotHat(comptime len: u8, a: PolyVec(len), b: PolyVec(len)) Poly { + var ret = Poly.zero; + for (0..len) |i| { + const prod = a.ps[i].mulHat(b.ps[i]); + ret = ret.add(prod); + } + return ret; +} + +// Modular arithmetic operations + +// Reduce x to [0, 2q) using the fact that 2^23 = 2^13 - 1 (mod q) +fn le2Q(x: u32) u32 { + // Write x = x1 * 2^23 + x2 with x2 < 2^23 and x1 < 2^9 + // Then x = x2 + x1 * 2^13 - x1 (mod q) + // and x2 + x1 * 2^13 - x1 <= 2^23 + 2^13 < 2q + const x1 = x >> 23; + const x2 = x & 0x7FFFFF; // 2^23 - 1 + return x2 +% (x1 << 13) -% x1; +} + +// Reduce x to [0, q) +fn modQ(x: u32) u32 { + return le2qModQ(le2Q(x)); +} + +// Given x < 2q, reduce to [0, q) +fn le2qModQ(x: u32) u32 { + const r = x -% Q; + const mask = signMask(u32, r); + return r +% (mask & Q); +} + +// Montgomery reduction: for x < q*2^32, return y < 2q where y ≡ x*R^(-1) (mod q) +// where R = 2^32. This is used for efficient modular multiplication in NTT operations. +fn montReduceLe2Q(x: u64) u32 { + const m = (x *% Q_INV) & 0xffffffff; + return @truncate((x +% m * @as(u64, Q)) >> 32); +} + +// Precomputed zetas for NTT (Montgomery form) +// zetas[i] = zeta^brv(i) * R mod q +const zetas = computeZetas(); + +fn computeZetas() [N]u32 { + @setEvalBranchQuota(100000); + var ret: [N]u32 = undefined; + + for (0..N) |i| { + const brv_i = @bitReverse(@as(u8, @intCast(i))); + const power = modularPow(u32, ZETA, brv_i, Q); + ret[i] = toMont(power); + } + + return ret; +} + +// Precomputed inverse zetas for inverse NTT +const inv_zetas = computeInvZetas(); + +fn computeInvZetas() [N]u32 { + @setEvalBranchQuota(100000); + var ret: [N]u32 = undefined; + + const inv_zeta = modularInverse(u32, ZETA, Q); + + for (0..N) |i| { + const idx = 255 - i; + const brv_idx = @bitReverse(@as(u8, @intCast(idx))); + + // Exponent is -(brv_idx - 256) = 256 - brv_idx + const exp: u32 = @as(u32, 256) - brv_idx; + + // Compute inv_zeta^exp + const power = modularPow(u32, inv_zeta, exp, Q); + + // Convert to Montgomery form + ret[i] = toMont(power); + } + + return ret; +} + +// Convert to Montgomery form: x -> x * R mod q +fn toMont(x: u32) u32 { + // R = 2^32, R mod q can be computed as: + // 2^32 mod q = 2^32 mod (2^23 - 2^13 + 1) + // Using the identity 2^23 = 2^13 - 1 (mod q), we can reduce 2^32 + // But it's easier to just do: return montReduce(x * R^2 mod q) + // where R^2 mod q is precomputed + + // Computing R^2 mod q: + // R = 2^32, so R^2 = 2^64 + // We can compute this by noting that R mod q first: + // 2^32 = 2^32 mod q + // But let's use a simpler approach: multiply x by R in the Montgomery domain + // Actually, the simplest is: x * R mod q = montReduceLe2Q(x * R^2 mod q) + + // Precompute R^2 mod q at comptime + const r_mod_q = comptime blk: { + // 2^32 mod q - compute by successive squaring + var r: u64 = 1; + for (0..32) |_| { + r = (r * 2) % Q; + } + break :blk @as(u32, @intCast(r)); + }; + + const r2_mod_q = comptime blk: { + const r = @as(u64, r_mod_q); + break :blk @as(u32, @intCast((r * r) % Q)); + }; + + return montReduceLe2Q(@as(u64, x) * @as(u64, r2_mod_q)); +} + +/// Splits 0 ≤ a < Q into a0 and a1 with a = a1*2^D + a0 +/// and -2^(D-1) < a0 ≤ 2^(D-1). Returns a0 + Q and a1. +/// FIPS 204: Power2Round (Algorithm 19) +fn power2Round(a: u32) struct { a0_plus_q: u32, a1: u32 } { + // We effectively compute a0 = a mod± 2^D + // and a1 = (a - a0) / 2^D + var a0 = a & ((1 << D) - 1); // a mod 2^D + + // a0 is one of 0, 1, ..., 2^(D-1)-1, 2^(D-1), 2^(D-1)+1, ..., 2^D-1 + a0 -%= (1 << (D - 1)) + 1; + // now a0 is -2^(D-1)-1, -2^(D-1), ..., -2, -1, 0, ..., 2^(D-1)-2 + + // Next, add 2^D to those a0 that are negative (seen as i32) + a0 +%= @as(u32, @bitCast(@as(i32, @bitCast(a0)) >> 31)) & (1 << D); + // now a0 is 2^(D-1)-1, 2^(D-1), ..., 2^D-2, 2^D-1, 0, ..., 2^(D-1)-2 + + a0 -%= (1 << (D - 1)) - 1; + // now a0 is 0, 1, 2, ..., 2^(D-1)-1, 2^(D-1), -2^(D-1)+1, ..., -1 + + const a0_plus_q = Q +% a0; + const a1 = (a -% a0) >> D; + + return .{ .a0_plus_q = a0_plus_q, .a1 = a1 }; +} + +/// Splits 0 ≤ a < q into a0 and a1 with a = a1*alpha + a0 with -alpha/2 < a0 ≤ alpha/2, +/// except when we would have a1 = (q-1)/alpha in which case a1=0 is taken +/// and -alpha/2 ≤ a0 < 0. Returns a0 + q. Note 0 ≤ a1 < (q-1)/alpha. +/// Recall alpha = 2*gamma2. +fn decompose(a: u32, comptime gamma2: u32) struct { a0_plus_q: u32, a1: u32 } { + const alpha = 2 * gamma2; + + // a1 = ⌈a / 128⌉ + var a1 = (a + 127) >> 7; + + if (alpha == 523776) { + // For ML-DSA-87: gamma2 = 261888, alpha = 523776 + // 1025/2^22 is close enough to 1/4092 so that a1 becomes a/alpha rounded down + a1 = ((a1 * 1025 + (1 << 21)) >> 22); + + // For the corner-case a1 = (q-1)/alpha = 16, we have to set a1=0 + a1 &= 15; + } else if (alpha == 190464) { + // For ML-DSA-65: gamma2 = 95232, alpha = 190464 + // 11275/2^24 is close enough to 1/1488 so that a1 becomes a/alpha rounded down + a1 = ((a1 * 11275) + (1 << 23)) >> 24; + + // For the corner-case a1 = (q-1)/alpha = 44, we have to set a1=0 + a1 ^= @as(u32, @bitCast(@as(i32, @bitCast(43 -% a1)) >> 31)) & a1; + } else { + @compileError("unsupported gamma2/alpha value"); + } + + var a0_plus_q = a -% a1 * alpha; + + // In the corner-case, when we set a1=0, we will incorrectly + // have a0 > (q-1)/2 and we'll need to subtract q. As we + // return a0 + q, that comes down to adding q if a0 < (q-1)/2. + a0_plus_q +%= @as(u32, @bitCast(@as(i32, @bitCast(a0_plus_q -% (Q - 1) / 2)) >> 31)) & Q; + + return .{ .a0_plus_q = a0_plus_q, .a1 = a1 }; +} + +/// Creates a hint bit to help recover high bits after a small perturbation. +/// Given: +/// - z0: the modified low bits (r0 - f mod Q) where f is small +/// - r1: the original high bits +/// Returns 1 if a hint is needed, 0 otherwise. +/// +/// This implements makeHint from FIPS 204. The hint helps recover r1 from +/// r' = r - f without knowing f explicitly. +fn makeHint(z0: u32, r1: u32, comptime gamma2: u32) u32 { + // If -alpha/2 < r0 - f <= alpha/2, then r1*alpha + r0 - f is a valid + // decomposition of r' with the restrictions of decompose() and so r'1 = r1. + // So the hint should be 0. This is covered by the first two inequalities. + // There is one other case: if r0 - f = -alpha/2, then r1*alpha + r0 - f is + // also a valid decomposition if r1 = 0. In the other cases a one is carried + // and the hint should be 1. + + const cond1 = @intFromBool(z0 <= gamma2); + const cond2 = @intFromBool(z0 > Q - gamma2); + const eq_gamma2 = @intFromBool(z0 == Q - gamma2); + const r1_is_zero = @intFromBool(r1 == 0); + const cond3 = eq_gamma2 & r1_is_zero; + + return 1 - (cond1 | cond2 | cond3); +} + +/// Uses a hint to reconstruct high bits from a perturbed value. +/// Given: +/// - rp: the perturbed value (r' = r - f) +/// - hint: the hint bit from makeHint +/// Returns the reconstructed high bits r1. +/// +/// This implements useHint from FIPS 204. +fn useHint(rp: u32, hint: u32, comptime gamma2: u32) u32 { + const decomp = decompose(rp, gamma2); + const rp0_plus_q = decomp.a0_plus_q; + var rp1 = decomp.a1; + + if (hint == 0) { + return rp1; + } + + // Depending on gamma2, handle the adjustment differently + if (gamma2 == 261888) { + // ML-DSA-65 and ML-DSA-87: max r1 is 15 + if (rp0_plus_q > Q) { + rp1 = (rp1 + 1) & 15; + } else { + rp1 = (rp1 -% 1) & 15; + } + } else if (gamma2 == 95232) { + // ML-DSA-44: max r1 is 43 + if (rp0_plus_q > Q) { + if (rp1 == 43) { + rp1 = 0; + } else { + rp1 += 1; + } + } else { + if (rp1 == 0) { + rp1 = 43; + } else { + rp1 -= 1; + } + } + } else { + @compileError("unsupported gamma2 value"); + } + + return rp1; +} + +/// Creates a hint polynomial for the difference between perturbed and original high bits. +/// Returns the number of hint bits set to 1 (the population count). +/// +/// This is used during signature generation to create hints that help verification +/// recover the high bits without access to the secret. +fn polyMakeHint(p0: Poly, p1: Poly, comptime gamma2: u32) struct { hint: Poly, count: u32 } { + var hint = Poly.zero; + var count: u32 = 0; + + for (0..N) |i| { + const h = makeHint(p0.cs[i], p1.cs[i], gamma2); + hint.cs[i] = h; + count += h; + } + + return .{ .hint = hint, .count = count }; +} + +/// Applies hints to reconstruct high bits from a perturbed polynomial. +/// +/// This is used during signature verification to recover the high bits +/// using the hints provided in the signature. +fn polyUseHint(q: Poly, hint: Poly, comptime gamma2: u32) Poly { + var result = Poly.zero; + + for (0..N) |i| { + result.cs[i] = useHint(q.cs[i], hint.cs[i], gamma2); + } + + return result; +} + +/// Pack polynomial with coefficients in [Q-eta, Q+eta] into bytes. +/// For eta=2: packs coefficients into 3 bits each (96 bytes total) +/// For eta=4: packs coefficients into 4 bits each (128 bytes total) +/// Assumes coefficients are not normalized, but in [q-η, q+η]. +fn polyPackLeqEta(p: Poly, comptime eta: u8, buf: []u8) void { + comptime { + if (eta != 2 and eta != 4) { + @compileError("eta must be 2 or 4"); + } + } + + if (eta == 2) { + // 3 bits per coefficient: pack 8 coefficients into 3 bytes + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 3) { + const c0 = Q + eta - p.cs[j]; + const c1 = Q + eta - p.cs[j + 1]; + const c2 = Q + eta - p.cs[j + 2]; + const c3 = Q + eta - p.cs[j + 3]; + const c4 = Q + eta - p.cs[j + 4]; + const c5 = Q + eta - p.cs[j + 5]; + const c6 = Q + eta - p.cs[j + 6]; + const c7 = Q + eta - p.cs[j + 7]; + + buf[i] = @truncate(c0 | (c1 << 3) | (c2 << 6)); + buf[i + 1] = @truncate((c2 >> 2) | (c3 << 1) | (c4 << 4) | (c5 << 7)); + buf[i + 2] = @truncate((c5 >> 1) | (c6 << 2) | (c7 << 5)); + + j += 8; + } + } else { // eta == 4 + // 4 bits per coefficient: pack 2 coefficients into 1 byte + var j: usize = 0; + for (0..buf.len) |i| { + const c0 = Q + eta - p.cs[j]; + const c1 = Q + eta - p.cs[j + 1]; + buf[i] = @truncate(c0 | (c1 << 4)); + j += 2; + } + } +} + +/// Unpack polynomial with coefficients in [Q-eta, Q+eta] from bytes. +/// Output coefficients will not be normalized, but in [q-η, q+η]. +fn polyUnpackLeqEta(comptime eta: u8, buf: []const u8) Poly { + comptime { + if (eta != 2 and eta != 4) { + @compileError("eta must be 2 or 4"); + } + } + + var p = Poly.zero; + + if (eta == 2) { + // 3 bits per coefficient: unpack 8 coefficients from 3 bytes + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 3) { + p.cs[j] = Q + eta - (buf[i] & 7); + p.cs[j + 1] = Q + eta - ((buf[i] >> 3) & 7); + p.cs[j + 2] = Q + eta - ((buf[i] >> 6) | ((buf[i + 1] << 2) & 7)); + p.cs[j + 3] = Q + eta - ((buf[i + 1] >> 1) & 7); + p.cs[j + 4] = Q + eta - ((buf[i + 1] >> 4) & 7); + p.cs[j + 5] = Q + eta - ((buf[i + 1] >> 7) | ((buf[i + 2] << 1) & 7)); + p.cs[j + 6] = Q + eta - ((buf[i + 2] >> 2) & 7); + p.cs[j + 7] = Q + eta - ((buf[i + 2] >> 5) & 7); + j += 8; + } + } else { // eta == 4 + // 4 bits per coefficient: unpack 2 coefficients from 1 byte + var j: usize = 0; + for (0..buf.len) |i| { + p.cs[j] = Q + eta - (buf[i] & 15); + p.cs[j + 1] = Q + eta - (buf[i] >> 4); + j += 2; + } + } + + return p; +} + +/// Pack polynomial with coefficients < 1024 (T1) into bytes. +/// Packs 10 bits per coefficient: 4 coefficients into 5 bytes. +/// Assumes coefficients are normalized. +fn polyPackT1(p: Poly, buf: []u8) void { + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 5) { + buf[i] = @truncate(p.cs[j]); + buf[i + 1] = @truncate((p.cs[j] >> 8) | (p.cs[j + 1] << 2)); + buf[i + 2] = @truncate((p.cs[j + 1] >> 6) | (p.cs[j + 2] << 4)); + buf[i + 3] = @truncate((p.cs[j + 2] >> 4) | (p.cs[j + 3] << 6)); + buf[i + 4] = @truncate(p.cs[j + 3] >> 2); + j += 4; + } +} + +/// Unpack polynomial with coefficients < 1024 (T1) from bytes. +/// Output coefficients will be normalized. +fn polyUnpackT1(buf: []const u8) Poly { + var p = Poly.zero; + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 5) { + p.cs[j] = (@as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8)) & 0x3ff; + p.cs[j + 1] = ((@as(u32, buf[i + 1]) >> 2) | (@as(u32, buf[i + 2]) << 6)) & 0x3ff; + p.cs[j + 2] = ((@as(u32, buf[i + 2]) >> 4) | (@as(u32, buf[i + 3]) << 4)) & 0x3ff; + p.cs[j + 3] = ((@as(u32, buf[i + 3]) >> 6) | (@as(u32, buf[i + 4]) << 2)) & 0x3ff; + j += 4; + } + return p; +} + +/// Pack polynomial with coefficients in (-2^(D-1), 2^(D-1)] (T0) into bytes. +/// Packs 13 bits per coefficient: 8 coefficients into 13 bytes. +/// Assumes coefficients are not normalized, but in (q-2^(D-1), q+2^(D-1)]. +fn polyPackT0(p: Poly, buf: []u8) void { + const bound = 1 << (D - 1); + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 13) { + const p0 = Q + bound - p.cs[j]; + const p1 = Q + bound - p.cs[j + 1]; + const p2 = Q + bound - p.cs[j + 2]; + const p3 = Q + bound - p.cs[j + 3]; + const p4 = Q + bound - p.cs[j + 4]; + const p5 = Q + bound - p.cs[j + 5]; + const p6 = Q + bound - p.cs[j + 6]; + const p7 = Q + bound - p.cs[j + 7]; + + buf[i] = @truncate(p0 >> 0); + buf[i + 1] = @truncate((p0 >> 8) | (p1 << 5)); + buf[i + 2] = @truncate(p1 >> 3); + buf[i + 3] = @truncate((p1 >> 11) | (p2 << 2)); + buf[i + 4] = @truncate((p2 >> 6) | (p3 << 7)); + buf[i + 5] = @truncate(p3 >> 1); + buf[i + 6] = @truncate((p3 >> 9) | (p4 << 4)); + buf[i + 7] = @truncate(p4 >> 4); + buf[i + 8] = @truncate((p4 >> 12) | (p5 << 1)); + buf[i + 9] = @truncate((p5 >> 7) | (p6 << 6)); + buf[i + 10] = @truncate(p6 >> 2); + buf[i + 11] = @truncate((p6 >> 10) | (p7 << 3)); + buf[i + 12] = @truncate(p7 >> 5); + + j += 8; + } +} + +/// Unpack polynomial with coefficients in (-2^(D-1), 2^(D-1)] (T0) from bytes. +/// Output coefficients will not be normalized, but in (-2^(D-1), 2^(D-1)]. +fn polyUnpackT0(buf: []const u8) Poly { + const bound = 1 << (D - 1); + var p = Poly.zero; + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 13) { + p.cs[j] = Q + bound - ((@as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8)) & 0x1fff); + p.cs[j + 1] = Q + bound - (((@as(u32, buf[i + 1]) >> 5) | (@as(u32, buf[i + 2]) << 3) | (@as(u32, buf[i + 3]) << 11)) & 0x1fff); + p.cs[j + 2] = Q + bound - (((@as(u32, buf[i + 3]) >> 2) | (@as(u32, buf[i + 4]) << 6)) & 0x1fff); + p.cs[j + 3] = Q + bound - (((@as(u32, buf[i + 4]) >> 7) | (@as(u32, buf[i + 5]) << 1) | (@as(u32, buf[i + 6]) << 9)) & 0x1fff); + p.cs[j + 4] = Q + bound - (((@as(u32, buf[i + 6]) >> 4) | (@as(u32, buf[i + 7]) << 4) | (@as(u32, buf[i + 8]) << 12)) & 0x1fff); + p.cs[j + 5] = Q + bound - (((@as(u32, buf[i + 8]) >> 1) | (@as(u32, buf[i + 9]) << 7)) & 0x1fff); + p.cs[j + 6] = Q + bound - (((@as(u32, buf[i + 9]) >> 6) | (@as(u32, buf[i + 10]) << 2) | (@as(u32, buf[i + 11]) << 10)) & 0x1fff); + p.cs[j + 7] = Q + bound - ((@as(u32, buf[i + 11]) >> 3) | (@as(u32, buf[i + 12]) << 5)); + j += 8; + } + return p; +} + +/// Convert coefficient from centered representation to non-negative. +/// Transforms value from [0,γ₁] ∪ (Q-γ₁, Q) to [0, 2γ₁). +fn centeredToPositive(val: u32, comptime gamma1: u32) u32 { + var result = gamma1 -% val; + result +%= (signMask(u32, result) & Q); + return result; +} + +/// Pack polynomial with coefficients in (-gamma1, gamma1] into bytes. +/// For gamma1_bits=17: packs 18 bits per coefficient (4 coefficients into 9 bytes) +/// For gamma1_bits=19: packs 20 bits per coefficient (2 coefficients into 5 bytes) +/// Assumes coefficients are normalized. +fn polyPackLeGamma1(p: Poly, comptime gamma1_bits: u8, buf: []u8) void { + const gamma1: u32 = @as(u32, 1) << gamma1_bits; + + if (gamma1_bits == 17) { + // Pack 4 coefficients into 9 bytes (18 bits each) + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 9) { + // Convert from [0,γ₁] ∪ (Q-γ₁, Q) to [0, 2γ₁) + const p0 = centeredToPositive(p.cs[j], gamma1); + const p1 = centeredToPositive(p.cs[j + 1], gamma1); + const p2 = centeredToPositive(p.cs[j + 2], gamma1); + const p3 = centeredToPositive(p.cs[j + 3], gamma1); + + buf[i] = @truncate(p0); + buf[i + 1] = @truncate(p0 >> 8); + buf[i + 2] = @truncate((p0 >> 16) | (p1 << 2)); + buf[i + 3] = @truncate(p1 >> 6); + buf[i + 4] = @truncate((p1 >> 14) | (p2 << 4)); + buf[i + 5] = @truncate(p2 >> 4); + buf[i + 6] = @truncate((p2 >> 12) | (p3 << 6)); + buf[i + 7] = @truncate(p3 >> 2); + buf[i + 8] = @truncate(p3 >> 10); + + j += 4; + } + } else if (gamma1_bits == 19) { + // Pack 2 coefficients into 5 bytes (20 bits each) + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 5) { + const p0 = centeredToPositive(p.cs[j], gamma1); + const p1 = centeredToPositive(p.cs[j + 1], gamma1); + + buf[i] = @truncate(p0); + buf[i + 1] = @truncate(p0 >> 8); + buf[i + 2] = @truncate((p0 >> 16) | (p1 << 4)); + buf[i + 3] = @truncate(p1 >> 4); + buf[i + 4] = @truncate(p1 >> 12); + + j += 2; + } + } else { + @compileError("gamma1_bits must be 17 or 19"); + } +} + +/// Unpack polynomial with coefficients in (-gamma1, gamma1] from bytes. +/// Output coefficients will be normalized. +fn polyUnpackLeGamma1(comptime gamma1_bits: u8, buf: []const u8) Poly { + const gamma1: u32 = @as(u32, 1) << gamma1_bits; + var p = Poly.zero; + + if (gamma1_bits == 17) { + // Unpack 4 coefficients from 9 bytes (18 bits each) + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 9) { + var p0 = @as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8) | ((@as(u32, buf[i + 2]) & 0x3) << 16); + var p1 = (@as(u32, buf[i + 2]) >> 2) | (@as(u32, buf[i + 3]) << 6) | ((@as(u32, buf[i + 4]) & 0xf) << 14); + var p2 = (@as(u32, buf[i + 4]) >> 4) | (@as(u32, buf[i + 5]) << 4) | ((@as(u32, buf[i + 6]) & 0x3f) << 12); + var p3 = (@as(u32, buf[i + 6]) >> 6) | (@as(u32, buf[i + 7]) << 2) | (@as(u32, buf[i + 8]) << 10); + + // Convert from [0, 2γ₁) to (-γ₁, γ₁] + p0 = centeredToPositive(p0, gamma1); + p1 = centeredToPositive(p1, gamma1); + p2 = centeredToPositive(p2, gamma1); + p3 = centeredToPositive(p3, gamma1); + + p.cs[j] = p0; + p.cs[j + 1] = p1; + p.cs[j + 2] = p2; + p.cs[j + 3] = p3; + + j += 4; + } + } else if (gamma1_bits == 19) { + // Unpack 2 coefficients from 5 bytes (20 bits each) + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 5) { + var p0 = @as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8) | ((@as(u32, buf[i + 2]) & 0xf) << 16); + var p1 = (@as(u32, buf[i + 2]) >> 4) | (@as(u32, buf[i + 3]) << 4) | (@as(u32, buf[i + 4]) << 12); + + p0 = centeredToPositive(p0, gamma1); + p1 = centeredToPositive(p1, gamma1); + + p.cs[j] = p0; + p.cs[j + 1] = p1; + + j += 2; + } + } else { + @compileError("gamma1_bits must be 17 or 19"); + } + + return p; +} + +/// Pack W1 polynomial for verification. +/// For gamma1_bits=17: packs 6 bits per coefficient (4 coefficients into 3 bytes) +/// For gamma1_bits=19: packs 4 bits per coefficient (2 coefficients into 1 byte) +/// Assumes coefficients are normalized. +fn polyPackW1(p: Poly, comptime gamma1_bits: u8, buf: []u8) void { + if (gamma1_bits == 17) { + // Pack 4 coefficients into 3 bytes (6 bits each) + var j: usize = 0; + var i: usize = 0; + while (i < buf.len) : (i += 3) { + buf[i] = @truncate(p.cs[j] | (p.cs[j + 1] << 6)); + buf[i + 1] = @truncate((p.cs[j + 1] >> 2) | (p.cs[j + 2] << 4)); + buf[i + 2] = @truncate((p.cs[j + 2] >> 4) | (p.cs[j + 3] << 2)); + j += 4; + } + } else if (gamma1_bits == 19) { + // Pack 2 coefficients into 1 byte (4 bits each) - equivalent to packLe16 + var j: usize = 0; + for (0..buf.len) |i| { + buf[i] = @truncate(p.cs[j] | (p.cs[j + 1] << 4)); + j += 2; + } + } else { + @compileError("gamma1_bits must be 17 or 19"); + } +} + +fn polyDeriveUniform(seed: *const [32]u8, nonce: u16) Poly { + var domain_sep: [2]u8 = undefined; + domain_sep[0] = @truncate(nonce); + domain_sep[1] = @truncate(nonce >> 8); + + return sampleUniformRejection( + Poly, + Q, + 23, + N, + seed, + &domain_sep, + ); +} + +/// Sample p uniformly with coefficients of norm less than or equal to η, +/// using the given seed and nonce with SHAKE-256. +/// The polynomial will not be normalized, but will have coefficients in [q-η, q+η]. +/// FIPS 204: ExpandS (Algorithm 27) +fn expandS(comptime eta: u8, seed: *const [64]u8, nonce: u16) Poly { + comptime { + if (eta != 2 and eta != 4) { + @compileError("eta must be 2 or 4"); + } + } + + var p = Poly.zero; + var i: usize = 0; + + var buf: [sha3.Shake256.block_length]u8 = undefined; // SHAKE-256 rate is 136 bytes + + // Prepare input: seed || nonce (little-endian u16) + var input: [66]u8 = undefined; + @memcpy(input[0..64], seed); + input[64] = @truncate(nonce); + input[65] = @truncate(nonce >> 8); + + var h = sha3.Shake256.init(.{}); + h.update(&input); + + while (i < N) { + h.squeeze(&buf); + + // Process buffer: extract two samples per byte (4-bit nibbles) + var j: usize = 0; + while (j < buf.len and i < N) : (j += 1) { + var t1 = @as(u32, buf[j]) & 15; + var t2 = @as(u32, buf[j]) >> 4; + + if (eta == 2) { + // For eta=2: reject if t > 14, then reduce mod 5 + if (t1 <= 14) { + t1 -%= ((205 * t1) >> 10) * 5; // reduce mod 5 + p.cs[i] = Q + eta - t1; + i += 1; + } + if (t2 <= 14 and i < N) { + t2 -%= ((205 * t2) >> 10) * 5; // reduce mod 5 + p.cs[i] = Q + eta - t2; + i += 1; + } + } else if (eta == 4) { + // For eta=4: accept if t <= 2*eta = 8 + if (t1 <= 2 * eta) { + p.cs[i] = Q + eta - t1; + i += 1; + } + if (t2 <= 2 * eta and i < N) { + p.cs[i] = Q + eta - t2; + i += 1; + } + } + } + } + + return p; +} + +/// Sample p uniformly with τ non-zero coefficients in {Q-1, 1} using SHAKE-256. +/// This creates a "ball" polynomial with exactly tau non-zero ±1 coefficients. +/// The polynomial will be normalized with coefficients in {0, 1, Q-1}. +/// FIPS 204: SampleInBall (Algorithm 18) +fn sampleInBall(comptime tau: u16, seed: []const u8) Poly { + var p = Poly.zero; + + var buf: [sha3.Shake256.block_length]u8 = undefined; // SHAKE-256 rate is 136 bytes + + var h = sha3.Shake256.init(.{}); + h.update(seed); + h.squeeze(&buf); + + // Extract signs from first 8 bytes + var signs: u64 = 0; + for (0..8) |j| { + signs |= @as(u64, buf[j]) << @intCast(j * 8); + } + var buf_off: usize = 8; + + // Generate tau non-zero coefficients using Fisher-Yates shuffle + // Start with N-tau zeros, then add tau ±1 values + var i: u16 = N - tau; + while (i < N) : (i += 1) { + var b: u16 = undefined; + + // Find location using rejection sampling + while (true) { + if (buf_off >= buf.len) { + h.squeeze(&buf); + buf_off = 0; + } + + b = buf[buf_off]; + buf_off += 1; + + if (b <= i) { + break; + } + } + + // Shuffle: move existing value to position i + p.cs[i] = p.cs[b]; + + // Set position b to ±1 based on sign bit + p.cs[b] = 1; + const sign_bit: u1 = @truncate(signs); + const mask = bitMask(u32, sign_bit); + p.cs[b] ^= mask & (1 | (Q - 1)); + signs >>= 1; + } + + return p; +} + +/// Sample a polynomial with coefficients uniformly distributed in (-gamma1, gamma1] +/// Used for sampling the masking vector y during signing +/// FIPS 204: ExpandMask (Algorithm 28) +fn expandMask(comptime gamma1_bits: u8, seed: *const [64]u8, nonce: u16) Poly { + const packed_size = ((gamma1_bits + 1) * N) / 8; + var buf: [packed_size]u8 = undefined; + + // Construct IV: seed || nonce (little-endian) + var iv: [66]u8 = undefined; + @memcpy(iv[0..64], seed); + iv[64] = @truncate(nonce & 0xFF); + iv[65] = @truncate(nonce >> 8); + + var h = sha3.Shake256.init(.{}); + h.update(&iv); + h.squeeze(&buf); + + // Unpack the polynomial + return polyUnpackLeGamma1(gamma1_bits, &buf); +} + +fn MLDSAImpl(comptime p: Params) type { + return struct { + pub const params = p; + pub const name = p.name; + pub const gamma1: u32 = @as(u32, 1) << p.gamma1_bits; + pub const beta: u32 = p.tau * p.eta; + pub const alpha: u32 = 2 * p.gamma2; + + const Self = @This(); + const PolyVecL = PolyVec(p.l); + const PolyVecK = PolyVec(p.k); + const MatKxL = Mat(p.k, p.l); + + /// Length of the seed used for deterministic key generation (32 bytes). + pub const seed_length: usize = 32; + + /// Length (in bytes) of optional random bytes, for non-deterministic signatures. + pub const noise_length = 32; + + /// Size of an encoded public key in bytes. + pub const public_key_bytes: usize = 32 + polyT1PackedSize() * p.k; + + /// Size of an encoded secret key in bytes. + pub const private_key_bytes: usize = 32 + 32 + p.tr_size + + polyLeqEtaPackedSize() * (p.l + p.k) + polyT0PackedSize() * p.k; + + /// Size of an encoded signature in bytes. + pub const signature_bytes: usize = p.ctilde_size + + polyLeGamma1PackedSize() * p.l + p.omega + p.k; + + // Packed sizes for different polynomial representations + fn polyLeqEtaPackedSize() usize { + // For eta=2: 3 bits per coefficient (values in [0,4]) + // For eta=4: 4 bits per coefficient (values in [0,8]) + const double_eta_bits = if (p.eta == 2) 3 else 4; + return (N * double_eta_bits) / 8; + } + + fn polyLeGamma1PackedSize() usize { + return ((p.gamma1_bits + 1) * N) / 8; + } + + fn polyT1PackedSize() usize { + return (N * (Q_BITS - D)) / 8; + } + + fn polyT0PackedSize() usize { + return (N * D) / 8; + } + + fn polyW1PackedSize() usize { + return (N * (Q_BITS - p.gamma1_bits)) / 8; + } + + /// Helper function to compute CRH (Collision Resistant Hash) using SHAKE-256. + /// This consolidates the repeated pattern of init-update-squeeze for hash operations. + fn crh(comptime outsize: usize, inputs: anytype) [outsize]u8 { + var h = sha3.Shake256.init(.{}); + inline for (inputs) |input| { + h.update(input); + } + var out: [outsize]u8 = undefined; + h.squeeze(&out); + return out; + } + + /// Helper function to compute t = As1 + s2. + /// This is used during key generation and public key reconstruction. + fn computeT(A: MatKxL, s1_hat: PolyVecL, s2: PolyVecK) PolyVecK { + const t = A.mulVec(s1_hat).add(s2); + return t.normalize(); + } + + /// ML-DSA public key + pub const PublicKey = struct { + /// Size of the encoded public key in bytes + pub const encoded_length: usize = 32 + polyT1PackedSize() * p.k; + + rho: [32]u8, // Seed for matrix A + t1: PolyVecK, // High bits of t = As1 + s2 + + // Cached values + t1_packed: [polyT1PackedSize() * p.k]u8, + A: MatKxL, + tr: [p.tr_size]u8, // CRH(rho || t1) + + /// Encode public key to bytes + pub fn toBytes(self: PublicKey) [encoded_length]u8 { + var out: [encoded_length]u8 = undefined; + @memcpy(out[0..32], &self.rho); + @memcpy(out[32..], &self.t1_packed); + return out; + } + + /// Decode public key from bytes + pub fn fromBytes(bytes: [encoded_length]u8) !PublicKey { + var pk: PublicKey = undefined; + @memcpy(&pk.rho, bytes[0..32]); + @memcpy(&pk.t1_packed, bytes[32..]); + + pk.t1 = PolyVecK.unpackT1(pk.t1_packed[0..]); + pk.A = MatKxL.derive(&pk.rho); + pk.tr = crh(p.tr_size, .{&bytes}); + + return pk; + } + }; + + /// ML-DSA secret key + pub const SecretKey = struct { + /// Size of the encoded secret key in bytes + pub const encoded_length: usize = 32 + 32 + p.tr_size + + polyLeqEtaPackedSize() * (p.l + p.k) + polyT0PackedSize() * p.k; + + rho: [32]u8, // Seed for matrix A + key: [32]u8, // Seed for signature generation randomness + tr: [p.tr_size]u8, // CRH(rho || t1) + s1: PolyVecL, // Secret vector 1 + s2: PolyVecK, // Secret vector 2 + t0: PolyVecK, // Low bits of t = As1 + s2 + + // Cached values (in NTT domain) + A: MatKxL, + s1_hat: PolyVecL, + s2_hat: PolyVecK, + t0_hat: PolyVecK, + + /// Encode secret key to bytes + pub fn toBytes(self: SecretKey) [encoded_length]u8 { + var out: [encoded_length]u8 = undefined; + var offset: usize = 0; + + @memcpy(out[offset .. offset + 32], &self.rho); + offset += 32; + + @memcpy(out[offset .. offset + 32], &self.key); + offset += 32; + + @memcpy(out[offset .. offset + p.tr_size], &self.tr); + offset += p.tr_size; + + if (p.eta == 2) { + self.s1.packLeqEta(2, out[offset..][0 .. p.l * polyLeqEtaPackedSize()]); + } else { + self.s1.packLeqEta(4, out[offset..][0 .. p.l * polyLeqEtaPackedSize()]); + } + offset += p.l * polyLeqEtaPackedSize(); + + if (p.eta == 2) { + self.s2.packLeqEta(2, out[offset..][0 .. p.k * polyLeqEtaPackedSize()]); + } else { + self.s2.packLeqEta(4, out[offset..][0 .. p.k * polyLeqEtaPackedSize()]); + } + offset += p.k * polyLeqEtaPackedSize(); + + self.t0.packT0(out[offset..][0 .. p.k * polyT0PackedSize()]); + offset += p.k * polyT0PackedSize(); + + return out; + } + + /// Decode secret key from bytes + pub fn fromBytes(bytes: [encoded_length]u8) !SecretKey { + var sk: SecretKey = undefined; + var offset: usize = 0; + + @memcpy(&sk.rho, bytes[offset .. offset + 32]); + offset += 32; + + @memcpy(&sk.key, bytes[offset .. offset + 32]); + offset += 32; + + @memcpy(&sk.tr, bytes[offset .. offset + p.tr_size]); + offset += p.tr_size; + + sk.s1 = if (p.eta == 2) + PolyVecL.unpackLeqEta(2, bytes[offset..][0 .. p.l * polyLeqEtaPackedSize()]) + else + PolyVecL.unpackLeqEta(4, bytes[offset..][0 .. p.l * polyLeqEtaPackedSize()]); + offset += p.l * polyLeqEtaPackedSize(); + + sk.s2 = if (p.eta == 2) + PolyVecK.unpackLeqEta(2, bytes[offset..][0 .. p.k * polyLeqEtaPackedSize()]) + else + PolyVecK.unpackLeqEta(4, bytes[offset..][0 .. p.k * polyLeqEtaPackedSize()]); + offset += p.k * polyLeqEtaPackedSize(); + + sk.t0 = PolyVecK.unpackT0(bytes[offset..][0 .. p.k * polyT0PackedSize()]); + offset += p.k * polyT0PackedSize(); + + // Compute cached NTT values for efficient signing + sk.A = MatKxL.derive(&sk.rho); + sk.s1_hat = sk.s1.ntt(); + sk.s2_hat = sk.s2.ntt(); + sk.t0_hat = sk.t0.ntt(); + + return sk; + } + + /// Compute the public key from this private key + pub fn public(self: *const SecretKey) PublicKey { + var pk: PublicKey = undefined; + pk.rho = self.rho; + pk.A = self.A; + pk.tr = self.tr; + + // Reconstruct t = As1 + s2, then extract high bits t1 + // Using power2Round: t = t1 * 2^D + t0 + const t = computeT(self.A, self.s1_hat, self.s2); + + var t0_unused: PolyVecK = undefined; + pk.t1 = t.power2Round(&t0_unused); + pk.t1.packT1(&pk.t1_packed); + + return pk; + } + + /// Create a Signer for incrementally signing a message. + /// The noise parameter can be null for deterministic signatures, + /// or provide randomness for hedged signatures (recommended for fault attack resistance). + pub fn signer(self: *const SecretKey, noise: ?[noise_length]u8) !Signer { + return self.signerWithContext(noise, ""); + } + + /// Create a Signer for incrementally signing a message with context. + /// The noise parameter can be null for deterministic signatures, + /// or provide randomness for hedged signatures (recommended for fault attack resistance). + /// The context parameter is an optional context string (max 255 bytes). + pub fn signerWithContext(self: *const SecretKey, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer { + return Signer.init(self, noise, context); + } + }; + + /// Generate a new key pair from a seed (deterministic) + pub fn newKeyFromSeed(seed: *const [seed_length]u8) struct { pk: PublicKey, sk: SecretKey } { + var sk: SecretKey = undefined; + var pk: PublicKey = undefined; + + // NIST mode: expand seed || k || l using SHAKE-256 to get 128-byte expanded seed + const e_seed = crh(128, .{ seed, &[_]u8{ p.k, p.l } }); + + @memcpy(&pk.rho, e_seed[0..32]); + const s_seed = e_seed[32..96]; + @memcpy(&sk.key, e_seed[96..128]); + @memcpy(&sk.rho, &pk.rho); + + sk.A = MatKxL.derive(&pk.rho); + pk.A = sk.A; + + const s_seed_array: *const [64]u8 = s_seed[0..64]; + for (0..p.l) |i| { + sk.s1.ps[i] = expandS(p.eta, s_seed_array, @intCast(i)); + } + + for (0..p.k) |i| { + sk.s2.ps[i] = expandS(p.eta, s_seed_array, @intCast(p.l + i)); + } + + sk.s1_hat = sk.s1.ntt(); + sk.s2_hat = sk.s2.ntt(); + + const t = computeT(sk.A, sk.s1_hat, sk.s2); + + pk.t1 = t.power2Round(&sk.t0); + sk.t0_hat = sk.t0.ntt(); + pk.t1.packT1(&pk.t1_packed); + + // tr = H(pk) = H(rho || t1) + const pk_bytes = pk.toBytes(); + const tr = crh(p.tr_size, .{&pk_bytes}); + sk.tr = tr; + pk.tr = tr; + + return .{ .pk = pk, .sk = sk }; + } + + /// ML-DSA signature + pub const Signature = struct { + /// Size of the encoded signature in bytes + pub const encoded_length: usize = p.ctilde_size + + polyLeGamma1PackedSize() * p.l + p.omega + p.k; + + c_tilde: [p.ctilde_size]u8, // Challenge hash + z: PolyVecL, // Response vector + hint: PolyVecK, // Hint vector + + /// Encode signature to bytes + pub fn toBytes(self: Signature) [encoded_length]u8 { + var out: [encoded_length]u8 = undefined; + var offset: usize = 0; + + @memcpy(out[offset .. offset + p.ctilde_size], &self.c_tilde); + offset += p.ctilde_size; + + self.z.packLeGamma1(p.gamma1_bits, out[offset .. offset + polyLeGamma1PackedSize() * p.l]); + offset += polyLeGamma1PackedSize() * p.l; + + _ = self.hint.packHint(p.omega, out[offset..]); + + return out; + } + + /// Decode signature from bytes + pub fn fromBytes(bytes: [encoded_length]u8) EncodingError!Signature { + var sig: Signature = undefined; + var offset: usize = 0; + + @memcpy(&sig.c_tilde, bytes[offset .. offset + p.ctilde_size]); + offset += p.ctilde_size; + + sig.z = PolyVecL.unpackLeGamma1(p.gamma1_bits, bytes[offset .. offset + polyLeGamma1PackedSize() * p.l]); + offset += polyLeGamma1PackedSize() * p.l; + + // Validate ||z||_inf < gamma1 - beta per FIPS 204 + if (sig.z.exceeds(gamma1 - beta)) { + return error.InvalidEncoding; + } + + sig.hint = PolyVecK.unpackHint(p.omega, bytes[offset..]) orelse return error.InvalidEncoding; + + return sig; + } + + pub const VerifyError = Verifier.InitError || Verifier.VerifyError; + + /// Verify this signature against a message and public key. + /// Returns an error if the signature is invalid. + pub fn verify( + sig: Signature, + msg: []const u8, + public_key: PublicKey, + ) VerifyError!void { + return sig.verifyWithContext(msg, public_key, ""); + } + + /// Verify this signature against a message and public key with context. + /// Returns an error if the signature is invalid. + /// The context parameter is an optional context string (max 255 bytes). + pub fn verifyWithContext( + sig: Signature, + msg: []const u8, + public_key: PublicKey, + context: []const u8, + ) VerifyError!void { + if (context.len > 255) { + return error.SignatureVerificationFailed; + } + + var h = sha3.Shake256.init(.{}); + h.update(&public_key.tr); + h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA + h.update(&[_]u8{@intCast(context.len)}); + if (context.len > 0) { + h.update(context); + } + h.update(msg); + var mu: [64]u8 = undefined; + h.squeeze(&mu); + + const z_hat = sig.z.ntt(); + const Az = public_key.A.mulVecHat(z_hat); + + // Compute w' ≈ Az - 2^d·c·t1 (approximate w used in signing) + var Az2dct1 = public_key.t1.mulBy2toD(); + Az2dct1 = Az2dct1.ntt(); + const c_poly = sampleInBall(p.tau, &sig.c_tilde); + const c_hat = c_poly.ntt(); + for (0..p.k) |i| { + Az2dct1.ps[i] = Az2dct1.ps[i].mulHat(c_hat); + } + Az2dct1 = Az.sub(Az2dct1); + Az2dct1 = Az2dct1.reduceLe2Q(); + Az2dct1 = Az2dct1.invNTT(); + Az2dct1 = Az2dct1.normalizeAssumingLe2Q(); + + // Apply hints to recover high bits w1' + var w1_prime = Az2dct1.useHint(sig.hint, p.gamma2); + var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined; + w1_prime.packW1(p.gamma1_bits, &w1_packed); + + const c_prime = crh(p.ctilde_size, .{ &mu, &w1_packed }); + + if (!mem.eql(u8, &c_prime, &sig.c_tilde)) { + return error.SignatureVerificationFailed; + } + } + + /// Create a Verifier for incrementally verifying a signature. + pub fn verifier(self: Signature, public_key: PublicKey) !Verifier { + return self.verifierWithContext(public_key, ""); + } + + /// Create a Verifier for incrementally verifying a signature with context. + /// The context parameter is an optional context string (max 255 bytes). + pub fn verifierWithContext(self: Signature, public_key: PublicKey, context: []const u8) ContextTooLongError!Verifier { + return Verifier.init(self, public_key, context); + } + }; + + /// A Signer is used to incrementally compute a signature over a streamed message. + /// It can be obtained from a `SecretKey` or `KeyPair`, using the `signer()` function. + pub const Signer = struct { + h: sha3.Shake256, // For computing μ = CRH(tr || msg) + secret_key: *const SecretKey, + rnd: [32]u8, + + /// Initialize a new Signer. + /// The noise parameter can be null for deterministic signatures, + /// or provide randomness for hedged signatures (recommended for fault attack resistance). + /// The context parameter is an optional context string (max 255 bytes). + pub fn init(secret_key: *const SecretKey, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer { + if (context.len > 255) { + return error.ContextTooLong; + } + + var h = sha3.Shake256.init(.{}); + h.update(&secret_key.tr); + h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA + h.update(&[_]u8{@intCast(context.len)}); + if (context.len > 0) { + h.update(context); + } + + return Signer{ + .h = h, + .secret_key = secret_key, + .rnd = noise orelse .{0} ** 32, + }; + } + + /// Add new data to the message being signed. + pub fn update(self: *Signer, data: []const u8) void { + self.h.update(data); + } + + /// Compute a signature over the entire message. + pub fn finalize(self: *Signer) Signature { + var mu: [64]u8 = undefined; + self.h.squeeze(&mu); + + const rho_prime = crh(64, .{ &self.secret_key.key, &self.rnd, &mu }); + + var sig: Signature = undefined; + var y_nonce: u16 = 0; + + // Rejection sampling loop (FIPS 204 Algorithm 2, steps 5-16) + var attempt: u32 = 0; + while (true) { + attempt += 1; + if (attempt >= 576) { // (6/7)⁵⁷⁶ < 2⁻¹²⁸ + @branchHint(.unlikely); + unreachable; + } + + const y = PolyVecL.deriveUniformLeGamma1(p.gamma1_bits, &rho_prime, y_nonce); + y_nonce += @intCast(p.l); + + const y_hat = y.ntt(); + var w = self.secret_key.A.mulVec(y_hat); + + w = w.normalize(); + var w0: PolyVecK = undefined; + const w1 = w.decomposeVec(p.gamma2, &w0); + var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined; + w1.packW1(p.gamma1_bits, &w1_packed); + + sig.c_tilde = crh(p.ctilde_size, .{ &mu, &w1_packed }); + + const c_poly = sampleInBall(p.tau, &sig.c_tilde); + const c_hat = c_poly.ntt(); + + // Rejection check: ensure masking is effective + var w0mcs2: PolyVecK = undefined; + for (0..p.k) |i| { + w0mcs2.ps[i] = c_hat.mulHat(self.secret_key.s2_hat.ps[i]); + w0mcs2.ps[i] = w0mcs2.ps[i].invNTT(); + } + w0mcs2 = w0.sub(w0mcs2); + w0mcs2 = w0mcs2.normalize(); + + if (w0mcs2.exceeds(p.gamma2 - beta)) { + continue; + } + + // Compute response z = y + c·s1 + for (0..p.l) |i| { + sig.z.ps[i] = c_hat.mulHat(self.secret_key.s1_hat.ps[i]); + sig.z.ps[i] = sig.z.ps[i].invNTT(); + } + sig.z = sig.z.add(y); + sig.z = sig.z.normalize(); + + if (sig.z.exceeds(gamma1 - beta)) { + continue; + } + + var ct0: PolyVecK = undefined; + for (0..p.k) |i| { + ct0.ps[i] = c_hat.mulHat(self.secret_key.t0_hat.ps[i]); + ct0.ps[i] = ct0.ps[i].invNTT(); + } + ct0 = ct0.reduceLe2Q(); + ct0 = ct0.normalize(); + + if (ct0.exceeds(p.gamma2)) { + continue; + } + + // Generate hints for verification + var w0mcs2pct0 = w0mcs2.add(ct0); + w0mcs2pct0 = w0mcs2pct0.reduceLe2Q(); + w0mcs2pct0 = w0mcs2pct0.normalizeAssumingLe2Q(); + const hint_result = PolyVecK.makeHintVec(w0mcs2pct0, w1, p.gamma2); + if (hint_result.pop > p.omega) { + continue; + } + sig.hint = hint_result.hint; + + return sig; + } + } + }; + + /// A Verifier is used to incrementally verify a signature over a streamed message. + /// It can be obtained from a `Signature`, using the `verifier()` function. + pub const Verifier = struct { + h: sha3.Shake256, // For computing μ = CRH(tr || msg) + signature: Signature, + public_key: PublicKey, + + pub const InitError = EncodingError; + pub const VerifyError = SignatureVerificationError; + + /// Initialize a new Verifier. + /// The context parameter is an optional context string (max 255 bytes). + pub fn init(signature: Signature, public_key: PublicKey, context: []const u8) ContextTooLongError!Verifier { + if (context.len > 255) { + return error.ContextTooLong; + } + + var h = sha3.Shake256.init(.{}); + h.update(&public_key.tr); + h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA + h.update(&[_]u8{@intCast(context.len)}); // Context length + if (context.len > 0) { + h.update(context); + } + + return Verifier{ + .h = h, + .signature = signature, + .public_key = public_key, + }; + } + + /// Add new content to the message to be verified. + pub fn update(self: *Verifier, data: []const u8) void { + self.h.update(data); + } + + /// Verify that the signature is valid for the entire message. + pub fn verify(self: *Verifier) SignatureVerificationError!void { + var mu: [64]u8 = undefined; + self.h.squeeze(&mu); + + const z_hat = self.signature.z.ntt(); + const Az = self.public_key.A.mulVecHat(z_hat); + + // Compute w' ≈ Az - 2^d·c·t1 (approximate w used in signing) + var Az2dct1 = self.public_key.t1.mulBy2toD(); + Az2dct1 = Az2dct1.ntt(); + const c_poly = sampleInBall(p.tau, &self.signature.c_tilde); + const c_hat = c_poly.ntt(); + for (0..p.k) |i| { + Az2dct1.ps[i] = Az2dct1.ps[i].mulHat(c_hat); + } + Az2dct1 = Az.sub(Az2dct1); + Az2dct1 = Az2dct1.reduceLe2Q(); + Az2dct1 = Az2dct1.invNTT(); + Az2dct1 = Az2dct1.normalizeAssumingLe2Q(); + + // Apply hints to recover high bits w1' + var w1_prime = Az2dct1.useHint(self.signature.hint, p.gamma2); + var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined; + w1_prime.packW1(p.gamma1_bits, &w1_packed); + + const c_prime = crh(p.ctilde_size, .{ &mu, &w1_packed }); + + if (!mem.eql(u8, &c_prime, &self.signature.c_tilde)) { + return error.SignatureVerificationFailed; + } + } + }; + + /// A key pair consisting of a secret key and its corresponding public key. + pub const KeyPair = struct { + /// Length (in bytes) of a seed required to create a key pair. + pub const seed_length = Self.seed_length; + + /// The public key component. + public_key: PublicKey, + + /// The secret key component. + secret_key: SecretKey, + + /// Generate a new random key pair. + /// This uses the system's cryptographically secure random number generator. + /// + /// `crypto.random.bytes` must be supported by the target. + pub fn generate() KeyPair { + var seed: [Self.seed_length]u8 = undefined; + crypto.random.bytes(&seed); + return generateDeterministic(seed) catch unreachable; + } + + /// Generate a key pair deterministically from a seed. + /// Use for testing or when reproducibility is required. + /// The seed should be generated using a cryptographically secure random source. + pub fn generateDeterministic(seed: [32]u8) !KeyPair { + const keys = newKeyFromSeed(&seed); + return .{ + .public_key = keys.pk, + .secret_key = keys.sk, + }; + } + + /// Derive the public key from an existing secret key. + /// This recomputes the public key components from the secret key. + pub fn fromSecretKey(sk: SecretKey) !KeyPair { + var pk: PublicKey = undefined; + pk.rho = sk.rho; + pk.tr = sk.tr; + pk.A = sk.A; + + const t = computeT(sk.A, sk.s1_hat, sk.s2); + + var t0: PolyVecK = undefined; + pk.t1 = t.power2Round(&t0); + pk.t1.packT1(&pk.t1_packed); + + return .{ + .public_key = pk, + .secret_key = sk, + }; + } + + /// Create a Signer for incrementally signing a message. + /// The noise parameter can be null for deterministic signatures, + /// or provide randomness for hedged signatures (recommended for fault attack resistance). + pub fn signer(self: *const KeyPair, noise: ?[noise_length]u8) !Signer { + return self.secret_key.signer(noise); + } + + /// Create a Signer for incrementally signing a message with context. + /// The noise parameter can be null for deterministic signatures, + /// or provide randomness for hedged signatures (recommended for fault attack resistance). + /// The context parameter is an optional context string (max 255 bytes). + pub fn signerWithContext(self: *const KeyPair, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer { + return self.secret_key.signerWithContext(noise, context); + } + + /// Sign a message using this key pair. + /// The noise parameter can be null for deterministic signatures, + /// or provide randomness for hedged signatures (recommended for fault attack resistance). + pub fn sign( + kp: KeyPair, + msg: []const u8, + noise: ?[noise_length]u8, + ) !Signature { + return kp.signWithContext(msg, noise, ""); + } + + /// Sign a message using this key pair with context. + /// The noise parameter can be null for deterministic signatures, + /// or provide randomness for hedged signatures (recommended for fault attack resistance). + /// The context parameter is an optional context string (max 255 bytes). + pub fn signWithContext( + kp: KeyPair, + msg: []const u8, + noise: ?[noise_length]u8, + context: []const u8, + ) ContextTooLongError!Signature { + var st = try kp.signerWithContext(noise, context); + st.update(msg); + return st.finalize(); + } + }; + }; +} + +test "modular arithmetic" { + // Test Montgomery reduction + const x: u64 = 12345678; + const y = montReduceLe2Q(x); + try testing.expect(y < 2 * Q); + + // Test modQ + try testing.expectEqual(@as(u32, 0), modQ(Q)); + try testing.expectEqual(@as(u32, 1), modQ(Q + 1)); +} + +test "polynomial operations" { + var p1 = Poly.zero; + p1.cs[0] = 1; + p1.cs[1] = 2; + + var p2 = Poly.zero; + p2.cs[0] = 3; + p2.cs[1] = 4; + + const p3 = p1.add(p2); + try testing.expectEqual(@as(u32, 4), p3.cs[0]); + try testing.expectEqual(@as(u32, 6), p3.cs[1]); +} + +test "NTT and inverse NTT" { + // Create a test polynomial in REGULAR FORM (not Montgomery) + var p = Poly.zero; + for (0..N) |i| { + p.cs[i] = @intCast(i % Q); + } + + // Apply NTT then inverse NTT + // According to Dilithium spec: NTT followed by invNTT multiplies by R + // So result will be p * R (i.e., p in Montgomery form) + var p_ntt = p.ntt(); + + // Reduce before invNTT (as Go test does) + p_ntt = p_ntt.reduceLe2Q(); + + const p_restored = p_ntt.invNTT(); + + // Reduce and normalize + const p_reduced = p_restored.reduceLe2Q(); + const p_norm = p_reduced.normalize(); + + // Check if we get p * R (which equals toMont(p)) + for (0..N) |i| { + const original: u32 = @intCast(i % Q); + const expected = toMont(original); + const expected_norm = modQ(expected); + try testing.expectEqual(expected_norm, p_norm.cs[i]); + } +} + +test "parameter set instantiation" { + // Just verify we can instantiate all three parameter sets + const ml44 = MLDSA44; + const ml65 = MLDSA65; + const ml87 = MLDSA87; + + try testing.expectEqualStrings("ML-DSA-44", ml44.name); + try testing.expectEqualStrings("ML-DSA-65", ml65.name); + try testing.expectEqualStrings("ML-DSA-87", ml87.name); +} + +test "compare zetas with Go implementation" { + // First 16 zetas from Go implementation (in Montgomery form) + const go_zetas = [16]u32{ + 4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, + 466468, 1826347, 2353451, 8021166, 6288512, 3119733, 5495562, + 3111497, 2680103, + }; + + // Compare our computed zetas with Go's + for (0..16) |i| { + try testing.expectEqual(go_zetas[i], zetas[i]); + } +} + +test "NTT with simple polynomial" { + // Test with a very simple polynomial: just one coefficient set to 1 in regular form + var p = Poly.zero; + p.cs[0] = 1; + + var p_ntt = p.ntt(); + + // Reduce before invNTT (as Go test does) + p_ntt = p_ntt.reduceLe2Q(); + + const p_restored = p_ntt.invNTT(); + + // Result should be 1 * R = toMont(1) in Montgomery form + const p_reduced = p_restored.reduceLe2Q(); + const p_norm = p_reduced.normalize(); + + const expected = modQ(toMont(1)); + try testing.expectEqual(expected, p_norm.cs[0]); + + // All other coefficients should be 0 * R = 0 + for (1..N) |i| { + try testing.expectEqual(@as(u32, 0), p_norm.cs[i]); + } +} + +test "Montgomery reduction correctness" { + // Test that Montgomery reduction works correctly + // montReduceLe2Q(a * b * R) = a * b mod q (where a, b are in Montgomery form) + + const x: u32 = 12345; + const y: u32 = 67890; + + // Convert to Montgomery form + const x_mont = toMont(x); + const y_mont = toMont(y); + + // Multiply in Montgomery form + const product_mont = montReduceLe2Q(@as(u64, x_mont) * @as(u64, y_mont)); + + // Convert back from Montgomery form + const product = montReduceLe2Q(@as(u64, product_mont)); + + // Direct multiplication mod q + const expected = modQ(@as(u32, @intCast((@as(u64, x) * @as(u64, y)) % Q))); + + try testing.expectEqual(expected, modQ(product)); +} + +// Removed debug test - was causing noise in output + +test "compare inv_zetas with Go implementation" { + // First 16 inv_zetas from Go implementation + const go_inv_zetas = [16]u32{ + 6403635, 846154, 6979993, 4442679, 1362209, 48306, 4460757, + 554416, 3545687, 6767575, 976891, 8196974, 2286327, 420899, + 2235985, 2939036, + }; + + // Compare our computed inv_zetas with Go's + for (0..16) |i| { + if (inv_zetas[i] != go_inv_zetas[i]) { + std.debug.print("Mismatch at inv_zetas[{d}]: got {d}, expected {d}\n", .{ i, inv_zetas[i], go_inv_zetas[i] }); + } + try testing.expectEqual(go_inv_zetas[i], inv_zetas[i]); + } +} + +test "power2Round correctness" { + // Test that power2Round correctly splits values + // For all a in [0, Q), we should have a = a1*2^D + a0 + // where -2^(D-1) < a0 <= 2^(D-1) + + // Test a few specific values + const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345, 8380416 }; + + for (test_values) |a| { + if (a >= Q) continue; + + const result = power2Round(a); + const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q)); + const a1 = result.a1; + + // Check reconstruction: a = a1*2^D + a0 + const reconstructed = @as(i32, @bitCast(a1 << D)) + a0; + try testing.expectEqual(@as(i32, @bitCast(a)), reconstructed); + + // Check a0 bounds: -2^(D-1) < a0 <= 2^(D-1) + const bound: i32 = 1 << (D - 1); + try testing.expect(a0 > -bound and a0 <= bound); + } +} + +test "decompose correctness for ML-DSA-65" { + // Test decompose with gamma2 = 95232 (ML-DSA-44) + const gamma2 = 95232; + const alpha = 2 * gamma2; + + const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345 }; + + for (test_values) |a| { + if (a >= Q) continue; + + const result = decompose(a, gamma2); + const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q)); + const a1 = result.a1; + + // Check reconstruction: a = a1*alpha + a0 (mod Q) + var reconstructed: i64 = @as(i64, @intCast(a1)) * @as(i64, @intCast(alpha)) + @as(i64, a0); + reconstructed = @mod(reconstructed, @as(i64, Q)); + try testing.expectEqual(@as(i64, @intCast(a)), reconstructed); + + // Check a0 bounds (approximately) + const bound: i32 = @intCast(alpha / 2); + try testing.expect(@abs(a0) <= bound); + } +} + +test "decompose correctness for ML-DSA-87" { + // Test decompose with gamma2 = 261888 (ML-DSA-65 and ML-DSA-87) + const gamma2 = 261888; + const alpha = 2 * gamma2; + + const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345 }; + + for (test_values) |a| { + if (a >= Q) continue; + + const result = decompose(a, gamma2); + const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q)); + const a1 = result.a1; + + // Check reconstruction: a = a1*alpha + a0 (mod Q) + var reconstructed: i64 = @as(i64, @intCast(a1)) * @as(i64, @intCast(alpha)) + @as(i64, a0); + reconstructed = @mod(reconstructed, @as(i64, Q)); + try testing.expectEqual(@as(i64, @intCast(a)), reconstructed); + + // Check a0 bounds (approximately) + const bound: i32 = @intCast(alpha / 2); + try testing.expect(@abs(a0) <= bound); + } +} + +test "polyDeriveUniform deterministic" { + // Test that polyDeriveUniform produces deterministic results + const seed: [32]u8 = .{0x01} ++ .{0x00} ** 31; + const nonce: u16 = 0; + + const p1 = polyDeriveUniform(&seed, nonce); + const p2 = polyDeriveUniform(&seed, nonce); + + // Should be identical + for (0..N) |i| { + try testing.expectEqual(p1.cs[i], p2.cs[i]); + } + + // All coefficients should be in [0, Q) + for (0..N) |i| { + try testing.expect(p1.cs[i] < Q); + } +} + +test "polyDeriveUniform different nonces" { + // Test that different nonces produce different polynomials + const seed: [32]u8 = .{0x01} ++ .{0x00} ** 31; + + const p1 = polyDeriveUniform(&seed, 0); + const p2 = polyDeriveUniform(&seed, 1); + + // Should be different + var different = false; + for (0..N) |i| { + if (p1.cs[i] != p2.cs[i]) { + different = true; + break; + } + } + try testing.expect(different); +} + +test "expandS with eta=2" { + // Test eta=2 sampling + const seed: [64]u8 = .{0x02} ++ .{0x00} ** 63; + const nonce: u16 = 0; + + const p = expandS(2, &seed, nonce); + + // All coefficients should be in [Q-eta, Q+eta] + // The function returns coefficients as Q + eta - t, where t is in [0, 2*eta] + // So coefficients are in [Q-eta, Q+eta] + for (0..N) |i| { + const c = p.cs[i]; + // Check that c is in [Q-2, Q+2] + try testing.expect(c >= Q - 2 and c <= Q + 2); + } +} + +test "expandS with eta=4" { + // Test eta=4 sampling + const seed: [64]u8 = .{0x03} ++ .{0x00} ** 63; + const nonce: u16 = 0; + + const p = expandS(4, &seed, nonce); + + // All coefficients should be in [Q-eta, Q+eta] + for (0..N) |i| { + const c = p.cs[i]; + // Check bounds (coefficients are around Q ± eta) + const diff = if (c >= Q) c - Q else Q - c; + try testing.expect(diff <= 4); + } +} + +test "sampleInBall has correct weight" { + // Test that ball polynomial has exactly tau non-zero coefficients + const tau = 39; // From ML-DSA-44 + const seed: [32]u8 = .{0x04} ++ .{0x00} ** 31; + + const p = sampleInBall(tau, &seed); + + // Count non-zero coefficients + var count: u32 = 0; + for (0..N) |i| { + if (p.cs[i] != 0) { + count += 1; + // Non-zero coefficients should be 1 or Q-1 + try testing.expect(p.cs[i] == 1 or p.cs[i] == Q - 1); + } + } + + try testing.expectEqual(tau, count); +} + +test "sampleInBall deterministic" { + // Test that ball sampling is deterministic + const tau = 49; // From ML-DSA-65 + const seed: [32]u8 = .{0x05} ++ .{0x00} ** 31; + + const p1 = sampleInBall(tau, &seed); + const p2 = sampleInBall(tau, &seed); + + // Should be identical + for (0..N) |i| { + try testing.expectEqual(p1.cs[i], p2.cs[i]); + } +} + +test "polyPackLeqEta / polyUnpackLeqEta roundtrip for eta=2" { + // Test packing and unpacking for eta=2 + const eta = 2; + + // Create a test polynomial with coefficients in [Q-eta, Q+eta] + var p = Poly.zero; + for (0..N) |i| { + // Use various values in range + const val = @as(u32, @intCast(i % 5)); // 0, 1, 2, 3, 4 + p.cs[i] = Q + eta - val; + } + + // Pack it + var buf: [96]u8 = undefined; // eta=2: 3 bits per coeff = 96 bytes + polyPackLeqEta(p, eta, &buf); + + // Unpack it + const p2 = polyUnpackLeqEta(eta, &buf); + + // Should be identical + for (0..N) |i| { + try testing.expectEqual(p.cs[i], p2.cs[i]); + } +} + +test "polyPackLeqEta / polyUnpackLeqEta roundtrip for eta=4" { + // Test packing and unpacking for eta=4 + const eta = 4; + + // Create a test polynomial with coefficients in [Q-eta, Q+eta] + var p = Poly.zero; + for (0..N) |i| { + // Use various values in range + const val = @as(u32, @intCast(i % 9)); // 0, 1, 2, ..., 8 + p.cs[i] = Q + eta - val; + } + + // Pack it + var buf: [128]u8 = undefined; // eta=4: 4 bits per coeff = 128 bytes + polyPackLeqEta(p, eta, &buf); + + // Unpack it + const p2 = polyUnpackLeqEta(eta, &buf); + + // Should be identical + for (0..N) |i| { + try testing.expectEqual(p.cs[i], p2.cs[i]); + } +} + +test "polyPackT1 / polyUnpackT1 roundtrip" { + // Create a test polynomial with coefficients < 1024 + var p = Poly.zero; + for (0..N) |i| { + p.cs[i] = @intCast(i % 1024); + } + + // Pack it + var buf: [320]u8 = undefined; // (256 * 10) / 8 = 320 bytes + polyPackT1(p, &buf); + + // Unpack it + const p2 = polyUnpackT1(&buf); + + // Should be identical + for (0..N) |i| { + try testing.expectEqual(p.cs[i], p2.cs[i]); + } +} + +test "polyPackT0 / polyUnpackT0 roundtrip" { + // Create a test polynomial with coefficients in (Q-2^12, Q+2^12] + // This is the range (-2^12, 2^12] represented as unsigned around Q + const bound = 1 << 12; // 2^(D-1) where D=13 + var p = Poly.zero; + for (0..N) |i| { + // Cycle through valid range for T0 + // Values should be Q + offset where offset is in (-bound, bound] + const cycle_val = @as(i32, @intCast(i % (2 * bound))); // 0 to 2*bound-1 + const offset = cycle_val - bound + 1; // (-bound+1) to bound + p.cs[i] = @as(u32, @intCast(@as(i32, Q) + offset)); + } + + // Pack it + var buf: [416]u8 = undefined; // (256 * 13) / 8 = 416 bytes + polyPackT0(p, &buf); + + // Unpack it + const p2 = polyUnpackT0(&buf); + + // Should be identical + for (0..N) |i| { + try testing.expectEqual(p.cs[i], p2.cs[i]); + } +} + +test "polyPackLeGamma1 / polyUnpackLeGamma1 roundtrip gamma1_bits=17" { + const gamma1_bits = 17; + const gamma1: u32 = @as(u32, 1) << gamma1_bits; + + // Create a test polynomial with coefficients in (-gamma1, gamma1] + // Normalized: [0, gamma1] ∪ (Q-gamma1, Q) + var p = Poly.zero; + for (0..N) |i| { + if (i % 2 == 0) { + // Positive values: [0, gamma1] + p.cs[i] = @intCast((i / 2) % (gamma1 + 1)); + } else { + // Negative values: (Q-gamma1, Q) + const neg_val: u32 = @intCast(((i / 2) % gamma1) + 1); + p.cs[i] = Q - neg_val; + } + } + + // Pack it + var buf: [576]u8 = undefined; // (256 * 18) / 8 = 576 bytes + polyPackLeGamma1(p, gamma1_bits, &buf); + + // Unpack it + const p2 = polyUnpackLeGamma1(gamma1_bits, &buf); + + // Should be identical + for (0..N) |i| { + try testing.expectEqual(p.cs[i], p2.cs[i]); + } +} + +test "polyPackLeGamma1 / polyUnpackLeGamma1 roundtrip gamma1_bits=19" { + const gamma1_bits = 19; + const gamma1: u32 = @as(u32, 1) << gamma1_bits; + + // Create a test polynomial with coefficients in (-gamma1, gamma1] + var p = Poly.zero; + for (0..N) |i| { + if (i % 2 == 0) { + // Positive values: [0, gamma1] + p.cs[i] = @intCast((i / 2) % (gamma1 + 1)); + } else { + // Negative values: (Q-gamma1, Q) + const neg_val: u32 = @intCast(((i / 2) % gamma1) + 1); + p.cs[i] = Q - neg_val; + } + } + + // Pack it + var buf: [640]u8 = undefined; // (256 * 20) / 8 = 640 bytes + polyPackLeGamma1(p, gamma1_bits, &buf); + + // Unpack it + const p2 = polyUnpackLeGamma1(gamma1_bits, &buf); + + // Should be identical + for (0..N) |i| { + try testing.expectEqual(p.cs[i], p2.cs[i]); + } +} + +test "polyPackW1 for gamma1_bits=17" { + const gamma1_bits = 17; + + // Create a test polynomial with small coefficients (w1 values < 64) + var p = Poly.zero; + for (0..N) |i| { + p.cs[i] = @intCast(i % 64); // 6-bit values + } + + // Pack it + var buf: [192]u8 = undefined; // (256 * 6) / 8 = 192 bytes + polyPackW1(p, gamma1_bits, &buf); + + // Verify basic properties + // All bytes should be used + var non_zero = false; + for (buf) |b| { + if (b != 0) { + non_zero = true; + break; + } + } + try testing.expect(non_zero); +} + +test "polyPackW1 for gamma1_bits=19" { + const gamma1_bits = 19; + + // Create a test polynomial with small coefficients (w1 values < 16) + var p = Poly.zero; + for (0..N) |i| { + p.cs[i] = @intCast(i % 16); // 4-bit values + } + + // Pack it + var buf: [128]u8 = undefined; // (256 * 4) / 8 = 128 bytes + polyPackW1(p, gamma1_bits, &buf); + + // Verify basic properties + var non_zero = false; + for (buf) |b| { + if (b != 0) { + non_zero = true; + break; + } + } + try testing.expect(non_zero); +} + +test "makeHint and useHint correctness for gamma2=261888" { + // Test for ML-DSA-65 and ML-DSA-87 + const gamma2: u32 = 261888; + + // Test a selection of values to verify the hint mechanism works + const test_values = [_]u32{ 0, 100, 1000, 10000, 100000, 1000000, Q / 2, Q - 1 }; + + for (test_values) |w| { + // Decompose w to get w0 and w1 + const decomp = decompose(w, gamma2); + const w0_plus_q = decomp.a0_plus_q; + const w1 = decomp.a1; + + // Test with various small perturbations f in [0, gamma2] + const perturbations = [_]u32{ 0, 1, 10, 100, 1000, gamma2 / 2, gamma2 }; + + for (perturbations) |f| { + // Test f (positive perturbation) + const z0_pos = (w0_plus_q +% Q -% f) % Q; + const hint_pos = makeHint(z0_pos, w1, gamma2); + const w_perturbed_pos = (w +% Q -% f) % Q; + const w1_recovered_pos = useHint(w_perturbed_pos, hint_pos, gamma2); + try testing.expectEqual(w1, w1_recovered_pos); + + // Test -f (negative perturbation) + if (f > 0) { + const z0_neg = (w0_plus_q +% f) % Q; + const hint_neg = makeHint(z0_neg, w1, gamma2); + const w_perturbed_neg = (w +% f) % Q; + const w1_recovered_neg = useHint(w_perturbed_neg, hint_neg, gamma2); + try testing.expectEqual(w1, w1_recovered_neg); + } + } + } +} + +test "makeHint and useHint correctness for gamma2=95232" { + // Test for ML-DSA-44 + const gamma2: u32 = 95232; + + // Test a selection of values to verify the hint mechanism works + const test_values = [_]u32{ 0, 100, 1000, 10000, 100000, 1000000, Q / 2, Q - 1 }; + + for (test_values) |w| { + // Decompose w to get w0 and w1 + const decomp = decompose(w, gamma2); + const w0_plus_q = decomp.a0_plus_q; + const w1 = decomp.a1; + + // Test with various small perturbations f in [0, gamma2] + const perturbations = [_]u32{ 0, 1, 10, 100, 1000, gamma2 / 2, gamma2 }; + + for (perturbations) |f| { + // Test f (positive perturbation) + const z0_pos = (w0_plus_q +% Q -% f) % Q; + const hint_pos = makeHint(z0_pos, w1, gamma2); + const w_perturbed_pos = (w +% Q -% f) % Q; + const w1_recovered_pos = useHint(w_perturbed_pos, hint_pos, gamma2); + try testing.expectEqual(w1, w1_recovered_pos); + + // Test -f (negative perturbation) + if (f > 0) { + const z0_neg = (w0_plus_q +% f) % Q; + const hint_neg = makeHint(z0_neg, w1, gamma2); + const w_perturbed_neg = (w +% f) % Q; + const w1_recovered_neg = useHint(w_perturbed_neg, hint_neg, gamma2); + try testing.expectEqual(w1, w1_recovered_neg); + } + } + } +} + +test "polyMakeHint basic functionality" { + const gamma2: u32 = 261888; + + // Create test polynomials + var p0 = Poly.zero; + var p1 = Poly.zero; + + // Fill with test values + for (0..N) |i| { + p0.cs[i] = @intCast((i * 17) % Q); + p1.cs[i] = @intCast((i * 3) % 16); // High bits are at most 15 for gamma2=261888 + } + + // Make hints + const result = polyMakeHint(p0, p1, gamma2); + const hint = result.hint; + const count = result.count; + + // Verify that hints are binary + for (0..N) |i| { + try testing.expect(hint.cs[i] == 0 or hint.cs[i] == 1); + } + + // Verify that count matches the number of 1s in hint + var actual_count: u32 = 0; + for (0..N) |i| { + actual_count += hint.cs[i]; + } + try testing.expectEqual(count, actual_count); +} + +test "polyUseHint reconstruction" { + const gamma2: u32 = 261888; + + // Create a test polynomial q + var q = Poly.zero; + for (0..N) |i| { + q.cs[i] = @intCast((i * 123) % Q); + } + + // Decompose q to get high and low bits + var q0_plus_q_array: [N]u32 = undefined; + var q1_array: [N]u32 = undefined; + for (0..N) |i| { + const decomp = decompose(q.cs[i], gamma2); + q0_plus_q_array[i] = decomp.a0_plus_q; + q1_array[i] = decomp.a1; + } + + const q0_plus_q = Poly{ .cs = q0_plus_q_array }; + const q1 = Poly{ .cs = q1_array }; + + // Create hints (in this case, they'll mostly be 0 since q and q are the same) + const hint_result = polyMakeHint(q0_plus_q, q1, gamma2); + const hint = hint_result.hint; + + // Use hints to recover high bits + const recovered = polyUseHint(q, hint, gamma2); + + // Recovered should match original high bits q1 + for (0..N) |i| { + try testing.expectEqual(q1.cs[i], recovered.cs[i]); + } +} + +test "hint roundtrip with perturbation" { + const gamma2: u32 = 261888; + + // Create a test polynomial w + var w = Poly.zero; + for (0..N) |i| { + w.cs[i] = @intCast((i * 7919) % Q); + } + + // Decompose w to get w0 and w1 + var w0_plus_q = Poly.zero; + var w1 = Poly.zero; + for (0..N) |i| { + const decomp = decompose(w.cs[i], gamma2); + w0_plus_q.cs[i] = decomp.a0_plus_q; + w1.cs[i] = decomp.a1; + } + + // Apply a small perturbation + var f = Poly.zero; + for (0..N) |i| { + // Small perturbation in [-gamma2, gamma2] + const f_val = @as(u32, @intCast(i % 1000)); + f.cs[i] = if (i % 2 == 0) f_val else Q -% f_val; + } + + // Compute w' = w - f and z0 = w0 - f + var w_prime = Poly.zero; + var z0 = Poly.zero; + for (0..N) |i| { + w_prime.cs[i] = (w.cs[i] +% Q -% f.cs[i]) % Q; + z0.cs[i] = (w0_plus_q.cs[i] +% Q -% f.cs[i]) % Q; + } + + // Make hints + const hint_result = polyMakeHint(z0, w1, gamma2); + const hint = hint_result.hint; + + // Use hints to recover w1 from w_prime + const w1_recovered = polyUseHint(w_prime, hint, gamma2); + + // Verify that we recovered the original high bits + for (0..N) |i| { + try testing.expectEqual(w1.cs[i], w1_recovered.cs[i]); + } +} + +// Parameterized test helper for key generation + +fn testKeyGenerationBasic(comptime MlDsa: type, seed: [32]u8) !void { + const result = MlDsa.newKeyFromSeed(&seed); + const pk = result.pk; + const sk = result.sk; + + // Basic sanity checks + try testing.expect(pk.rho.len == 32); + try testing.expect(sk.rho.len == 32); + try testing.expectEqualSlices(u8, &pk.rho, &sk.rho); + + // Verify tr matches between pk and sk + try testing.expectEqualSlices(u8, &pk.tr, &sk.tr); + + // Test toBytes/fromBytes round-trip for public key + const pk_bytes = pk.toBytes(); + const pk2 = try MlDsa.PublicKey.fromBytes(pk_bytes); + try testing.expectEqualSlices(u8, &pk.rho, &pk2.rho); + try testing.expectEqualSlices(u8, &pk.tr, &pk2.tr); + + // Test toBytes/fromBytes round-trip for secret key + const sk_bytes = sk.toBytes(); + const sk2 = try MlDsa.SecretKey.fromBytes(sk_bytes); + try testing.expectEqualSlices(u8, &sk.rho, &sk2.rho); + try testing.expectEqualSlices(u8, &sk.key, &sk2.key); + try testing.expectEqualSlices(u8, &sk.tr, &sk2.tr); +} + +test "Key generation basic - all variants" { + inline for (.{ + .{ .variant = MLDSA44, .seed_byte = 0x44 }, + .{ .variant = MLDSA65, .seed_byte = 0x65 }, + .{ .variant = MLDSA87, .seed_byte = 0x87 }, + }) |config| { + const seed = [_]u8{config.seed_byte} ** 32; + try testKeyGenerationBasic(config.variant, seed); + } +} + +test "Key generation determinism" { + const seed = [_]u8{ 0x12, 0x34, 0x56, 0x78 } ++ [_]u8{0xAB} ** 28; + + // Generate two key pairs from the same seed + const result1 = MLDSA44.newKeyFromSeed(&seed); + const result2 = MLDSA44.newKeyFromSeed(&seed); + + // They should be identical + const pk_bytes1 = result1.pk.toBytes(); + const pk_bytes2 = result2.pk.toBytes(); + try testing.expectEqualSlices(u8, &pk_bytes1, &pk_bytes2); + + const sk_bytes1 = result1.sk.toBytes(); + const sk_bytes2 = result2.sk.toBytes(); + try testing.expectEqualSlices(u8, &sk_bytes1, &sk_bytes2); +} + +test "Private key can compute public key" { + const seed = [_]u8{0xFF} ** 32; + const result = MLDSA44.newKeyFromSeed(&seed); + const pk = result.pk; + const sk = result.sk; + + // Compute public key from private key + const pk_from_sk = sk.public(); + + // Pack both public keys and compare + const pk_bytes1 = pk.toBytes(); + const pk_bytes2 = pk_from_sk.toBytes(); + + try testing.expectEqualSlices(u8, &pk_bytes1, &pk_bytes2); +} + +// Parameterized test helper for sign and verify +fn testSignAndVerify(comptime MlDsa: type, seed: [32]u8, message: []const u8) !void { + const result = MlDsa.newKeyFromSeed(&seed); + const kp = try MlDsa.KeyPair.fromSecretKey(result.sk); + + // Sign the message + const sig = try kp.sign(message, null); + + // Verify the signature + try sig.verify(message, kp.public_key); +} + +test "Sign and verify - all variants" { + inline for (.{ + .{ .variant = MLDSA44, .seed_byte = 0x44, .message = "Hello, ML-DSA-44!" }, + .{ .variant = MLDSA65, .seed_byte = 0x65, .message = "Hello, ML-DSA-65!" }, + .{ .variant = MLDSA87, .seed_byte = 0x87, .message = "Hello, ML-DSA-87!" }, + }) |config| { + const seed = [_]u8{config.seed_byte} ** 32; + try testSignAndVerify(config.variant, seed, config.message); + } +} + +test "Invalid signature rejection" { + const seed = [_]u8{0x99} ** 32; + const result = MLDSA44.newKeyFromSeed(&seed); + const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); + + const message = "Original message"; + + // Sign the message + const sig = try kp.sign(message, null); + + // Verify with wrong message should fail + const wrong_message = "Modified message"; + try testing.expectError(error.SignatureVerificationFailed, sig.verify(wrong_message, kp.public_key)); + + // Modify signature and verify should fail + var corrupted_sig_bytes = sig.toBytes(); + corrupted_sig_bytes[0] ^= 0xFF; + const corrupted_sig = try MLDSA44.Signature.fromBytes(corrupted_sig_bytes); + try testing.expectError(error.SignatureVerificationFailed, corrupted_sig.verify(message, kp.public_key)); +} + +test "Context string support" { + const seed = [_]u8{0xAA} ** 32; + const result = MLDSA44.newKeyFromSeed(&seed); + const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); + + const message = "Test message"; + const context1 = "context1"; + const context2 = "context2"; + + // Sign with context1 + const sig1 = try kp.signWithContext(message, null, context1); + + // Verify with correct context should succeed + try sig1.verifyWithContext(message, kp.public_key, context1); + + // Verify with wrong context should fail + try testing.expectError(error.SignatureVerificationFailed, sig1.verifyWithContext(message, kp.public_key, context2)); + + // Verify with empty context should fail + try testing.expectError(error.SignatureVerificationFailed, sig1.verify(message, kp.public_key)); + + // Sign with empty context + const sig2 = try kp.sign(message, null); + + // Verify with empty context should succeed + try sig2.verify(message, kp.public_key); + + // Verify with non-empty context should fail + try testing.expectError(error.SignatureVerificationFailed, sig2.verifyWithContext(message, kp.public_key, context1)); + + // Test maximum context length (255 bytes) + const max_context = [_]u8{0xBB} ** 255; + const sig3 = try kp.signWithContext(message, null, &max_context); + try sig3.verifyWithContext(message, kp.public_key, &max_context); + + // Test context too long (256 bytes should fail) + const too_long_context = [_]u8{0xCC} ** 256; + try testing.expectError(error.ContextTooLong, kp.signWithContext(message, null, &too_long_context)); +} + +test "Context string with streaming API" { + const seed = [_]u8{0xDD} ** 32; + const result = MLDSA44.newKeyFromSeed(&seed); + const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); + + const context = "streaming-context"; + const message_part1 = "Hello, "; + const message_part2 = "World!"; + + // Sign using streaming API with context + var signer = try kp.signerWithContext(null, context); + signer.update(message_part1); + signer.update(message_part2); + const sig = signer.finalize(); + + // Verify using streaming API with context + var verifier = try sig.verifierWithContext(kp.public_key, context); + verifier.update(message_part1); + verifier.update(message_part2); + try verifier.verify(); + + // Verify with wrong context should fail + var verifier_wrong = try sig.verifierWithContext(kp.public_key, "wrong"); + verifier_wrong.update(message_part1); + verifier_wrong.update(message_part2); + try testing.expectError(error.SignatureVerificationFailed, verifier_wrong.verify()); +} + +test "Signature determinism (same rnd)" { + const seed = [_]u8{0x11} ** 32; + const result = MLDSA44.newKeyFromSeed(&seed); + const sk = result.sk; + + const message = "Deterministic test"; + const rnd = [_]u8{0x22} ** 32; + + // Sign twice with same randomness using streaming API + var st1 = try sk.signer(rnd); + st1.update(message); + const sig1 = st1.finalize(); + + var st2 = try sk.signer(rnd); + st2.update(message); + const sig2 = st2.finalize(); + + // Signatures should be identical + try testing.expectEqualSlices(u8, &sig1.toBytes(), &sig2.toBytes()); +} + +test "Signature toBytes/fromBytes roundtrip" { + const seed = [_]u8{0x33} ** 32; + const result = MLDSA44.newKeyFromSeed(&seed); + const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); + + const message = "toBytes/fromBytes test"; + + // Sign the message + const sig = try kp.sign(message, null); + const sig_bytes = sig.toBytes(); + + // Unpack and repack + const sig_reparsed = try MLDSA44.Signature.fromBytes(sig_bytes); + + const repacked = sig_reparsed.toBytes(); + + // Should match original + try testing.expectEqualSlices(u8, &sig_bytes, &repacked); +} + +test "Empty message signing" { + const seed = [_]u8{0x44} ** 32; + const result = MLDSA44.newKeyFromSeed(&seed); + const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); + + const message = ""; + + // Sign empty message + const sig = try kp.sign(message, null); + + // Verify should work + try sig.verify(message, kp.public_key); +} + +test "Long message signing" { + const seed = [_]u8{0x55} ** 32; + const result = MLDSA44.newKeyFromSeed(&seed); + const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); + + // Create a long message (1KB) + const long_message = [_]u8{0xAB} ** 1024; + + // Sign long message + const sig = try kp.sign(&long_message, null); + + // Verify should work + try sig.verify(&long_message, kp.public_key); +} + +// Helper function to decode hex string into bytes +fn hexToBytes(comptime hex: []const u8, out: []u8) !void { + if (hex.len != out.len * 2) return error.InvalidLength; + + var i: usize = 0; + while (i < out.len) : (i += 1) { + const hi = try std.fmt.charToDigit(hex[i * 2], 16); + const lo = try std.fmt.charToDigit(hex[i * 2 + 1], 16); + out[i] = (hi << 4) | lo; + } +} + +test "ML-DSA-44 KAT test vector 0" { + // Test vector from NIST ML-DSA KAT (count = 0) + // xi is the seed for key generation (Algorithm 1, line 1) + const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68"; + const pk_hex_start = "bd4e96f9a038ab5e36214fe69c0b1cb835ef9d7c8417e76aecd152f5cddebec8"; + const msg_hex = "6dbbc4375136df3b07f7c70e639e223e"; + + // Parse xi (32-byte seed for key generation) + var xi: [32]u8 = undefined; + try hexToBytes(xi_hex, &xi); + + // Generate keys from xi + const result = MLDSA44.newKeyFromSeed(&xi); + const pk = result.pk; + const sk = result.sk; + + // Verify public key starts with expected bytes + const pk_bytes = pk.toBytes(); + + var expected_pk_start: [32]u8 = undefined; + try hexToBytes(pk_hex_start, &expected_pk_start); + + // Check first 32 bytes of public key match + try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]); + + // Parse message + var msg: [16]u8 = undefined; + try hexToBytes(msg_hex, &msg); + + // Sign the message (deterministic mode with fixed randomness) + const kp = try MLDSA44.KeyPair.fromSecretKey(sk); + const sig = try kp.sign(&msg, null); + + // Verify the signature + try sig.verify(&msg, kp.public_key); +} + +test "ML-DSA-65 KAT test vector 0" { + // Test vector from NIST ML-DSA KAT (count = 0) + // xi is the seed for key generation (Algorithm 1, line 1) + const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68"; + const pk_hex_start = "e50d03fff3b3a70961abbb92a390008dec1283f603f50cdbaaa3d00bd659bc76"; + const msg_hex = "6dbbc4375136df3b07f7c70e639e223e"; + + // Parse xi (32-byte seed for key generation) + var xi: [32]u8 = undefined; + try hexToBytes(xi_hex, &xi); + + // Generate keys from xi + const result = MLDSA65.newKeyFromSeed(&xi); + const pk = result.pk; + const sk = result.sk; + + // Verify public key starts with expected bytes + const pk_bytes = pk.toBytes(); + + var expected_pk_start: [32]u8 = undefined; + try hexToBytes(pk_hex_start, &expected_pk_start); + + // Check first 32 bytes of public key match + try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]); + + // Parse message + var msg: [16]u8 = undefined; + try hexToBytes(msg_hex, &msg); + + // Sign the message + const kp = try MLDSA65.KeyPair.fromSecretKey(sk); + const sig = try kp.sign(&msg, null); + + // Verify the signature + try sig.verify(&msg, kp.public_key); +} + +test "ML-DSA-87 KAT test vector 0" { + // Test vector from NIST ML-DSA KAT (count = 0) + // xi is the seed for key generation (Algorithm 1, line 1) + const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68"; + const pk_hex_start = "bc89b367d4288f47c71a74679d0fcffbe041de41b5da2f5fc66d8e28c5899494"; + const msg_hex = "6dbbc4375136df3b07f7c70e639e223e"; + + // Parse xi (32-byte seed for key generation) + var xi: [32]u8 = undefined; + try hexToBytes(xi_hex, &xi); + + // Generate keys from xi + const result = MLDSA87.newKeyFromSeed(&xi); + const pk = result.pk; + const sk = result.sk; + + // Verify public key starts with expected bytes + const pk_bytes = pk.toBytes(); + + var expected_pk_start: [32]u8 = undefined; + try hexToBytes(pk_hex_start, &expected_pk_start); + + // Check first 32 bytes of public key match + try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]); + + // Parse message + var msg: [16]u8 = undefined; + try hexToBytes(msg_hex, &msg); + + // Sign the message + const kp = try MLDSA87.KeyPair.fromSecretKey(sk); + const sig = try kp.sign(&msg, null); + + // Verify the signature + try sig.verify(&msg, kp.public_key); +} + +test "KeyPair API - generate and sign" { + // Test the new KeyPair API with random generation + const kp = MLDSA44.KeyPair.generate(); + const msg = "Test message for KeyPair API"; + + // Sign with deterministic mode (no noise) + const sig = try kp.sign(msg, null); + + // Verify using Signature.verify API + try sig.verify(msg, kp.public_key); +} + +test "KeyPair API - generateDeterministic" { + // Test deterministic key generation + const seed = [_]u8{42} ** 32; + const kp1 = try MLDSA44.KeyPair.generateDeterministic(seed); + const kp2 = try MLDSA44.KeyPair.generateDeterministic(seed); + + // Same seed should produce same keys + const pk1_bytes = kp1.public_key.toBytes(); + const pk2_bytes = kp2.public_key.toBytes(); + try testing.expectEqualSlices(u8, &pk1_bytes, &pk2_bytes); +} + +test "KeyPair API - fromSecretKey" { + // Generate a key pair + const kp1 = MLDSA44.KeyPair.generate(); + + // Derive public key from secret key + const kp2 = try MLDSA44.KeyPair.fromSecretKey(kp1.secret_key); + + // Public keys should match + const pk1_bytes = kp1.public_key.toBytes(); + const pk2_bytes = kp2.public_key.toBytes(); + try testing.expectEqualSlices(u8, &pk1_bytes, &pk2_bytes); +} + +test "Signature verification with noise" { + // Test signing with randomness (hedged signatures) + const kp = MLDSA65.KeyPair.generate(); + const msg = "Message to be signed with randomness"; + + // Create some noise + const noise = [_]u8{ 1, 2, 3, 4, 5 } ++ [_]u8{0} ** 27; + + // Sign with noise + const sig = try kp.sign(msg, noise); + + // Verify should still work + try sig.verify(msg, kp.public_key); +} + +test "Signature verification failure" { + // Test that invalid signatures are rejected + const kp = MLDSA44.KeyPair.generate(); + const msg = "Original message"; + const sig = try kp.sign(msg, null); + + // Verify with wrong message should fail + const wrong_msg = "Different message"; + try testing.expectError(error.SignatureVerificationFailed, sig.verify(wrong_msg, kp.public_key)); +} + +test "Streaming API - sign and verify" { + const seed = [_]u8{0x55} ** 32; + const kp = try MLDSA44.KeyPair.generateDeterministic(seed); + + const msg = "Test message for streaming API"; + + // Sign using streaming API + var signer = try kp.signer(null); + signer.update(msg); + const sig = signer.finalize(); + + // Verify using streaming API + var verifier = try sig.verifier(kp.public_key); + verifier.update(msg); + try verifier.verify(); +} + +test "Streaming API - chunked message" { + const seed = [_]u8{0x66} ** 32; + const kp = try MLDSA44.KeyPair.generateDeterministic(seed); + + // Create a message in chunks + const chunk1 = "Hello, "; + const chunk2 = "streaming "; + const chunk3 = "world!"; + const full_msg = chunk1 ++ chunk2 ++ chunk3; + + // Sign with chunks + var signer = try kp.signer(null); + signer.update(chunk1); + signer.update(chunk2); + signer.update(chunk3); + const sig_chunked = signer.finalize(); + + // Sign with full message for comparison + var signer2 = try kp.signer(null); + signer2.update(full_msg); + const sig_full = signer2.finalize(); + + // Signatures should be identical + try testing.expectEqualSlices(u8, &sig_chunked.toBytes(), &sig_full.toBytes()); + + // Verify with chunks + const sig = sig_chunked; + var verifier = try sig.verifier(kp.public_key); + verifier.update(chunk1); + verifier.update(chunk2); + verifier.update(chunk3); + try verifier.verify(); +} + +test "Streaming API - large message" { + const seed = [_]u8{0x77} ** 32; + const kp = try MLDSA44.KeyPair.generateDeterministic(seed); + + // Create a large message (1MB) + const chunk_size = 4096; + const num_chunks = 256; + var chunk: [chunk_size]u8 = undefined; + for (0..chunk_size) |i| { + chunk[i] = @intCast(i % 256); + } + + // Sign streaming + var signer = try kp.signer(null); + for (0..num_chunks) |_| { + signer.update(&chunk); + } + const sig = signer.finalize(); + + // Verify streaming + var verifier = try sig.verifier(kp.public_key); + for (0..num_chunks) |_| { + verifier.update(&chunk); + } + try verifier.verify(); +} + +test "Streaming API - all parameter sets" { + const test_msg = "Streaming test for all ML-DSA parameter sets"; + + // ML-DSA-44 + { + const seed = [_]u8{0x44} ** 32; + const kp = try MLDSA44.KeyPair.generateDeterministic(seed); + var signer = try kp.signer(null); + signer.update(test_msg); + const sig = signer.finalize(); + var verifier = try sig.verifier(kp.public_key); + verifier.update(test_msg); + try verifier.verify(); + } + + // ML-DSA-65 + { + const seed = [_]u8{0x65} ** 32; + const kp = try MLDSA65.KeyPair.generateDeterministic(seed); + var signer = try kp.signer(null); + signer.update(test_msg); + const sig = signer.finalize(); + var verifier = try sig.verifier(kp.public_key); + verifier.update(test_msg); + try verifier.verify(); + } + + // ML-DSA-87 + { + const seed = [_]u8{0x87} ** 32; + const kp = try MLDSA87.KeyPair.generateDeterministic(seed); + var signer = try kp.signer(null); + signer.update(test_msg); + const sig = signer.finalize(); + var verifier = try sig.verifier(kp.public_key); + verifier.update(test_msg); + try verifier.verify(); + } +} + +/// Extended Euclidian Algorithm +/// Only meant to be used on comptime values; correctness matters, performance doesn't. +fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } { + var a = a_; + var b = b_; + var x0: T = 1; + var x1: T = 0; + var y0: T = 0; + var y1: T = 1; + + while (b != 0) { + const q = @divTrunc(a, b); + const temp_a = a; + a = b; + b = temp_a - q * b; + + const temp_x = x0; + x0 = x1; + x1 = temp_x - q * x1; + + const temp_y = y0; + y0 = y1; + y1 = temp_y - q * y1; + } + + return .{ .gcd = a, .x = x0, .y = y0 }; +} + +/// Modular inversion: computes a^(-1) mod p +/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p). +fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T { + // Use a signed type for EEA computation + const type_info = @typeInfo(T); + const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned) + std.meta.Int(.signed, type_info.int.bits) + else + T; + + const a_signed = @as(SignedT, @intCast(a)); + const p_signed = @as(SignedT, @intCast(p)); + + const r = extendedEuclidean(SignedT, a_signed, p_signed); + assert(r.gcd == 1); + + // Normalize result to [0, p) + var result = r.x; + while (result < 0) { + result += p_signed; + } + + return @intCast(result); +} + +/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm. +fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T { + const type_info = @typeInfo(T); + const bits = type_info.int.bits; + const WideT = std.meta.Int(.unsigned, bits * 2); + + var ret: T = 1; + var base: T = a; + var exp = s; + + while (exp > 0) { + if (exp & 1 == 1) { + ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p); + } + base = @intCast((@as(WideT, base) * @as(WideT, base)) % p); + exp >>= 1; + } + + return ret; +} + +/// Creates an all-ones or all-zeros mask from a single bit value. +/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0. +fn bitMask(comptime T: type, bit: T) T { + const type_info = @typeInfo(T); + if (type_info != .int or type_info.int.signedness != .unsigned) { + @compileError("bitMask requires an unsigned integer type"); + } + return -%bit; +} + +/// Creates a mask from the sign bit of a signed integer. +/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0. +fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) { + const type_info = @typeInfo(T); + if (type_info != .int) { + @compileError("signMask requires an integer type"); + } + + const bits = type_info.int.bits; + const SignedT = std.meta.Int(.signed, bits); + + // Convert to signed if needed, arithmetic right shift to propagate sign bit + const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x); + const shifted = x_signed >> (bits - 1); + return @bitCast(shifted); +} + +/// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q). +/// This is a generic implementation parameterized by the modulus q, its inverse qInv, +/// the Montgomery constant R, and the result bound. +/// +/// For ML-DSA: R = 2^32, returns y < 2q +/// For ML-KEM: R = 2^16, returns y in range (-q, q) +fn montgomeryReduce( + comptime InT: type, + comptime OutT: type, + comptime q: comptime_int, + comptime qInv: comptime_int, + comptime r_bits: comptime_int, + x: InT, +) OutT { + const mask = (@as(InT, 1) << r_bits) - 1; + const m_full = (x *% qInv) & mask; + const m: OutT = @truncate(m_full); + + const yR = x -% @as(InT, m) * @as(InT, q); + const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits; + return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted))); +} + +/// Uniform sampling using SHAKE-128 with rejection sampling. +/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling. +/// +/// Parameters: +/// - PolyType: The polynomial type to return +/// - q: Modulus +/// - bits_per_coef: Number of bits per coefficient (12 or 23) +/// - n: Number of coefficients +/// - seed: Random seed +/// - domain_sep: Domain separation bytes (appended to seed) +fn sampleUniformRejection( + comptime PolyType: type, + comptime q: comptime_int, + comptime bits_per_coef: comptime_int, + comptime n: comptime_int, + seed: []const u8, + domain_sep: []const u8, +) PolyType { + var h = sha3.Shake128.init(.{}); + h.update(seed); + h.update(domain_sep); + + const buf_len = sha3.Shake128.block_length; // 168 bytes + var buf: [buf_len]u8 = undefined; + + var ret: PolyType = undefined; + var coef_idx: usize = 0; + + if (bits_per_coef == 12) { + // ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each) + outer: while (true) { + h.squeeze(&buf); + + var j: usize = 0; + while (j < buf_len) : (j += 3) { + const b0 = @as(u16, buf[j]); + const b1 = @as(u16, buf[j + 1]); + const b2 = @as(u16, buf[j + 2]); + + const ts: [2]u16 = .{ + b0 | ((b1 & 0xf) << 8), + (b1 >> 4) | (b2 << 4), + }; + + inline for (ts) |t| { + if (t < q) { + ret.cs[coef_idx] = @intCast(t); + coef_idx += 1; + if (coef_idx == n) break :outer; + } + } + } + } + } else if (bits_per_coef == 23) { + // ML-DSA path: 1 coefficient per 3 bytes (23 bits) + while (coef_idx < n) { + h.squeeze(&buf); + + var j: usize = 0; + while (j < buf_len and coef_idx < n) : (j += 3) { + const t = (@as(u32, buf[j]) | + (@as(u32, buf[j + 1]) << 8) | + (@as(u32, buf[j + 2]) << 16)) & 0x7fffff; + + if (t < q) { + ret.cs[coef_idx] = @intCast(t); + coef_idx += 1; + } + } + } + } else { + @compileError("bits_per_coef must be 12 or 23"); + } + + return ret; +} + +test "bitMask and signMask helpers" { + try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0)); + try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1)); + try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0)); + try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1)); + try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0)); + try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1)); + + try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1)); + try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100)); + try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0)); + try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1)); + try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100)); + + try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set + try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear +}