std.crypto.tls: certificate signature validation

This commit is contained in:
Andrew Kelley 2022-12-20 19:26:23 -07:00
parent 504070e8fc
commit 244a97e8ad
2 changed files with 326 additions and 37 deletions

View File

@ -15,7 +15,7 @@ pub const Key = struct {
/// The returned bytes become invalid after calling any of the rescan functions
/// or add functions.
pub fn find(cb: CertificateBundle, subject_name: []const u8) ?[]const u8 {
pub fn find(cb: CertificateBundle, subject_name: []const u8) ?u32 {
const Adapter = struct {
cb: CertificateBundle,
@ -29,8 +29,7 @@ pub fn find(cb: CertificateBundle, subject_name: []const u8) ?[]const u8 {
return mem.eql(u8, a, b);
}
};
const index = cb.map.getAdapted(subject_name, Adapter{ .cb = cb }) orelse return null;
return cb.bytes.items[index..];
return cb.map.getAdapted(subject_name, Adapter{ .cb = cb });
}
pub fn deinit(cb: *CertificateBundle, gpa: Allocator) void {
@ -105,7 +104,7 @@ pub fn addCertsFromFile(
const decoded_start = @intCast(u32, cb.bytes.items.len);
const dest_buf = cb.bytes.allocatedSlice()[decoded_start..];
cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert);
const k = try key(cb, decoded_start);
const k = try cb.key(decoded_start);
const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb });
if (gop.found_existing) {
cb.bytes.items.len = decoded_start;
@ -115,16 +114,12 @@ pub fn addCertsFromFile(
}
}
pub fn key(cb: *CertificateBundle, bytes_index: u32) !Key {
pub fn key(cb: CertificateBundle, bytes_index: u32) !Key {
const bytes = cb.bytes.items;
const certificate = try Der.parseElement(bytes, bytes_index);
const tbs_certificate = try Der.parseElement(bytes, certificate.start);
const version = try Der.parseElement(bytes, tbs_certificate.start);
if (@bitCast(u8, version.identifier) != 0xa0 or
!mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02"))
{
return error.UnsupportedCertificateVersion;
}
try checkVersion(bytes, version);
const serial_number = try Der.parseElement(bytes, version.end);
@ -144,10 +139,173 @@ pub fn key(cb: *CertificateBundle, bytes_index: u32) !Key {
};
}
pub const Certificate = struct {
buffer: []const u8,
index: u32,
pub fn verify(subject: Certificate, issuer: Certificate) !void {
const subject_certificate = try Der.parseElement(subject.buffer, subject.index);
const subject_tbs_certificate = try Der.parseElement(subject.buffer, subject_certificate.start);
const subject_version = try Der.parseElement(subject.buffer, subject_tbs_certificate.start);
try checkVersion(subject.buffer, subject_version);
const subject_serial_number = try Der.parseElement(subject.buffer, subject_version.end);
// RFC 5280, section 4.1.2.3:
// "This field MUST contain the same algorithm identifier as
// the signatureAlgorithm field in the sequence Certificate."
const subject_signature = try Der.parseElement(subject.buffer, subject_serial_number.end);
const subject_issuer = try Der.parseElement(subject.buffer, subject_signature.end);
const subject_validity = try Der.parseElement(subject.buffer, subject_issuer.end);
//const subject_name = try Der.parseElement(subject.buffer, subject_validity.end);
const subject_sig_algo = try Der.parseElement(subject.buffer, subject_tbs_certificate.end);
const subject_algo_elem = try Der.parseElement(subject.buffer, subject_sig_algo.start);
const subject_algo = try Der.parseObjectId(subject.buffer, subject_algo_elem);
const subject_sig_elem = try Der.parseElement(subject.buffer, subject_sig_algo.end);
const subject_sig = try parseBitString(subject, subject_sig_elem);
const issuer_certificate = try Der.parseElement(issuer.buffer, issuer.index);
const issuer_tbs_certificate = try Der.parseElement(issuer.buffer, issuer_certificate.start);
const issuer_version = try Der.parseElement(issuer.buffer, issuer_tbs_certificate.start);
try checkVersion(issuer.buffer, issuer_version);
const issuer_serial_number = try Der.parseElement(issuer.buffer, issuer_version.end);
// RFC 5280, section 4.1.2.3:
// "This field MUST contain the same algorithm identifier as
// the signatureAlgorithm field in the sequence Certificate."
const issuer_signature = try Der.parseElement(issuer.buffer, issuer_serial_number.end);
const issuer_issuer = try Der.parseElement(issuer.buffer, issuer_signature.end);
const issuer_validity = try Der.parseElement(issuer.buffer, issuer_issuer.end);
const issuer_name = try Der.parseElement(issuer.buffer, issuer_validity.end);
const issuer_pub_key_info = try Der.parseElement(issuer.buffer, issuer_name.end);
const issuer_pub_key_signature_algorithm = try Der.parseElement(issuer.buffer, issuer_pub_key_info.start);
const issuer_pub_key_algo_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.start);
const issuer_pub_key_algo = try Der.parseObjectId(issuer.buffer, issuer_pub_key_algo_elem);
const issuer_pub_key_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.end);
const issuer_pub_key = try parseBitString(issuer, issuer_pub_key_elem);
// Check that the subject's issuer name matches the issuer's subject
// name.
if (!mem.eql(u8, subject.contents(subject_issuer), issuer.contents(issuer_name))) {
return error.CertificateIssuerMismatch;
}
// TODO check the time validity for the subject
_ = subject_validity;
// TODO check the time validity for the issuer
const message = subject.buffer[subject_certificate.start..subject_tbs_certificate.end];
//std.debug.print("issuer algo: {any} subject algo: {any}\n", .{ issuer_pub_key_algo, subject_algo });
switch (subject_algo) {
// zig fmt: off
.sha1WithRSAEncryption => return verifyRsa(crypto.hash.Sha1, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
.sha224WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha224, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
.sha256WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha256, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
.sha384WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha384, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
.sha512WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha512, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
// zig fmt: on
else => {
std.debug.print("unhandled algorithm: {any}\n", .{subject_algo});
return error.UnsupportedCertificateSignatureAlgorithm;
},
}
}
pub fn contents(cert: Certificate, elem: Der.Element) []const u8 {
return cert.buffer[elem.start..elem.end];
}
pub fn parseBitString(cert: Certificate, elem: Der.Element) ![]const u8 {
if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType;
if (cert.buffer[elem.start] != 0) return error.CertificateHasInvalidBitString;
return cert.buffer[elem.start + 1 .. elem.end];
}
fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: Der.Oid, pub_key: []const u8) !void {
if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch;
const pub_key_seq = try Der.parseElement(pub_key, 0);
if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
const modulus_elem = try Der.parseElement(pub_key, pub_key_seq.start);
if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
const exponent_elem = try Der.parseElement(pub_key, modulus_elem.end);
if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
// Skip over meaningless zeroes in the modulus.
const modulus_raw = pub_key[modulus_elem.start..modulus_elem.end];
const modulus_offset = for (modulus_raw) |byte, i| {
if (byte != 0) break i;
} else modulus_raw.len;
const modulus = modulus_raw[modulus_offset..];
const exponent = pub_key[exponent_elem.start..exponent_elem.end];
if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid;
if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength;
const hash_der = switch (Hash) {
crypto.hash.Sha1 => [_]u8{
0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e,
0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14,
},
crypto.hash.sha2.Sha224 => [_]u8{
0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05,
0x00, 0x04, 0x1c,
},
crypto.hash.sha2.Sha256 => [_]u8{
0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05,
0x00, 0x04, 0x20,
},
crypto.hash.sha2.Sha384 => [_]u8{
0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05,
0x00, 0x04, 0x30,
},
crypto.hash.sha2.Sha512 => [_]u8{
0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05,
0x00, 0x04, 0x40,
},
else => @compileError("unreachable"),
};
var msg_hashed: [Hash.digest_length]u8 = undefined;
Hash.hash(message, &msg_hashed, .{});
switch (modulus.len) {
inline 128, 256, 512 => |modulus_len| {
const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3;
const em: [modulus_len]u8 =
[2]u8{ 0, 1 } ++
([1]u8{0xff} ** ps_len) ++
[1]u8{0} ++
hash_der ++
msg_hashed;
const public_key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop);
const em_dec = try rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop);
if (!mem.eql(u8, &em, &em_dec)) {
try std.testing.expectEqualSlices(u8, &em, &em_dec);
return error.CertificateSignatureInvalid;
}
},
else => {
return error.CertificateSignatureUnsupportedBitCount;
},
}
}
};
fn checkVersion(bytes: []const u8, version: Der.Element) !void {
if (@bitCast(u8, version.identifier) != 0xa0 or
!mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02"))
{
return error.UnsupportedCertificateVersion;
}
}
const builtin = @import("builtin");
const std = @import("../std.zig");
const fs = std.fs;
const mem = std.mem;
const crypto = std.crypto;
const Allocator = std.mem.Allocator;
const Der = std.crypto.Der;
const CertificateBundle = @This();
@ -177,3 +335,138 @@ test {
try bundle.rescan(std.testing.allocator);
}
/// TODO: replace this with Frank's upcoming RSA implementation. the verify
/// function won't have the possibility of failure - it will either identify a
/// valid signature or an invalid signature.
/// This code is borrowed from https://github.com/shiguredo/tls13-zig
/// which is licensed under the Apache License Version 2.0, January 2004
/// http://www.apache.org/licenses/
/// The code has been modified.
const rsa = struct {
const BigInt = std.math.big.int.Managed;
const PublicKey = struct {
n: BigInt,
e: BigInt,
pub fn deinit(self: *PublicKey) void {
self.n.deinit();
self.e.deinit();
}
pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey {
var _n = try BigInt.init(allocator);
errdefer _n.deinit();
try setBytes(&_n, modulus_bytes, allocator);
var _e = try BigInt.init(allocator);
errdefer _e.deinit();
try setBytes(&_e, pub_bytes, allocator);
return .{
.n = _n,
.e = _e,
};
}
};
fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 {
var m = try BigInt.init(allocator);
defer m.deinit();
try setBytes(&m, &msg, allocator);
if (m.order(public_key.n) != .lt) {
return error.MessageTooLong;
}
var e = try BigInt.init(allocator);
defer e.deinit();
try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator);
var res: [modulus_len]u8 = undefined;
try toBytes(&res, &e, allocator);
return res;
}
fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void {
try r.set(0);
var tmp = try BigInt.init(allcator);
defer tmp.deinit();
for (bytes) |b| {
try r.shiftLeft(r, 8);
try tmp.set(b);
try r.add(r, &tmp);
}
}
fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void {
var bin_raw: [512]u8 = undefined;
try toBytes(&bin_raw, x, allocator);
var i: usize = 0;
while (bin_raw[i] == 0x00) : (i += 1) {}
const bin = bin_raw[i..];
try r.set(1);
var r1 = try BigInt.init(allocator);
defer r1.deinit();
try BigInt.copy(&r1, a.toConst());
i = 0;
while (i < bin.len * 8) : (i += 1) {
if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) {
try BigInt.mul(&r1, r, &r1);
try mod(&r1, &r1, n, allocator);
try BigInt.sqr(r, r);
try mod(r, r, n, allocator);
} else {
try BigInt.mul(r, r, &r1);
try mod(r, r, n, allocator);
try BigInt.sqr(&r1, &r1);
try mod(&r1, &r1, n, allocator);
}
}
}
fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void {
const Error = error{
BufferTooSmall,
};
var mask = try BigInt.initSet(allocator, 0xFF);
defer mask.deinit();
var tmp = try BigInt.init(allocator);
defer tmp.deinit();
var a_copy = try BigInt.init(allocator);
defer a_copy.deinit();
try a_copy.copy(a.toConst());
// Encoding into big-endian bytes
var i: usize = 0;
while (i < out.len) : (i += 1) {
try tmp.bitAnd(&a_copy, &mask);
const b = try tmp.to(u8);
out[out.len - i - 1] = b;
try a_copy.shiftRight(&a_copy, 8);
}
if (!a_copy.eqZero()) {
return Error.BufferTooSmall;
}
}
fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void {
var q = try BigInt.init(allocator);
defer q.deinit();
try BigInt.divFloor(&q, rem, a, n);
}
// TODO: flush the toilet
const poop = std.heap.page_allocator;
};

View File

@ -53,15 +53,12 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con
0x02, // byte length of supported versions
0x03, 0x04, // TLS 1.3
}) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{
.rsa_pkcs1_sha256,
.rsa_pkcs1_sha384,
.rsa_pkcs1_sha512,
.ecdsa_secp256r1_sha256,
.ecdsa_secp384r1_sha384,
.ecdsa_secp521r1_sha512,
.rsa_pss_rsae_sha256,
.rsa_pss_rsae_sha384,
.rsa_pss_rsae_sha512,
.rsa_pkcs1_sha256,
.rsa_pkcs1_sha384,
.rsa_pkcs1_sha512,
.ed25519,
})) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{
.secp256r1,
@ -420,33 +417,32 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con
// "This field MUST contain the same algorithm identifier as
// the signatureAlgorithm field in the sequence Certificate."
const signature = try Der.parseElement(handshake, serial_number.end);
const issuer = try Der.parseElement(handshake, signature.end);
const validity = try Der.parseElement(handshake, issuer.end);
const subject = try Der.parseElement(handshake, validity.end);
const subject_pub_key = try Der.parseElement(handshake, subject.end);
const extensions = try Der.parseElement(handshake, subject_pub_key.end);
_ = extensions;
const issuer_elem = try Der.parseElement(handshake, signature.end);
const signature_algorithm = try Der.parseElement(handshake, tbs_certificate.end);
const signature_value = try Der.parseElement(handshake, signature_algorithm.end);
_ = signature_value;
const algorithm_elem = try Der.parseElement(handshake, signature_algorithm.start);
const algorithm = try Der.parseObjectId(handshake, algorithm_elem);
std.debug.print("cert has this signature algorithm: {any}\n", .{algorithm});
//const parameters = try Der.parseElement(signature_algorithm.contents, &sa_i);
const issuer_bytes = handshake[issuer_elem.start..issuer_elem.end];
if (ca_bundle.find(issuer_bytes)) |ca_cert_i| {
const Certificate = crypto.CertificateBundle.Certificate;
const subject: Certificate = .{
.buffer = handshake,
.index = hs_i,
};
const issuer: Certificate = .{
.buffer = ca_bundle.bytes.items,
.index = ca_cert_i,
};
if (subject.verify(issuer)) |_| {
std.debug.print("found a root CA cert matching issuer. verification success!\n", .{});
} else |err| {
std.debug.print("found a root CA cert matching issuer. verification failure: {s}\n", .{
@errorName(err),
});
}
}
hs_i = end_cert;
const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
hs_i += 2;
hs_i += total_ext_size;
const issuer_bytes = handshake[issuer.start..issuer.end];
const ca_cert = ca_bundle.find(issuer_bytes);
std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions. ca_found={any}\n", .{
cert_size, total_ext_size, ca_cert != null,
});
}
},
@enumToInt(HandshakeType.certificate_verify) => {