From 5d7eca6669228cec762fc9063a7ea3cb52af357c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 18 Dec 2022 21:32:15 -0700 Subject: [PATCH] std.crypto.tls.Client: fix verify_data for batched handshakes --- lib/std/crypto/tls.zig | 7 ++- lib/std/crypto/tls/Client.zig | 92 +++++++++++++++++++---------------- 2 files changed, 56 insertions(+), 43 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index d2a347e87e..09a22f9a23 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -221,6 +221,12 @@ pub const CipherSuite = enum(u16) { _, }; +pub const CertificateType = enum(u8) { + X509 = 0, + RawPublicKey = 2, + _, +}; + pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type { return struct { pub const AEAD = AeadType; @@ -237,7 +243,6 @@ pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type { client_handshake_iv: [AEAD.nonce_length]u8, server_handshake_iv: [AEAD.nonce_length]u8, transcript_hash: Hash, - finished_digest: [Hash.digest_length]u8, }; } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 6d6e0754da..7fb96ff00c 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -62,12 +62,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha512, .ed25519, - .ed448, - .rsa_pss_pss_sha256, - .rsa_pss_pss_sha384, - .rsa_pss_pss_sha512, - .rsa_pkcs1_sha1, - .ecdsa_sha1, })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ .secp256r1, .x25519, @@ -98,24 +92,21 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { int2(legacy_compression_methods) ++ extensions_header; - const handshake = + const out_handshake = [_]u8{@enumToInt(HandshakeType.client_hello)} ++ int3(@intCast(u24, client_hello.len + host_len)) ++ client_hello; - const hello_header = [_]u8{ - // Plaintext header + const plaintext_header = [_]u8{ @enumToInt(ContentType.handshake), 0x03, 0x01, // legacy_record_version - } ++ - int2(@intCast(u16, handshake.len + host_len)) ++ - handshake; + } ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake; { var iovecs = [_]std.os.iovec_const{ .{ - .iov_base = &hello_header, - .iov_len = hello_header.len, + .iov_base = &plaintext_header, + .iov_len = plaintext_header.len, }, .{ .iov_base = host.ptr, @@ -125,7 +116,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { try stream.writevAll(&iovecs); } - const client_hello_bytes1 = hello_header[5..]; + const client_hello_bytes1 = plaintext_header[5..]; var cipher_params: CipherParams = undefined; @@ -176,7 +167,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]); i += 2; const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int); - std.debug.print("server wants cipher suite {any}\n", .{cipher_suite_tag}); const legacy_compression_method = frag[i]; i += 1; _ = legacy_compression_method; @@ -243,12 +233,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { if (!have_shared_key) return error.TlsIllegalParameter; const tls_version = if (supported_version == 0) legacy_version else supported_version; switch (tls_version) { - @enumToInt(tls.ProtocolVersion.tls_1_2) => { - std.debug.print("server wants TLS v1.2\n", .{}); - }, - @enumToInt(tls.ProtocolVersion.tls_1_3) => { - std.debug.print("server wants TLS v1.3\n", .{}); - }, + @enumToInt(tls.ProtocolVersion.tls_1_3) => {}, else => return error.TlsIllegalParameter, } @@ -270,7 +255,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .client_handshake_iv = undefined, .server_handshake_iv = undefined, .transcript_hash = P.Hash.init(.{}), - .finished_digest = undefined, }); const p = &@field(cipher_params, @tagName(tag)); p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 @@ -361,7 +345,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const ad = handshake_buf[end_hdr - 5 ..][0..5]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch return error.TlsBadRecordMac; - p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]); break :c cleartext; }, }; @@ -378,17 +361,22 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const next_handshake_i = ct_i + handshake_len; if (next_handshake_i > cleartext.len - 1) return error.TlsBadLength; + const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i]; + const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { @enumToInt(HandshakeType.encrypted_extensions) => { - const total_ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]); - ct_i += 2; - const end_ext_i = ct_i + total_ext_size; - while (ct_i < end_ext_i) { - const et = mem.readIntBig(u16, cleartext[ct_i..][0..2]); - ct_i += 2; - const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]); - ct_i += 2; - const next_ext_i = ct_i + ext_size; + switch (cipher_params) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + const total_ext_size = mem.readIntBig(u16, handshake[0..2]); + var hs_i: usize = 2; + const end_ext_i = 2 + total_ext_size; + while (hs_i < end_ext_i) { + const et = mem.readIntBig(u16, handshake[hs_i..][0..2]); + hs_i += 2; + const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); + hs_i += 2; + const next_ext_i = hs_i + ext_size; switch (et) { @enumToInt(tls.ExtensionType.server_name) => {}, else => { @@ -397,19 +385,38 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { }); }, } - ct_i = next_ext_i; + hs_i = next_ext_i; } }, @enumToInt(HandshakeType.certificate) => { - std.debug.print("cool certificate bro\n", .{}); + switch (cipher_params) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + var hs_i: usize = 0; + const cert_req_ctx_len = handshake[hs_i]; + hs_i += 1; + if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; + const certs_size = mem.readIntBig(u24, handshake[hs_i..][0..3]); + hs_i += 3; + const end_certs = hs_i + certs_size; + while (hs_i < end_certs) { + const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]); + hs_i += 3; + hs_i += cert_size; + const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); + hs_i += 2; + hs_i += total_ext_size; + + std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions\n", .{ + cert_size, total_ext_size, + }); + } }, @enumToInt(HandshakeType.certificate_verify) => { - std.debug.print("the certificate came with a fancy signature\n", .{}); switch (cipher_params) { - inline else => |*p| { - p.finished_digest = p.transcript_hash.peek(); - }, + inline else => |*p| p.transcript_hash.update(wrapped_handshake), } + std.debug.print("ignoring certificate_verify\n", .{}); }, @enumToInt(HandshakeType.finished) => { // This message is to trick buggy proxies into behaving correctly. @@ -422,9 +429,10 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const app_cipher = switch (cipher_params) { inline else => |*p, tag| c: { const P = @TypeOf(p.*); - const expected_server_verify_data = tls.hmac(P.Hmac, &p.finished_digest, p.server_finished_key); - const actual_server_verify_data = cleartext[ct_i..][0..handshake_len]; - if (!mem.eql(u8, &expected_server_verify_data, actual_server_verify_data)) + const finished_digest = p.transcript_hash.peek(); + p.transcript_hash.update(wrapped_handshake); + const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); + if (!mem.eql(u8, &expected_server_verify_data, handshake)) return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);