diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 8be281e9df..063f1d8077 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -917,18 +917,20 @@ pub const rsa = struct { return result; } - pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, allocator: std.mem.Allocator) !void { + pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type) !void { const mod_bits = public_key.n.bits(); const em_dec = try encrypt(modulus_len, sig, public_key); - EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash, allocator) catch unreachable; + EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash) catch unreachable; } - fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type, allocator: std.mem.Allocator) !void { - // TODO + fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) !void { // 1. If the length of M is greater than the input limitation for // the hash function (2^61 - 1 octets for SHA-1), output // "inconsistent" and stop. + // All the cryptographic hash functions in the standard library have a limit of >= 2^61 - 1. + // Even then, this check is only there for paranoia. In the context of TLS certifcates, emBit cannot exceed 4096. + if (emBit >= 1 << 61) return error.InvalidSignature; // emLen = \ceil(emBits/8) const emLen = ((emBit - 1) / 8) + 1; @@ -952,7 +954,7 @@ pub const rsa = struct { // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, // and let H be the next hLen octets. const maskedDB = em[0..(emLen - Hash.digest_length - 1)]; - const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)]; + const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)][0..Hash.digest_length]; // 6. If the leftmost 8emLen - emBits bits of the leftmost octet in // maskedDB are not all equal to zero, output "inconsistent" and @@ -969,9 +971,12 @@ pub const rsa = struct { // 7. Let dbMask = MGF(H, emLen - hLen - 1). const mgf_len = emLen - Hash.digest_length - 1; - var mgf_out = try allocator.alloc(u8, ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length); - defer allocator.free(mgf_out); - var dbMask = try MGF1(mgf_out, h, mgf_len, Hash, allocator); + var mgf_out_buf: [512]u8 = undefined; + if (mgf_len > mgf_out_buf.len) { // Modulus > 4096 bits + return error.InvalidSignature; + } + var mgf_out = mgf_out_buf[0 .. ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length]; + var dbMask = try MGF1(Hash, mgf_out, h, mgf_len); // 8. Let DB = maskedDB \xor dbMask. i = 0; @@ -1008,8 +1013,11 @@ pub const rsa = struct { // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; // M' is an octet string of length 8 + hLen + sLen with eight // initial zero octets. - var m_p = try allocator.alloc(u8, 8 + Hash.digest_length + sLen); - defer allocator.free(m_p); + if (sLen > Hash.digest_length) { // A seed larger than the hash length would be useless + return error.InvalidSignature; + } + var m_p_buf: [8 + Hash.digest_length + Hash.digest_length]u8 = undefined; + var m_p = m_p_buf[0 .. 8 + Hash.digest_length + sLen]; std.mem.copyForwards(u8, m_p, &([_]u8{0} ** 8)); std.mem.copyForwards(u8, m_p[8..], &mHash); std.mem.copyForwards(u8, m_p[(8 + Hash.digest_length)..], salt); @@ -1025,14 +1033,12 @@ pub const rsa = struct { } } - fn MGF1(out: []u8, seed: []const u8, len: usize, comptime Hash: type, allocator: std.mem.Allocator) ![]u8 { + fn MGF1(comptime Hash: type, out: []u8, seed: *const [Hash.digest_length]u8, len: usize) ![]u8 { var counter: usize = 0; var idx: usize = 0; var c: [4]u8 = undefined; - - var hash = try allocator.alloc(u8, seed.len + c.len); - defer allocator.free(hash); - std.mem.copyForwards(u8, hash, seed); + var hash: [Hash.digest_length + c.len]u8 = undefined; + @memcpy(hash[0..Hash.digest_length], seed); var hashed: [Hash.digest_length]u8 = undefined; while (idx < len) { @@ -1042,7 +1048,7 @@ pub const rsa = struct { c[3] = @intCast(u8, counter & 0xFF); std.mem.copyForwards(u8, hash[seed.len..], &c); - Hash.hash(hash, &hashed, .{}); + Hash.hash(&hash, &hashed, .{}); std.mem.copyForwards(u8, out[idx..], &hashed); idx += hashed.len; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index df6db1382b..0d404d29ac 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -424,7 +424,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In var handshake_state: HandshakeState = .encrypted_extensions; var cleartext_bufs: [2][8000]u8 = undefined; var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; - var main_cert_pub_key_buf: [300]u8 = undefined; + var main_cert_pub_key_buf: [600]u8 = undefined; var main_cert_pub_key_len: u16 = undefined; const now_sec = std.time.timestamp(); @@ -602,14 +602,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const components = try rsa.PublicKey.parseDer(main_cert_pub_key); const exponent = components.exponent; const modulus = components.modulus; - var rsa_mem_buf: [512 * 32]u8 = undefined; - var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); - const ally = fba.allocator(); switch (modulus.len) { inline 128, 256, 512 => |modulus_len| { const key = try rsa.PublicKey.fromBytes(exponent, modulus); const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally); + try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); }, else => { return error.TlsBadRsaSignatureBitCount;