From 244a97e8ada5349136ca642d89092dbaf6e52ae2 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 20 Dec 2022 19:26:23 -0700 Subject: [PATCH] std.crypto.tls: certificate signature validation --- lib/std/crypto/CertificateBundle.zig | 313 ++++++++++++++++++++++++++- lib/std/crypto/tls/Client.zig | 50 ++--- 2 files changed, 326 insertions(+), 37 deletions(-) diff --git a/lib/std/crypto/CertificateBundle.zig b/lib/std/crypto/CertificateBundle.zig index 83560b6367..17c0286bed 100644 --- a/lib/std/crypto/CertificateBundle.zig +++ b/lib/std/crypto/CertificateBundle.zig @@ -15,7 +15,7 @@ pub const Key = struct { /// The returned bytes become invalid after calling any of the rescan functions /// or add functions. -pub fn find(cb: CertificateBundle, subject_name: []const u8) ?[]const u8 { +pub fn find(cb: CertificateBundle, subject_name: []const u8) ?u32 { const Adapter = struct { cb: CertificateBundle, @@ -29,8 +29,7 @@ pub fn find(cb: CertificateBundle, subject_name: []const u8) ?[]const u8 { return mem.eql(u8, a, b); } }; - const index = cb.map.getAdapted(subject_name, Adapter{ .cb = cb }) orelse return null; - return cb.bytes.items[index..]; + return cb.map.getAdapted(subject_name, Adapter{ .cb = cb }); } pub fn deinit(cb: *CertificateBundle, gpa: Allocator) void { @@ -105,7 +104,7 @@ pub fn addCertsFromFile( const decoded_start = @intCast(u32, cb.bytes.items.len); const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); - const k = try key(cb, decoded_start); + const k = try cb.key(decoded_start); const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb }); if (gop.found_existing) { cb.bytes.items.len = decoded_start; @@ -115,16 +114,12 @@ pub fn addCertsFromFile( } } -pub fn key(cb: *CertificateBundle, bytes_index: u32) !Key { +pub fn key(cb: CertificateBundle, bytes_index: u32) !Key { const bytes = cb.bytes.items; const certificate = try Der.parseElement(bytes, bytes_index); const tbs_certificate = try Der.parseElement(bytes, certificate.start); const version = try Der.parseElement(bytes, tbs_certificate.start); - if (@bitCast(u8, version.identifier) != 0xa0 or - !mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02")) - { - return error.UnsupportedCertificateVersion; - } + try checkVersion(bytes, version); const serial_number = try Der.parseElement(bytes, version.end); @@ -144,10 +139,173 @@ pub fn key(cb: *CertificateBundle, bytes_index: u32) !Key { }; } +pub const Certificate = struct { + buffer: []const u8, + index: u32, + + pub fn verify(subject: Certificate, issuer: Certificate) !void { + const subject_certificate = try Der.parseElement(subject.buffer, subject.index); + const subject_tbs_certificate = try Der.parseElement(subject.buffer, subject_certificate.start); + const subject_version = try Der.parseElement(subject.buffer, subject_tbs_certificate.start); + try checkVersion(subject.buffer, subject_version); + const subject_serial_number = try Der.parseElement(subject.buffer, subject_version.end); + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const subject_signature = try Der.parseElement(subject.buffer, subject_serial_number.end); + const subject_issuer = try Der.parseElement(subject.buffer, subject_signature.end); + const subject_validity = try Der.parseElement(subject.buffer, subject_issuer.end); + //const subject_name = try Der.parseElement(subject.buffer, subject_validity.end); + + const subject_sig_algo = try Der.parseElement(subject.buffer, subject_tbs_certificate.end); + const subject_algo_elem = try Der.parseElement(subject.buffer, subject_sig_algo.start); + const subject_algo = try Der.parseObjectId(subject.buffer, subject_algo_elem); + const subject_sig_elem = try Der.parseElement(subject.buffer, subject_sig_algo.end); + const subject_sig = try parseBitString(subject, subject_sig_elem); + + const issuer_certificate = try Der.parseElement(issuer.buffer, issuer.index); + const issuer_tbs_certificate = try Der.parseElement(issuer.buffer, issuer_certificate.start); + const issuer_version = try Der.parseElement(issuer.buffer, issuer_tbs_certificate.start); + try checkVersion(issuer.buffer, issuer_version); + const issuer_serial_number = try Der.parseElement(issuer.buffer, issuer_version.end); + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const issuer_signature = try Der.parseElement(issuer.buffer, issuer_serial_number.end); + const issuer_issuer = try Der.parseElement(issuer.buffer, issuer_signature.end); + const issuer_validity = try Der.parseElement(issuer.buffer, issuer_issuer.end); + const issuer_name = try Der.parseElement(issuer.buffer, issuer_validity.end); + const issuer_pub_key_info = try Der.parseElement(issuer.buffer, issuer_name.end); + const issuer_pub_key_signature_algorithm = try Der.parseElement(issuer.buffer, issuer_pub_key_info.start); + const issuer_pub_key_algo_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.start); + const issuer_pub_key_algo = try Der.parseObjectId(issuer.buffer, issuer_pub_key_algo_elem); + const issuer_pub_key_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.end); + const issuer_pub_key = try parseBitString(issuer, issuer_pub_key_elem); + + // Check that the subject's issuer name matches the issuer's subject + // name. + if (!mem.eql(u8, subject.contents(subject_issuer), issuer.contents(issuer_name))) { + return error.CertificateIssuerMismatch; + } + + // TODO check the time validity for the subject + _ = subject_validity; + // TODO check the time validity for the issuer + + const message = subject.buffer[subject_certificate.start..subject_tbs_certificate.end]; + //std.debug.print("issuer algo: {any} subject algo: {any}\n", .{ issuer_pub_key_algo, subject_algo }); + switch (subject_algo) { + // zig fmt: off + .sha1WithRSAEncryption => return verifyRsa(crypto.hash.Sha1, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + .sha224WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha224, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + .sha256WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha256, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + .sha384WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha384, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + .sha512WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha512, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + // zig fmt: on + else => { + std.debug.print("unhandled algorithm: {any}\n", .{subject_algo}); + return error.UnsupportedCertificateSignatureAlgorithm; + }, + } + } + + pub fn contents(cert: Certificate, elem: Der.Element) []const u8 { + return cert.buffer[elem.start..elem.end]; + } + + pub fn parseBitString(cert: Certificate, elem: Der.Element) ![]const u8 { + if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType; + if (cert.buffer[elem.start] != 0) return error.CertificateHasInvalidBitString; + return cert.buffer[elem.start + 1 .. elem.end]; + } + + fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: Der.Oid, pub_key: []const u8) !void { + if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; + const pub_key_seq = try Der.parseElement(pub_key, 0); + if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; + const modulus_elem = try Der.parseElement(pub_key, pub_key_seq.start); + if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + const exponent_elem = try Der.parseElement(pub_key, modulus_elem.end); + if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + // Skip over meaningless zeroes in the modulus. + const modulus_raw = pub_key[modulus_elem.start..modulus_elem.end]; + const modulus_offset = for (modulus_raw) |byte, i| { + if (byte != 0) break i; + } else modulus_raw.len; + const modulus = modulus_raw[modulus_offset..]; + const exponent = pub_key[exponent_elem.start..exponent_elem.end]; + if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; + if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; + + const hash_der = switch (Hash) { + crypto.hash.Sha1 => [_]u8{ + 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, + 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, + }, + crypto.hash.sha2.Sha224 => [_]u8{ + 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, + 0x00, 0x04, 0x1c, + }, + crypto.hash.sha2.Sha256 => [_]u8{ + 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, + 0x00, 0x04, 0x20, + }, + crypto.hash.sha2.Sha384 => [_]u8{ + 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, + 0x00, 0x04, 0x30, + }, + crypto.hash.sha2.Sha512 => [_]u8{ + 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, + 0x00, 0x04, 0x40, + }, + else => @compileError("unreachable"), + }; + + var msg_hashed: [Hash.digest_length]u8 = undefined; + Hash.hash(message, &msg_hashed, .{}); + + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; + const em: [modulus_len]u8 = + [2]u8{ 0, 1 } ++ + ([1]u8{0xff} ** ps_len) ++ + [1]u8{0} ++ + hash_der ++ + msg_hashed; + + const public_key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); + const em_dec = try rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop); + + if (!mem.eql(u8, &em, &em_dec)) { + try std.testing.expectEqualSlices(u8, &em, &em_dec); + return error.CertificateSignatureInvalid; + } + }, + else => { + return error.CertificateSignatureUnsupportedBitCount; + }, + } + } +}; + +fn checkVersion(bytes: []const u8, version: Der.Element) !void { + if (@bitCast(u8, version.identifier) != 0xa0 or + !mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02")) + { + return error.UnsupportedCertificateVersion; + } +} + const builtin = @import("builtin"); const std = @import("../std.zig"); const fs = std.fs; const mem = std.mem; +const crypto = std.crypto; const Allocator = std.mem.Allocator; const Der = std.crypto.Der; const CertificateBundle = @This(); @@ -177,3 +335,138 @@ test { try bundle.rescan(std.testing.allocator); } + +/// TODO: replace this with Frank's upcoming RSA implementation. the verify +/// function won't have the possibility of failure - it will either identify a +/// valid signature or an invalid signature. +/// This code is borrowed from https://github.com/shiguredo/tls13-zig +/// which is licensed under the Apache License Version 2.0, January 2004 +/// http://www.apache.org/licenses/ +/// The code has been modified. +const rsa = struct { + const BigInt = std.math.big.int.Managed; + + const PublicKey = struct { + n: BigInt, + e: BigInt, + + pub fn deinit(self: *PublicKey) void { + self.n.deinit(); + self.e.deinit(); + } + + pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey { + var _n = try BigInt.init(allocator); + errdefer _n.deinit(); + try setBytes(&_n, modulus_bytes, allocator); + + var _e = try BigInt.init(allocator); + errdefer _e.deinit(); + try setBytes(&_e, pub_bytes, allocator); + + return .{ + .n = _n, + .e = _e, + }; + } + }; + + fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 { + var m = try BigInt.init(allocator); + defer m.deinit(); + + try setBytes(&m, &msg, allocator); + + if (m.order(public_key.n) != .lt) { + return error.MessageTooLong; + } + + var e = try BigInt.init(allocator); + defer e.deinit(); + + try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator); + + var res: [modulus_len]u8 = undefined; + + try toBytes(&res, &e, allocator); + + return res; + } + + fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void { + try r.set(0); + var tmp = try BigInt.init(allcator); + defer tmp.deinit(); + for (bytes) |b| { + try r.shiftLeft(r, 8); + try tmp.set(b); + try r.add(r, &tmp); + } + } + + fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var bin_raw: [512]u8 = undefined; + try toBytes(&bin_raw, x, allocator); + + var i: usize = 0; + while (bin_raw[i] == 0x00) : (i += 1) {} + const bin = bin_raw[i..]; + + try r.set(1); + var r1 = try BigInt.init(allocator); + defer r1.deinit(); + try BigInt.copy(&r1, a.toConst()); + i = 0; + while (i < bin.len * 8) : (i += 1) { + if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) { + try BigInt.mul(&r1, r, &r1); + try mod(&r1, &r1, n, allocator); + try BigInt.sqr(r, r); + try mod(r, r, n, allocator); + } else { + try BigInt.mul(r, r, &r1); + try mod(r, r, n, allocator); + try BigInt.sqr(&r1, &r1); + try mod(&r1, &r1, n, allocator); + } + } + } + + fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void { + const Error = error{ + BufferTooSmall, + }; + + var mask = try BigInt.initSet(allocator, 0xFF); + defer mask.deinit(); + var tmp = try BigInt.init(allocator); + defer tmp.deinit(); + + var a_copy = try BigInt.init(allocator); + defer a_copy.deinit(); + try a_copy.copy(a.toConst()); + + // Encoding into big-endian bytes + var i: usize = 0; + while (i < out.len) : (i += 1) { + try tmp.bitAnd(&a_copy, &mask); + const b = try tmp.to(u8); + out[out.len - i - 1] = b; + try a_copy.shiftRight(&a_copy, 8); + } + + if (!a_copy.eqZero()) { + return Error.BufferTooSmall; + } + } + + fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var q = try BigInt.init(allocator); + defer q.deinit(); + + try BigInt.divFloor(&q, rem, a, n); + } + + // TODO: flush the toilet + const poop = std.heap.page_allocator; +}; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 45c96ed290..8395be4551 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -53,15 +53,12 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con 0x02, // byte length of supported versions 0x03, 0x04, // TLS 1.3 }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ - .rsa_pkcs1_sha256, - .rsa_pkcs1_sha384, - .rsa_pkcs1_sha512, .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, .ecdsa_secp521r1_sha512, - .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, .ed25519, })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ .secp256r1, @@ -420,33 +417,32 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con // "This field MUST contain the same algorithm identifier as // the signatureAlgorithm field in the sequence Certificate." const signature = try Der.parseElement(handshake, serial_number.end); - const issuer = try Der.parseElement(handshake, signature.end); - const validity = try Der.parseElement(handshake, issuer.end); - const subject = try Der.parseElement(handshake, validity.end); - const subject_pub_key = try Der.parseElement(handshake, subject.end); - const extensions = try Der.parseElement(handshake, subject_pub_key.end); - _ = extensions; + const issuer_elem = try Der.parseElement(handshake, signature.end); - const signature_algorithm = try Der.parseElement(handshake, tbs_certificate.end); - const signature_value = try Der.parseElement(handshake, signature_algorithm.end); - _ = signature_value; - - const algorithm_elem = try Der.parseElement(handshake, signature_algorithm.start); - const algorithm = try Der.parseObjectId(handshake, algorithm_elem); - std.debug.print("cert has this signature algorithm: {any}\n", .{algorithm}); - //const parameters = try Der.parseElement(signature_algorithm.contents, &sa_i); + const issuer_bytes = handshake[issuer_elem.start..issuer_elem.end]; + if (ca_bundle.find(issuer_bytes)) |ca_cert_i| { + const Certificate = crypto.CertificateBundle.Certificate; + const subject: Certificate = .{ + .buffer = handshake, + .index = hs_i, + }; + const issuer: Certificate = .{ + .buffer = ca_bundle.bytes.items, + .index = ca_cert_i, + }; + if (subject.verify(issuer)) |_| { + std.debug.print("found a root CA cert matching issuer. verification success!\n", .{}); + } else |err| { + std.debug.print("found a root CA cert matching issuer. verification failure: {s}\n", .{ + @errorName(err), + }); + } + } hs_i = end_cert; const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); hs_i += 2; hs_i += total_ext_size; - - const issuer_bytes = handshake[issuer.start..issuer.end]; - const ca_cert = ca_bundle.find(issuer_bytes); - - std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions. ca_found={any}\n", .{ - cert_size, total_ext_size, ca_cert != null, - }); } }, @enumToInt(HandshakeType.certificate_verify) => {