diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 6d9a75dc22..2a0d49ca69 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -214,158 +214,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In try stream.writevAll(&iovecs); } - const client_hello_bytes1 = cleartext_header[tls.record_header_len..]; - var tls_version: tls.ProtocolVersion = undefined; - var cipher_suite_tag: tls.CipherSuite = undefined; - var handshake_cipher: tls.HandshakeCipher = undefined; - var handshake_buffer: [8000]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; - { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const ct = d.decode(tls.ContentType); - d.skip(2); // legacy_record_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - const server_hello_fragment = d.buf[d.idx..][0..record_len]; - var ptd = try d.sub(record_len); - switch (ct) { - .alert => { - try ptd.ensure(2); - const level = ptd.decode(tls.AlertLevel); - const desc = ptd.decode(tls.AlertDescription); - _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - try ptd.ensure(4); - const handshake_type = ptd.decode(tls.HandshakeType); - if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; - const length = ptd.decode(u24); - var hsd = try ptd.sub(length); - try hsd.ensure(2 + 32 + 1); - const legacy_version = hsd.decode(u16); - @memcpy(&server_hello_rand, hsd.array(32)); - if (mem.eql(u8, &server_hello_rand, &tls.hello_retry_request_sequence)) { - // This is a HelloRetryRequest message. This client implementation - // does not expect to get one. - return error.TlsUnexpectedMessage; - } - const legacy_session_id_echo_len = hsd.decode(u8); - try hsd.ensure(legacy_session_id_echo_len + 2 + 1); - const legacy_session_id_echo = hsd.slice(legacy_session_id_echo_len); - cipher_suite_tag = hsd.decode(tls.CipherSuite); - hsd.skip(1); // legacy_compression_method - var supported_version: ?u16 = null; - if (!hsd.eof()) { - try hsd.ensure(2); - const extensions_size = hsd.decode(u16); - var all_extd = try hsd.sub(extensions_size); - while (!all_extd.eof()) { - try all_extd.ensure(2 + 2); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - var extd = try all_extd.sub(ext_size); - switch (et) { - .supported_versions => { - if (supported_version) |_| return error.TlsIllegalParameter; - try extd.ensure(2); - supported_version = extd.decode(u16); - }, - .key_share => { - if (key_share.getSharedSecret()) |_| return error.TlsIllegalParameter; - try extd.ensure(4); - const named_group = extd.decode(tls.NamedGroup); - const key_size = extd.decode(u16); - try extd.ensure(key_size); - try key_share.exchange(named_group, extd.slice(key_size)); - }, - else => {}, - } - } - } - - tls_version = @enumFromInt(supported_version orelse legacy_version); - switch (tls_version) { - .tls_1_3 => if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter, - .tls_1_2 => if (mem.eql(u8, server_hello_rand[24..31], "DOWNGRD") and - server_hello_rand[31] >> 1 == 0x00) return error.TlsIllegalParameter, - else => return error.TlsIllegalParameter, - } - - switch (cipher_suite_tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_256_SHA512, - .AEGIS_128L_SHA256, - - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - => |tag| { - handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag.with()), .{ - .transcript_hash = .init(.{}), - .version = undefined, - }); - const p = &@field(handshake_cipher, @tagName(tag.with())); - p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 - p.transcript_hash.update(host); // Client Hello part 2 - p.transcript_hash.update(server_hello_fragment); - }, - - else => return error.TlsIllegalParameter, - } - switch (tls_version) { - .tls_1_3 => switch (cipher_suite_tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_256_SHA512, - .AEGIS_128L_SHA256, - => |tag| { - const sk = key_share.getSharedSecret() orelse return error.TlsIllegalParameter; - const p = &@field(handshake_cipher, @tagName(tag.with())); - const P = @TypeOf(p.*).A; - const hello_hash = p.transcript_hash.peek(); - const zeroes = [1]u8{0} ** P.Hash.digest_length; - const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = tls.emptyHash(P.Hash); - p.version = .{ .tls_1_3 = undefined }; - const pv = &p.version.tls_1_3; - const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); - pv.handshake_secret = P.Hkdf.extract(&hs_derived_secret, sk); - const ap_derived_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); - pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); - const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); - pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); - pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - pv.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - pv.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - pv.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - else => return error.TlsIllegalParameter, - }, - .tls_1_2 => switch (cipher_suite_tag) { - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - => {}, - else => return error.TlsIllegalParameter, - }, - else => return error.TlsIllegalParameter, - } - }, - else => return error.TlsUnexpectedMessage, - } - } - // This is used for two purposes: // * Detect whether a certificate is the first one presented, in which case // we need to verify the host name. @@ -384,13 +233,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In /// Application cipher is in use application, }; - var pending_cipher_state: CipherState = switch (tls_version) { - .tls_1_3 => .handshake, - .tls_1_2 => .cleartext, - else => unreachable, - }; - var cipher_state: CipherState = .cleartext; + var pending_cipher_state: CipherState = .cleartext; + var cipher_state = pending_cipher_state; const HandshakeState = enum { + /// In this state we expect only a server hello message. + hello, /// In this state we expect only an encrypted_extensions message. encrypted_extensions, /// In this state we expect certificate handshake messages. @@ -404,15 +251,14 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In /// In this state, we expect only the finished handshake message. finished, }; - var handshake_state: HandshakeState = switch (tls_version) { - .tls_1_3 => .encrypted_extensions, - .tls_1_2 => .certificate, - else => unreachable, - }; - var cleartext_bufs: [2][8000]u8 = undefined; + var handshake_state: HandshakeState = .hello; + var handshake_cipher: tls.HandshakeCipher = undefined; var main_cert_pub_key: CertificatePublicKey = undefined; const now_sec = std.time.timestamp(); + var cleartext_bufs: [2][8000]u8 = undefined; + var handshake_buffer: [8000]u8 = undefined; + var d: tls.Decoder = .{ .buf = &handshake_buffer }; while (true) { try d.readAtLeastOurAmt(stream, tls.record_header_len); const record_header = d.buf[d.idx..][0..tls.record_header_len]; @@ -526,11 +372,132 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In var hsd = try ctd.sub(handshake_len); const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; switch (handshake_type) { + .server_hello => { + if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; + if (handshake_state != .hello) return error.TlsUnexpectedMessage; + try hsd.ensure(2 + 32 + 1); + const legacy_version = hsd.decode(u16); + @memcpy(&server_hello_rand, hsd.array(32)); + if (mem.eql(u8, &server_hello_rand, &tls.hello_retry_request_sequence)) { + // This is a HelloRetryRequest message. This client implementation + // does not expect to get one. + return error.TlsUnexpectedMessage; + } + const legacy_session_id_echo_len = hsd.decode(u8); + try hsd.ensure(legacy_session_id_echo_len + 2 + 1); + const legacy_session_id_echo = hsd.slice(legacy_session_id_echo_len); + const cipher_suite_tag = hsd.decode(tls.CipherSuite); + hsd.skip(1); // legacy_compression_method + var supported_version: ?u16 = null; + if (!hsd.eof()) { + try hsd.ensure(2); + const extensions_size = hsd.decode(u16); + var all_extd = try hsd.sub(extensions_size); + while (!all_extd.eof()) { + try all_extd.ensure(2 + 2); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + var extd = try all_extd.sub(ext_size); + switch (et) { + .supported_versions => { + if (supported_version) |_| return error.TlsIllegalParameter; + try extd.ensure(2); + supported_version = extd.decode(u16); + }, + .key_share => { + if (key_share.getSharedSecret()) |_| return error.TlsIllegalParameter; + try extd.ensure(4); + const named_group = extd.decode(tls.NamedGroup); + const key_size = extd.decode(u16); + try extd.ensure(key_size); + try key_share.exchange(named_group, extd.slice(key_size)); + }, + else => {}, + } + } + } + + tls_version = @enumFromInt(supported_version orelse legacy_version); + switch (tls_version) { + .tls_1_3 => if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter, + .tls_1_2 => if (mem.eql(u8, server_hello_rand[24..31], "DOWNGRD") and + server_hello_rand[31] >> 1 == 0x00) return error.TlsIllegalParameter, + else => return error.TlsIllegalParameter, + } + + switch (cipher_suite_tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_256_SHA512, + .AEGIS_128L_SHA256, + + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => |tag| { + handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag.with()), .{ + .transcript_hash = .init(.{}), + .version = undefined, + }); + const p = &@field(handshake_cipher, @tagName(tag.with())); + p.transcript_hash.update(cleartext_header[tls.record_header_len..]); // Client Hello part 1 + p.transcript_hash.update(host); // Client Hello part 2 + p.transcript_hash.update(wrapped_handshake); + }, + + else => return error.TlsIllegalParameter, + } + switch (tls_version) { + .tls_1_3 => { + switch (cipher_suite_tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_256_SHA512, + .AEGIS_128L_SHA256, + => |tag| { + const sk = key_share.getSharedSecret() orelse return error.TlsIllegalParameter; + const p = &@field(handshake_cipher, @tagName(tag.with())); + const P = @TypeOf(p.*).A; + const hello_hash = p.transcript_hash.peek(); + const zeroes = [1]u8{0} ** P.Hash.digest_length; + const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = tls.emptyHash(P.Hash); + p.version = .{ .tls_1_3 = undefined }; + const pv = &p.version.tls_1_3; + const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); + pv.handshake_secret = P.Hkdf.extract(&hs_derived_secret, sk); + const ap_derived_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); + pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); + pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); + pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); + pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + pv.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, + else => return error.TlsIllegalParameter, + } + pending_cipher_state = .handshake; + handshake_state = .encrypted_extensions; + }, + .tls_1_2 => switch (cipher_suite_tag) { + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => handshake_state = .certificate, + else => return error.TlsIllegalParameter, + }, + else => return error.TlsIllegalParameter, + } + }, .encrypted_extensions => { if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage; if (cipher_state != .handshake) return error.TlsUnexpectedMessage; if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - handshake_state = .certificate; switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } @@ -548,16 +515,18 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In else => {}, } } + handshake_state = .certificate; }, .certificate => cert: { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } + if (cipher_state == .application) return error.TlsUnexpectedMessage; switch (handshake_state) { .certificate => {}, .trust_chain_established => break :cert, else => return error.TlsUnexpectedMessage, } + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } switch (tls_version) { .tls_1_3 => { @@ -614,7 +583,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; switch (handshake_state) { - .trust_chain_established => handshake_state = .server_hello_done, + .trust_chain_established => {}, .certificate => return error.TlsCertificateNotVerified, else => return error.TlsUnexpectedMessage, } @@ -631,12 +600,12 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const server_pub_key = hsd.slice(key_size); try main_cert_pub_key.verifySignature(&hsd, &.{ &client_hello_rand, &server_hello_rand, hsd.buf[0..hsd.idx] }); try key_share.exchange(named_group, server_pub_key); + handshake_state = .server_hello_done; }, .server_hello_done => { if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; if (handshake_state != .server_hello_done) return error.TlsUnexpectedMessage; - handshake_state = .finished; const client_key_exchange_msg = .{@intFromEnum(tls.ContentType.handshake)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ @@ -680,7 +649,6 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In .app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block), } }; const pv = &p.version.tls_1_2; - pending_cipher_state = .application; const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) nonce: { @@ -715,12 +683,14 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In }, } write_seq += 1; + pending_cipher_state = .application; + handshake_state = .finished; }, .certificate_verify => { if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage; if (cipher_state != .handshake) return error.TlsUnexpectedMessage; switch (handshake_state) { - .trust_chain_established => handshake_state = .finished, + .trust_chain_established => {}, .certificate => return error.TlsCertificateNotVerified, else => return error.TlsUnexpectedMessage, } @@ -733,6 +703,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In p.transcript_hash.update(wrapped_handshake); }, } + handshake_state = .finished; }, .finished => { if (cipher_state == .cleartext) return error.TlsUnexpectedMessage;