diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 17c2e4afe9..6387eb48ae 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -177,8 +177,8 @@ const std = @import("std.zig"); pub const errors = @import("crypto/errors.zig"); pub const tls = @import("crypto/tls.zig"); -pub const Der = @import("crypto/Der.zig"); -pub const CertificateBundle = @import("crypto/CertificateBundle.zig"); +pub const der = @import("crypto/der.zig"); +pub const Certificate = @import("crypto/Certificate.zig"); test { _ = aead.aegis.Aegis128L; @@ -269,8 +269,8 @@ test { _ = random; _ = errors; _ = tls; - _ = Der; - _ = CertificateBundle; + _ = der; + _ = Certificate; } test "CSPRNG" { diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig new file mode 100644 index 0000000000..3d50e43839 --- /dev/null +++ b/lib/std/crypto/Certificate.zig @@ -0,0 +1,446 @@ +buffer: []const u8, +index: u32, + +pub const Bundle = @import("Certificate/Bundle.zig"); + +pub const Algorithm = enum { + sha1WithRSAEncryption, + sha224WithRSAEncryption, + sha256WithRSAEncryption, + sha384WithRSAEncryption, + sha512WithRSAEncryption, + + pub const map = std.ComptimeStringMap(Algorithm, .{ + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, + }); + + pub fn Hash(comptime algorithm: Algorithm) type { + return switch (algorithm) { + .sha1WithRSAEncryption => crypto.hash.Sha1, + .sha224WithRSAEncryption => crypto.hash.sha2.Sha224, + .sha256WithRSAEncryption => crypto.hash.sha2.Sha256, + .sha384WithRSAEncryption => crypto.hash.sha2.Sha384, + .sha512WithRSAEncryption => crypto.hash.sha2.Sha512, + }; + } +}; + +pub const AlgorithmCategory = enum { + rsaEncryption, + X9_62_id_ecPublicKey, + + pub const map = std.ComptimeStringMap(AlgorithmCategory, .{ + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, + }); +}; + +pub const Attribute = enum { + commonName, + serialNumber, + countryName, + localityName, + stateOrProvinceName, + organizationName, + organizationalUnitName, + organizationIdentifier, + + pub const map = std.ComptimeStringMap(Attribute, .{ + .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, + .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, + .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, + .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, + .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, + .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, + .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, + }); +}; + +pub const Parsed = struct { + certificate: Certificate, + issuer_slice: Slice, + subject_slice: Slice, + common_name_slice: Slice, + signature_slice: Slice, + signature_algorithm: Algorithm, + pub_key_algo: AlgorithmCategory, + pub_key_slice: Slice, + message_slice: Slice, + + pub const Slice = der.Element.Slice; + + pub fn slice(p: Parsed, s: Slice) []const u8 { + return p.certificate.buffer[s.start..s.end]; + } + + pub fn issuer(p: Parsed) []const u8 { + return p.slice(p.issuer_slice); + } + + pub fn subject(p: Parsed) []const u8 { + return p.slice(p.subject_slice); + } + + pub fn commonName(p: Parsed) []const u8 { + return p.slice(p.common_name_slice); + } + + pub fn signature(p: Parsed) []const u8 { + return p.slice(p.signature_slice); + } + + pub fn pubKey(p: Parsed) []const u8 { + return p.slice(p.pub_key_slice); + } + + pub fn message(p: Parsed) []const u8 { + return p.slice(p.message_slice); + } + + pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed) !void { + // Check that the subject's issuer name matches the issuer's + // subject name. + if (!mem.eql(u8, parsed_subject.issuer(), parsed_issuer.subject())) { + return error.CertificateIssuerMismatch; + } + + // TODO check the time validity for the subject + // TODO check the time validity for the issuer + + switch (parsed_subject.signature_algorithm) { + inline .sha1WithRSAEncryption, + .sha224WithRSAEncryption, + .sha256WithRSAEncryption, + .sha384WithRSAEncryption, + .sha512WithRSAEncryption, + => |algorithm| return verifyRsa( + algorithm.Hash(), + parsed_subject.message(), + parsed_subject.signature(), + parsed_issuer.pub_key_algo, + parsed_issuer.pubKey(), + ), + } + } +}; + +pub fn parse(cert: Certificate) !Parsed { + const cert_bytes = cert.buffer; + const certificate = try der.parseElement(cert_bytes, cert.index); + const tbs_certificate = try der.parseElement(cert_bytes, certificate.slice.start); + const version = try der.parseElement(cert_bytes, tbs_certificate.slice.start); + try checkVersion(cert_bytes, version); + const serial_number = try der.parseElement(cert_bytes, version.slice.end); + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const tbs_signature = try der.parseElement(cert_bytes, serial_number.slice.end); + const issuer = try der.parseElement(cert_bytes, tbs_signature.slice.end); + const validity = try der.parseElement(cert_bytes, issuer.slice.end); + const subject = try der.parseElement(cert_bytes, validity.slice.end); + + const pub_key_info = try der.parseElement(cert_bytes, subject.slice.end); + const pub_key_signature_algorithm = try der.parseElement(cert_bytes, pub_key_info.slice.start); + const pub_key_algo_elem = try der.parseElement(cert_bytes, pub_key_signature_algorithm.slice.start); + const pub_key_algo = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); + const pub_key_elem = try der.parseElement(cert_bytes, pub_key_signature_algorithm.slice.end); + const pub_key = try parseBitString(cert, pub_key_elem); + + const rdn = try der.parseElement(cert_bytes, subject.slice.start); + const atav = try der.parseElement(cert_bytes, rdn.slice.start); + + var common_name = der.Element.Slice.empty; + var atav_i = atav.slice.start; + while (atav_i < atav.slice.end) { + const ty_elem = try der.parseElement(cert_bytes, atav_i); + const ty = try parseAttribute(cert_bytes, ty_elem); + const val = try der.parseElement(cert_bytes, ty_elem.slice.end); + switch (ty) { + .commonName => common_name = val.slice, + else => {}, + } + atav_i = val.slice.end; + } + + const sig_algo = try der.parseElement(cert_bytes, tbs_certificate.slice.end); + const algo_elem = try der.parseElement(cert_bytes, sig_algo.slice.start); + const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem); + const sig_elem = try der.parseElement(cert_bytes, sig_algo.slice.end); + const signature = try parseBitString(cert, sig_elem); + + return .{ + .certificate = cert, + .common_name_slice = common_name, + .issuer_slice = issuer.slice, + .subject_slice = subject.slice, + .signature_slice = signature, + .signature_algorithm = signature_algorithm, + .message_slice = .{ .start = certificate.slice.start, .end = tbs_certificate.slice.end }, + .pub_key_algo = pub_key_algo, + .pub_key_slice = pub_key, + }; +} + +pub fn verify(subject: Certificate, issuer: Certificate) !void { + const parsed_subject = try subject.parse(); + const parsed_issuer = try issuer.parse(); + return parsed_subject.verify(parsed_issuer); +} + +pub fn contents(cert: Certificate, elem: der.Element) []const u8 { + return cert.buffer[elem.start..elem.end]; +} + +pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice { + if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType; + if (cert.buffer[elem.slice.start] != 0) return error.CertificateHasInvalidBitString; + return .{ .start = elem.slice.start + 1, .end = elem.slice.end }; +} + +pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return Algorithm.map.get(bytes[element.slice.start..element.slice.end]) orelse + return error.CertificateHasUnrecognizedAlgorithm; +} + +pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !AlgorithmCategory { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return AlgorithmCategory.map.get(bytes[element.slice.start..element.slice.end]) orelse + return error.CertificateHasUnrecognizedAlgorithmCategory; +} + +pub fn parseAttribute(bytes: []const u8, element: der.Element) !Attribute { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return Attribute.map.get(bytes[element.slice.start..element.slice.end]) orelse + return error.CertificateHasUnrecognizedAlgorithm; +} + +fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: AlgorithmCategory, pub_key: []const u8) !void { + if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; + const pub_key_seq = try der.parseElement(pub_key, 0); + if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; + const modulus_elem = try der.parseElement(pub_key, pub_key_seq.slice.start); + if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + const exponent_elem = try der.parseElement(pub_key, modulus_elem.slice.end); + if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + // Skip over meaningless zeroes in the modulus. + const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; + const modulus_offset = for (modulus_raw) |byte, i| { + if (byte != 0) break i; + } else modulus_raw.len; + const modulus = modulus_raw[modulus_offset..]; + const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end]; + if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; + if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; + + const hash_der = switch (Hash) { + crypto.hash.Sha1 => [_]u8{ + 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, + 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, + }, + crypto.hash.sha2.Sha224 => [_]u8{ + 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, + 0x00, 0x04, 0x1c, + }, + crypto.hash.sha2.Sha256 => [_]u8{ + 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, + 0x00, 0x04, 0x20, + }, + crypto.hash.sha2.Sha384 => [_]u8{ + 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, + 0x00, 0x04, 0x30, + }, + crypto.hash.sha2.Sha512 => [_]u8{ + 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, + 0x00, 0x04, 0x40, + }, + else => @compileError("unreachable"), + }; + + var msg_hashed: [Hash.digest_length]u8 = undefined; + Hash.hash(message, &msg_hashed, .{}); + + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; + const em: [modulus_len]u8 = + [2]u8{ 0, 1 } ++ + ([1]u8{0xff} ** ps_len) ++ + [1]u8{0} ++ + hash_der ++ + msg_hashed; + + const public_key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); + const em_dec = try rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop); + + if (!mem.eql(u8, &em, &em_dec)) { + try std.testing.expectEqualSlices(u8, &em, &em_dec); + return error.CertificateSignatureInvalid; + } + }, + else => { + return error.CertificateSignatureUnsupportedBitCount; + }, + } +} + +pub fn checkVersion(bytes: []const u8, version: der.Element) !void { + if (@bitCast(u8, version.identifier) != 0xa0 or + !mem.eql(u8, bytes[version.slice.start..version.slice.end], "\x02\x01\x02")) + { + return error.UnsupportedCertificateVersion; + } +} + +const std = @import("../std.zig"); +const crypto = std.crypto; +const mem = std.mem; +const der = std.crypto.der; +const Certificate = @This(); + +/// TODO: replace this with Frank's upcoming RSA implementation. the verify +/// function won't have the possibility of failure - it will either identify a +/// valid signature or an invalid signature. +/// This code is borrowed from https://github.com/shiguredo/tls13-zig +/// which is licensed under the Apache License Version 2.0, January 2004 +/// http://www.apache.org/licenses/ +/// The code has been modified. +const rsa = struct { + const BigInt = std.math.big.int.Managed; + + const PublicKey = struct { + n: BigInt, + e: BigInt, + + pub fn deinit(self: *PublicKey) void { + self.n.deinit(); + self.e.deinit(); + } + + pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey { + var _n = try BigInt.init(allocator); + errdefer _n.deinit(); + try setBytes(&_n, modulus_bytes, allocator); + + var _e = try BigInt.init(allocator); + errdefer _e.deinit(); + try setBytes(&_e, pub_bytes, allocator); + + return .{ + .n = _n, + .e = _e, + }; + } + }; + + fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 { + var m = try BigInt.init(allocator); + defer m.deinit(); + + try setBytes(&m, &msg, allocator); + + if (m.order(public_key.n) != .lt) { + return error.MessageTooLong; + } + + var e = try BigInt.init(allocator); + defer e.deinit(); + + try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator); + + var res: [modulus_len]u8 = undefined; + + try toBytes(&res, &e, allocator); + + return res; + } + + fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void { + try r.set(0); + var tmp = try BigInt.init(allcator); + defer tmp.deinit(); + for (bytes) |b| { + try r.shiftLeft(r, 8); + try tmp.set(b); + try r.add(r, &tmp); + } + } + + fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var bin_raw: [512]u8 = undefined; + try toBytes(&bin_raw, x, allocator); + + var i: usize = 0; + while (bin_raw[i] == 0x00) : (i += 1) {} + const bin = bin_raw[i..]; + + try r.set(1); + var r1 = try BigInt.init(allocator); + defer r1.deinit(); + try BigInt.copy(&r1, a.toConst()); + i = 0; + while (i < bin.len * 8) : (i += 1) { + if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) { + try BigInt.mul(&r1, r, &r1); + try mod(&r1, &r1, n, allocator); + try BigInt.sqr(r, r); + try mod(r, r, n, allocator); + } else { + try BigInt.mul(r, r, &r1); + try mod(r, r, n, allocator); + try BigInt.sqr(&r1, &r1); + try mod(&r1, &r1, n, allocator); + } + } + } + + fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void { + const Error = error{ + BufferTooSmall, + }; + + var mask = try BigInt.initSet(allocator, 0xFF); + defer mask.deinit(); + var tmp = try BigInt.init(allocator); + defer tmp.deinit(); + + var a_copy = try BigInt.init(allocator); + defer a_copy.deinit(); + try a_copy.copy(a.toConst()); + + // Encoding into big-endian bytes + var i: usize = 0; + while (i < out.len) : (i += 1) { + try tmp.bitAnd(&a_copy, &mask); + const b = try tmp.to(u8); + out[out.len - i - 1] = b; + try a_copy.shiftRight(&a_copy, 8); + } + + if (!a_copy.eqZero()) { + return Error.BufferTooSmall; + } + } + + fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var q = try BigInt.init(allocator); + defer q.deinit(); + + try BigInt.divFloor(&q, rem, a, n); + } + + // TODO: flush the toilet + const poop = std.heap.page_allocator; +}; diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig new file mode 100644 index 0000000000..c2c18552a7 --- /dev/null +++ b/lib/std/crypto/Certificate/Bundle.zig @@ -0,0 +1,174 @@ +//! A set of certificates. Typically pre-installed on every operating system, +//! these are "Certificate Authorities" used to validate SSL certificates. +//! This data structure stores certificates in DER-encoded form, all of them +//! concatenated together in the `bytes` array. The `map` field contains an +//! index from the DER-encoded subject name to the index of the containing +//! certificate within `bytes`. + +/// The key is the contents slice of the subject. +map: std.HashMapUnmanaged(der.Element.Slice, u32, MapContext, std.hash_map.default_max_load_percentage) = .{}, +bytes: std.ArrayListUnmanaged(u8) = .{}, + +pub fn verify(cb: Bundle, subject: Certificate.Parsed) !void { + const bytes_index = cb.find(subject.issuer()) orelse return error.IssuerNotFound; + const issuer_cert: Certificate = .{ + .buffer = cb.bytes.items, + .index = bytes_index, + }; + const issuer = try issuer_cert.parse(); + try subject.verify(issuer); +} + +/// The returned bytes become invalid after calling any of the rescan functions +/// or add functions. +pub fn find(cb: Bundle, subject_name: []const u8) ?u32 { + const Adapter = struct { + cb: Bundle, + + pub fn hash(ctx: @This(), k: []const u8) u64 { + _ = ctx; + return std.hash_map.hashString(k); + } + + pub fn eql(ctx: @This(), a: []const u8, b_key: der.Element.Slice) bool { + const b = ctx.cb.bytes.items[b_key.start..b_key.end]; + return mem.eql(u8, a, b); + } + }; + return cb.map.getAdapted(subject_name, Adapter{ .cb = cb }); +} + +pub fn deinit(cb: *Bundle, gpa: Allocator) void { + cb.map.deinit(gpa); + cb.bytes.deinit(gpa); + cb.* = undefined; +} + +/// Empties the set of certificates and then scans the host operating system +/// file system standard locations for certificates. +pub fn rescan(cb: *Bundle, gpa: Allocator) !void { + switch (builtin.os.tag) { + .linux => return rescanLinux(cb, gpa), + else => @compileError("it is unknown where the root CA certificates live on this OS"), + } +} + +pub fn rescanLinux(cb: *Bundle, gpa: Allocator) !void { + var dir = fs.openIterableDirAbsolute("/etc/ssl/certs", .{}) catch |err| switch (err) { + error.FileNotFound => return, + else => |e| return e, + }; + defer dir.close(); + + cb.bytes.clearRetainingCapacity(); + cb.map.clearRetainingCapacity(); + + var it = dir.iterate(); + while (try it.next()) |entry| { + switch (entry.kind) { + .File, .SymLink => {}, + else => continue, + } + + try addCertsFromFile(cb, gpa, dir.dir, entry.name); + } + + cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len); +} + +pub fn addCertsFromFile( + cb: *Bundle, + gpa: Allocator, + dir: fs.Dir, + sub_file_path: []const u8, +) !void { + var file = try dir.openFile(sub_file_path, .{}); + defer file.close(); + + const size = try file.getEndPos(); + + // We borrow `bytes` as a temporary buffer for the base64-encoded data. + // This is possible by computing the decoded length and reserving the space + // for the decoded bytes first. + const decoded_size_upper_bound = size / 4 * 3; + try cb.bytes.ensureUnusedCapacity(gpa, decoded_size_upper_bound + size); + const end_reserved = cb.bytes.items.len + decoded_size_upper_bound; + const buffer = cb.bytes.allocatedSlice()[end_reserved..]; + const end_index = try file.readAll(buffer); + const encoded_bytes = buffer[0..end_index]; + + const begin_marker = "-----BEGIN CERTIFICATE-----"; + const end_marker = "-----END CERTIFICATE-----"; + + var start_index: usize = 0; + while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| { + const cert_start = begin_marker_start + begin_marker.len; + const cert_end = mem.indexOfPos(u8, encoded_bytes, cert_start, end_marker) orelse + return error.MissingEndCertificateMarker; + start_index = cert_end + end_marker.len; + const encoded_cert = mem.trim(u8, encoded_bytes[cert_start..cert_end], " \t\r\n"); + const decoded_start = @intCast(u32, cb.bytes.items.len); + const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; + cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); + const k = try cb.key(decoded_start); + const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb }); + if (gop.found_existing) { + cb.bytes.items.len = decoded_start; + } else { + gop.value_ptr.* = decoded_start; + } + } +} + +pub fn key(cb: Bundle, bytes_index: u32) !der.Element.Slice { + const bytes = cb.bytes.items; + const certificate = try der.parseElement(bytes, bytes_index); + const tbs_certificate = try der.parseElement(bytes, certificate.slice.start); + const version = try der.parseElement(bytes, tbs_certificate.slice.start); + try Certificate.checkVersion(bytes, version); + const serial_number = try der.parseElement(bytes, version.slice.end); + const signature = try der.parseElement(bytes, serial_number.slice.end); + const issuer = try der.parseElement(bytes, signature.slice.end); + const validity = try der.parseElement(bytes, issuer.slice.end); + const subject = try der.parseElement(bytes, validity.slice.end); + + return subject.slice; +} + +const builtin = @import("builtin"); +const std = @import("../../std.zig"); +const fs = std.fs; +const mem = std.mem; +const crypto = std.crypto; +const Allocator = std.mem.Allocator; +const der = std.crypto.der; +const Certificate = std.crypto.Certificate; +const Bundle = @This(); + +const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); + +const MapContext = struct { + cb: *const Bundle, + + pub fn hash(ctx: MapContext, k: der.Element.Slice) u64 { + return std.hash_map.hashString(ctx.cb.bytes.items[k.start..k.end]); + } + + pub fn eql(ctx: MapContext, a: der.Element.Slice, b: der.Element.Slice) bool { + const bytes = ctx.cb.bytes.items; + return mem.eql( + u8, + bytes[a.start..a.end], + bytes[b.start..b.end], + ); + } +}; + +test "scan for OS-provided certificates" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + var bundle: Bundle = .{}; + defer bundle.deinit(std.testing.allocator); + + try bundle.rescan(std.testing.allocator); +} diff --git a/lib/std/crypto/CertificateBundle.zig b/lib/std/crypto/CertificateBundle.zig deleted file mode 100644 index 6f9e77a4d7..0000000000 --- a/lib/std/crypto/CertificateBundle.zig +++ /dev/null @@ -1,593 +0,0 @@ -//! A set of certificates. Typically pre-installed on every operating system, -//! these are "Certificate Authorities" used to validate SSL certificates. -//! This data structure stores certificates in DER-encoded form, all of them -//! concatenated together in the `bytes` array. The `map` field contains an -//! index from the DER-encoded subject name to the index of the containing -//! certificate within `bytes`. - -map: std.HashMapUnmanaged(Key, u32, MapContext, std.hash_map.default_max_load_percentage) = .{}, -bytes: std.ArrayListUnmanaged(u8) = .{}, - -pub const Key = struct { - subject_start: u32, - subject_end: u32, -}; - -pub fn verify(cb: CertificateBundle, subject: Certificate.Parsed) !void { - const bytes_index = cb.find(subject.issuer) orelse return error.IssuerNotFound; - const issuer_cert: Certificate = .{ - .buffer = cb.bytes.items, - .index = bytes_index, - }; - const issuer = try issuer_cert.parse(); - try subject.verify(issuer); -} - -/// The returned bytes become invalid after calling any of the rescan functions -/// or add functions. -pub fn find(cb: CertificateBundle, subject_name: []const u8) ?u32 { - const Adapter = struct { - cb: CertificateBundle, - - pub fn hash(ctx: @This(), k: []const u8) u64 { - _ = ctx; - return std.hash_map.hashString(k); - } - - pub fn eql(ctx: @This(), a: []const u8, b_key: Key) bool { - const b = ctx.cb.bytes.items[b_key.subject_start..b_key.subject_end]; - return mem.eql(u8, a, b); - } - }; - return cb.map.getAdapted(subject_name, Adapter{ .cb = cb }); -} - -pub fn deinit(cb: *CertificateBundle, gpa: Allocator) void { - cb.map.deinit(gpa); - cb.bytes.deinit(gpa); - cb.* = undefined; -} - -/// Empties the set of certificates and then scans the host operating system -/// file system standard locations for certificates. -pub fn rescan(cb: *CertificateBundle, gpa: Allocator) !void { - switch (builtin.os.tag) { - .linux => return rescanLinux(cb, gpa), - else => @compileError("it is unknown where the root CA certificates live on this OS"), - } -} - -pub fn rescanLinux(cb: *CertificateBundle, gpa: Allocator) !void { - var dir = fs.openIterableDirAbsolute("/etc/ssl/certs", .{}) catch |err| switch (err) { - error.FileNotFound => return, - else => |e| return e, - }; - defer dir.close(); - - cb.bytes.clearRetainingCapacity(); - cb.map.clearRetainingCapacity(); - - var it = dir.iterate(); - while (try it.next()) |entry| { - switch (entry.kind) { - .File, .SymLink => {}, - else => continue, - } - - try addCertsFromFile(cb, gpa, dir.dir, entry.name); - } - - cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len); -} - -pub fn addCertsFromFile( - cb: *CertificateBundle, - gpa: Allocator, - dir: fs.Dir, - sub_file_path: []const u8, -) !void { - var file = try dir.openFile(sub_file_path, .{}); - defer file.close(); - - const size = try file.getEndPos(); - - // We borrow `bytes` as a temporary buffer for the base64-encoded data. - // This is possible by computing the decoded length and reserving the space - // for the decoded bytes first. - const decoded_size_upper_bound = size / 4 * 3; - try cb.bytes.ensureUnusedCapacity(gpa, decoded_size_upper_bound + size); - const end_reserved = cb.bytes.items.len + decoded_size_upper_bound; - const buffer = cb.bytes.allocatedSlice()[end_reserved..]; - const end_index = try file.readAll(buffer); - const encoded_bytes = buffer[0..end_index]; - - const begin_marker = "-----BEGIN CERTIFICATE-----"; - const end_marker = "-----END CERTIFICATE-----"; - - var start_index: usize = 0; - while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| { - const cert_start = begin_marker_start + begin_marker.len; - const cert_end = mem.indexOfPos(u8, encoded_bytes, cert_start, end_marker) orelse - return error.MissingEndCertificateMarker; - start_index = cert_end + end_marker.len; - const encoded_cert = mem.trim(u8, encoded_bytes[cert_start..cert_end], " \t\r\n"); - const decoded_start = @intCast(u32, cb.bytes.items.len); - const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; - cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); - const k = try cb.key(decoded_start); - const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb }); - if (gop.found_existing) { - cb.bytes.items.len = decoded_start; - } else { - gop.value_ptr.* = decoded_start; - } - } -} - -pub fn key(cb: CertificateBundle, bytes_index: u32) !Key { - const bytes = cb.bytes.items; - const certificate = try Der.parseElement(bytes, bytes_index); - const tbs_certificate = try Der.parseElement(bytes, certificate.start); - const version = try Der.parseElement(bytes, tbs_certificate.start); - try checkVersion(bytes, version); - const serial_number = try Der.parseElement(bytes, version.end); - const signature = try Der.parseElement(bytes, serial_number.end); - const issuer = try Der.parseElement(bytes, signature.end); - const validity = try Der.parseElement(bytes, issuer.end); - const subject = try Der.parseElement(bytes, validity.end); - - return .{ - .subject_start = subject.start, - .subject_end = subject.end, - }; -} - -pub const Certificate = struct { - buffer: []const u8, - index: u32, - - pub const Algorithm = enum { - sha1WithRSAEncryption, - sha224WithRSAEncryption, - sha256WithRSAEncryption, - sha384WithRSAEncryption, - sha512WithRSAEncryption, - - pub const map = std.ComptimeStringMap(Algorithm, .{ - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, - }); - - pub fn Hash(comptime algorithm: Algorithm) type { - return switch (algorithm) { - .sha1WithRSAEncryption => crypto.hash.Sha1, - .sha224WithRSAEncryption => crypto.hash.sha2.Sha224, - .sha256WithRSAEncryption => crypto.hash.sha2.Sha256, - .sha384WithRSAEncryption => crypto.hash.sha2.Sha384, - .sha512WithRSAEncryption => crypto.hash.sha2.Sha512, - }; - } - }; - - pub const AlgorithmCategory = enum { - rsaEncryption, - X9_62_id_ecPublicKey, - - pub const map = std.ComptimeStringMap(AlgorithmCategory, .{ - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, - }); - }; - - pub const Attribute = enum { - commonName, - serialNumber, - countryName, - localityName, - stateOrProvinceName, - organizationName, - organizationalUnitName, - organizationIdentifier, - - pub const map = std.ComptimeStringMap(Attribute, .{ - .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, - .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, - .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, - .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, - .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, - .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, - .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, - .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, - }); - }; - - pub const Parsed = struct { - certificate: Certificate, - issuer: []const u8, - subject: []const u8, - common_name: []const u8, - signature: []const u8, - signature_algorithm: Algorithm, - message: []const u8, - pub_key_algo: AlgorithmCategory, - pub_key: []const u8, - - pub fn verify(subject: Parsed, issuer: Parsed) !void { - // Check that the subject's issuer name matches the issuer's - // subject name. - if (!mem.eql(u8, subject.issuer, issuer.subject)) { - return error.CertificateIssuerMismatch; - } - - // TODO check the time validity for the subject - // TODO check the time validity for the issuer - - switch (subject.signature_algorithm) { - inline .sha1WithRSAEncryption, - .sha224WithRSAEncryption, - .sha256WithRSAEncryption, - .sha384WithRSAEncryption, - .sha512WithRSAEncryption, - => |algorithm| return verifyRsa( - algorithm.Hash(), - subject.message, - subject.signature, - issuer.pub_key_algo, - issuer.pub_key, - ), - } - } - }; - - pub fn parse(cert: Certificate) !Parsed { - const cert_bytes = cert.buffer; - const certificate = try Der.parseElement(cert_bytes, cert.index); - const tbs_certificate = try Der.parseElement(cert_bytes, certificate.start); - const version = try Der.parseElement(cert_bytes, tbs_certificate.start); - try checkVersion(cert_bytes, version); - const serial_number = try Der.parseElement(cert_bytes, version.end); - // RFC 5280, section 4.1.2.3: - // "This field MUST contain the same algorithm identifier as - // the signatureAlgorithm field in the sequence Certificate." - const tbs_signature = try Der.parseElement(cert_bytes, serial_number.end); - const issuer = try Der.parseElement(cert_bytes, tbs_signature.end); - const validity = try Der.parseElement(cert_bytes, issuer.end); - const subject = try Der.parseElement(cert_bytes, validity.end); - - const pub_key_info = try Der.parseElement(cert_bytes, subject.end); - const pub_key_signature_algorithm = try Der.parseElement(cert_bytes, pub_key_info.start); - const pub_key_algo_elem = try Der.parseElement(cert_bytes, pub_key_signature_algorithm.start); - const pub_key_algo = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); - const pub_key_elem = try Der.parseElement(cert_bytes, pub_key_signature_algorithm.end); - const pub_key = try parseBitString(cert, pub_key_elem); - - const rdn = try Der.parseElement(cert_bytes, subject.start); - const atav = try Der.parseElement(cert_bytes, rdn.start); - - var common_name: []const u8 = &.{}; - var atav_i = atav.start; - while (atav_i < atav.end) { - const ty_elem = try Der.parseElement(cert_bytes, atav_i); - const ty = try parseAttribute(cert_bytes, ty_elem); - const val = try Der.parseElement(cert_bytes, ty_elem.end); - switch (ty) { - .commonName => common_name = cert.contents(val), - else => {}, - } - atav_i = val.end; - } - - const sig_algo = try Der.parseElement(cert_bytes, tbs_certificate.end); - const algo_elem = try Der.parseElement(cert_bytes, sig_algo.start); - const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem); - const sig_elem = try Der.parseElement(cert_bytes, sig_algo.end); - const signature = try parseBitString(cert, sig_elem); - - return .{ - .certificate = cert, - .common_name = common_name, - .issuer = cert.contents(issuer), - .subject = cert.contents(subject), - .signature = signature, - .signature_algorithm = signature_algorithm, - .message = cert_bytes[certificate.start..tbs_certificate.end], - .pub_key_algo = pub_key_algo, - .pub_key = pub_key, - }; - } - - pub fn verify(subject: Certificate, issuer: Certificate) !void { - const parsed_subject = try subject.parse(); - const parsed_issuer = try issuer.parse(); - return parsed_subject.verify(parsed_issuer); - } - - pub fn contents(cert: Certificate, elem: Der.Element) []const u8 { - return cert.buffer[elem.start..elem.end]; - } - - pub fn parseBitString(cert: Certificate, elem: Der.Element) ![]const u8 { - if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType; - if (cert.buffer[elem.start] != 0) return error.CertificateHasInvalidBitString; - return cert.buffer[elem.start + 1 .. elem.end]; - } - - pub fn parseAlgorithm(bytes: []const u8, element: Der.Element) !Algorithm { - if (element.identifier.tag != .object_identifier) - return error.CertificateFieldHasWrongDataType; - return Algorithm.map.get(bytes[element.start..element.end]) orelse - return error.CertificateHasUnrecognizedAlgorithm; - } - - pub fn parseAlgorithmCategory(bytes: []const u8, element: Der.Element) !AlgorithmCategory { - if (element.identifier.tag != .object_identifier) - return error.CertificateFieldHasWrongDataType; - return AlgorithmCategory.map.get(bytes[element.start..element.end]) orelse { - std.debug.print("unrecognized algorithm category: {}\n", .{std.fmt.fmtSliceHexLower(bytes[element.start..element.end])}); - return error.CertificateHasUnrecognizedAlgorithmCategory; - }; - } - - pub fn parseAttribute(bytes: []const u8, element: Der.Element) !Attribute { - if (element.identifier.tag != .object_identifier) - return error.CertificateFieldHasWrongDataType; - return Attribute.map.get(bytes[element.start..element.end]) orelse - return error.CertificateHasUnrecognizedAlgorithm; - } - - fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: AlgorithmCategory, pub_key: []const u8) !void { - if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; - const pub_key_seq = try Der.parseElement(pub_key, 0); - if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; - const modulus_elem = try Der.parseElement(pub_key, pub_key_seq.start); - if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - const exponent_elem = try Der.parseElement(pub_key, modulus_elem.end); - if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - // Skip over meaningless zeroes in the modulus. - const modulus_raw = pub_key[modulus_elem.start..modulus_elem.end]; - const modulus_offset = for (modulus_raw) |byte, i| { - if (byte != 0) break i; - } else modulus_raw.len; - const modulus = modulus_raw[modulus_offset..]; - const exponent = pub_key[exponent_elem.start..exponent_elem.end]; - if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; - if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; - - const hash_der = switch (Hash) { - crypto.hash.Sha1 => [_]u8{ - 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, - 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, - }, - crypto.hash.sha2.Sha224 => [_]u8{ - 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, - 0x00, 0x04, 0x1c, - }, - crypto.hash.sha2.Sha256 => [_]u8{ - 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, - 0x00, 0x04, 0x20, - }, - crypto.hash.sha2.Sha384 => [_]u8{ - 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, - 0x00, 0x04, 0x30, - }, - crypto.hash.sha2.Sha512 => [_]u8{ - 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, - 0x00, 0x04, 0x40, - }, - else => @compileError("unreachable"), - }; - - var msg_hashed: [Hash.digest_length]u8 = undefined; - Hash.hash(message, &msg_hashed, .{}); - - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; - const em: [modulus_len]u8 = - [2]u8{ 0, 1 } ++ - ([1]u8{0xff} ** ps_len) ++ - [1]u8{0} ++ - hash_der ++ - msg_hashed; - - const public_key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); - const em_dec = try rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop); - - if (!mem.eql(u8, &em, &em_dec)) { - try std.testing.expectEqualSlices(u8, &em, &em_dec); - return error.CertificateSignatureInvalid; - } - }, - else => { - return error.CertificateSignatureUnsupportedBitCount; - }, - } - } -}; - -fn checkVersion(bytes: []const u8, version: Der.Element) !void { - if (@bitCast(u8, version.identifier) != 0xa0 or - !mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02")) - { - return error.UnsupportedCertificateVersion; - } -} - -const builtin = @import("builtin"); -const std = @import("../std.zig"); -const fs = std.fs; -const mem = std.mem; -const crypto = std.crypto; -const Allocator = std.mem.Allocator; -const Der = std.crypto.Der; -const CertificateBundle = @This(); - -const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); - -const MapContext = struct { - cb: *const CertificateBundle, - - pub fn hash(ctx: MapContext, k: Key) u64 { - return std.hash_map.hashString(ctx.cb.bytes.items[k.subject_start..k.subject_end]); - } - - pub fn eql(ctx: MapContext, a: Key, b: Key) bool { - const bytes = ctx.cb.bytes.items; - return mem.eql( - u8, - bytes[a.subject_start..a.subject_end], - bytes[b.subject_start..b.subject_end], - ); - } -}; - -test "scan for OS-provided certificates" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - var bundle: CertificateBundle = .{}; - defer bundle.deinit(std.testing.allocator); - - try bundle.rescan(std.testing.allocator); -} - -/// TODO: replace this with Frank's upcoming RSA implementation. the verify -/// function won't have the possibility of failure - it will either identify a -/// valid signature or an invalid signature. -/// This code is borrowed from https://github.com/shiguredo/tls13-zig -/// which is licensed under the Apache License Version 2.0, January 2004 -/// http://www.apache.org/licenses/ -/// The code has been modified. -const rsa = struct { - const BigInt = std.math.big.int.Managed; - - const PublicKey = struct { - n: BigInt, - e: BigInt, - - pub fn deinit(self: *PublicKey) void { - self.n.deinit(); - self.e.deinit(); - } - - pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey { - var _n = try BigInt.init(allocator); - errdefer _n.deinit(); - try setBytes(&_n, modulus_bytes, allocator); - - var _e = try BigInt.init(allocator); - errdefer _e.deinit(); - try setBytes(&_e, pub_bytes, allocator); - - return .{ - .n = _n, - .e = _e, - }; - } - }; - - fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 { - var m = try BigInt.init(allocator); - defer m.deinit(); - - try setBytes(&m, &msg, allocator); - - if (m.order(public_key.n) != .lt) { - return error.MessageTooLong; - } - - var e = try BigInt.init(allocator); - defer e.deinit(); - - try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator); - - var res: [modulus_len]u8 = undefined; - - try toBytes(&res, &e, allocator); - - return res; - } - - fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void { - try r.set(0); - var tmp = try BigInt.init(allcator); - defer tmp.deinit(); - for (bytes) |b| { - try r.shiftLeft(r, 8); - try tmp.set(b); - try r.add(r, &tmp); - } - } - - fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { - var bin_raw: [512]u8 = undefined; - try toBytes(&bin_raw, x, allocator); - - var i: usize = 0; - while (bin_raw[i] == 0x00) : (i += 1) {} - const bin = bin_raw[i..]; - - try r.set(1); - var r1 = try BigInt.init(allocator); - defer r1.deinit(); - try BigInt.copy(&r1, a.toConst()); - i = 0; - while (i < bin.len * 8) : (i += 1) { - if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) { - try BigInt.mul(&r1, r, &r1); - try mod(&r1, &r1, n, allocator); - try BigInt.sqr(r, r); - try mod(r, r, n, allocator); - } else { - try BigInt.mul(r, r, &r1); - try mod(r, r, n, allocator); - try BigInt.sqr(&r1, &r1); - try mod(&r1, &r1, n, allocator); - } - } - } - - fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void { - const Error = error{ - BufferTooSmall, - }; - - var mask = try BigInt.initSet(allocator, 0xFF); - defer mask.deinit(); - var tmp = try BigInt.init(allocator); - defer tmp.deinit(); - - var a_copy = try BigInt.init(allocator); - defer a_copy.deinit(); - try a_copy.copy(a.toConst()); - - // Encoding into big-endian bytes - var i: usize = 0; - while (i < out.len) : (i += 1) { - try tmp.bitAnd(&a_copy, &mask); - const b = try tmp.to(u8); - out[out.len - i - 1] = b; - try a_copy.shiftRight(&a_copy, 8); - } - - if (!a_copy.eqZero()) { - return Error.BufferTooSmall; - } - } - - fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { - var q = try BigInt.init(allocator); - defer q.deinit(); - - try BigInt.divFloor(&q, rem, a, n); - } - - // TODO: flush the toilet - const poop = std.heap.page_allocator; -}; diff --git a/lib/std/crypto/Der.zig b/lib/std/crypto/der.zig similarity index 92% rename from lib/std/crypto/Der.zig rename to lib/std/crypto/der.zig index 7b183d5c34..82f75421ea 100644 --- a/lib/std/crypto/Der.zig +++ b/lib/std/crypto/der.zig @@ -99,8 +99,14 @@ pub const Oid = enum { pub const Element = struct { identifier: Identifier, - start: u32, - end: u32, + slice: Slice, + + pub const Slice = struct { + start: u32, + end: u32, + + pub const empty: Slice = .{ .start = 0, .end = 0 }; + }; }; pub const ParseElementError = error{CertificateHasFieldWithInvalidLength}; @@ -114,8 +120,10 @@ pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { if ((size_byte >> 7) == 0) { return .{ .identifier = identifier, - .start = i, - .end = i + size_byte, + .slice = .{ + .start = i, + .end = i + size_byte, + }, }; } @@ -132,8 +140,10 @@ pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { return .{ .identifier = identifier, - .start = i, - .end = i + long_form_size, + .slice = .{ + .start = i, + .end = i + long_form_size, + }, }; } @@ -145,9 +155,9 @@ pub const ParseObjectIdError = error{ pub fn parseObjectId(bytes: []const u8, element: Element) ParseObjectIdError!Oid { if (element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; - return Oid.map.get(bytes[element.start..element.end]) orelse + return Oid.map.get(bytes[element.slice.start..element.slice.end]) orelse return error.CertificateHasUnrecognizedObjectId; } const std = @import("../std.zig"); -const Der = @This(); +const der = @This(); diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index eb1b1b80bc..c8fd41f83a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,6 +1,5 @@ const std = @import("../../std.zig"); const tls = std.crypto.tls; -const Der = std.crypto.Der; const Client = @This(); const net = std.net; const mem = std.mem; @@ -18,7 +17,7 @@ const int2 = tls.int2; const int3 = tls.int3; const array = tls.array; const enum_array = tls.enum_array; -const Certificate = crypto.CertificateBundle.Certificate; +const Certificate = crypto.Certificate; application_cipher: ApplicationCipher, read_seq: u64, @@ -30,7 +29,7 @@ partially_read_len: u15, eof: bool, /// `host` is only borrowed during this function call. -pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []const u8) !Client { +pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client { const host_len = @intCast(u16, host.len); var random_buffer: [128]u8 = undefined; @@ -298,9 +297,19 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con break :i end; }; + // This is used for two purposes: + // * Detect whether a certificate is the first one presented, in which case + // we need to verify the host name. + // * Flip back and forth between the two cleartext buffers in order to keep + // the previous certificate in memory so that it can be verified by the + // next one. + var cert_index: usize = 0; var read_seq: u64 = 0; - var validated_cert = false; - var is_subsequent_cert = false; + var prev_cert: Certificate.Parsed = undefined; + // Set to true once a trust chain has been established from the first + // certificate to a root CA. + var cert_verification_done = false; + var cleartext_bufs: [2][8000]u8 = undefined; while (true) { const end_hdr = i + 5; @@ -328,7 +337,8 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; }, .application_data => { - var cleartext_buf: [8000]u8 = undefined; + const cleartext_buf = &cleartext_bufs[cert_index % 2]; + const cleartext = switch (handshake_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); @@ -393,7 +403,7 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } - if (validated_cert) break :cert; + if (cert_verification_done) break :cert; var hs_i: u32 = 0; const cert_req_ctx_len = handshake[hs_i]; hs_i += 1; @@ -411,12 +421,22 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con .index = hs_i, }; const subject = try subject_cert.parse(); - if (!is_subsequent_cert) { - is_subsequent_cert = true; - if (mem.eql(u8, subject.common_name, host)) { + if (cert_index > 0) { + if (prev_cert.verify(subject)) |_| { + std.debug.print("previous certificate verified\n", .{}); + } else |err| { + std.debug.print("unable to validate previous cert: {s}\n", .{ + @errorName(err), + }); + } + } else { + // Verify the host on the first certificate. + const common_name = subject.commonName(); + if (mem.eql(u8, common_name, host)) { std.debug.print("exact host match\n", .{}); - } else if (mem.startsWith(u8, subject.common_name, "*.") and - mem.eql(u8, subject.common_name[2..], host)) + } else if (mem.startsWith(u8, common_name, "*.") and + (mem.endsWith(u8, host, common_name[1..]) or + mem.eql(u8, common_name[2..], host))) { std.debug.print("wildcard host match\n", .{}); } else { @@ -427,17 +447,17 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con if (ca_bundle.verify(subject)) |_| { std.debug.print("found a root CA cert matching issuer. verification success!\n", .{}); - validated_cert = true; + cert_verification_done = true; break :cert; } else |err| { std.debug.print("unable to validate cert against system root CAs: {s}\n", .{ @errorName(err), }); - // TODO handle a certificate - // signing chain that ends in a - // root-validated one. } + prev_cert = subject; + cert_index += 1; + hs_i = end_cert; const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); hs_i += 2; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 58686ed2e5..1d10870312 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -7,7 +7,7 @@ const Client = @This(); allocator: std.mem.Allocator, headers: std.ArrayListUnmanaged(u8) = .{}, active_requests: usize = 0, -ca_bundle: std.crypto.CertificateBundle = .{}, +ca_bundle: std.crypto.Certificate.Bundle = .{}, pub const Request = struct { client: *Client,