From 462b3ed69c20ea5dcae1660761012b3d5fa91367 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 16 Dec 2022 13:16:53 -0700 Subject: [PATCH] std.crypto.Tls: handshake fixes * Handle multiple handshakes in one encrypted record * Fix incorrect handshake length sent to server --- lib/std/crypto/Tls.zig | 203 ++++++++++++++++++++++------------------- 1 file changed, 109 insertions(+), 94 deletions(-) diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index 19bd3442cf..4ea64f1be9 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -17,6 +17,10 @@ eof: bool, pub const ciphertext_record_header_len = 5; pub const max_ciphertext_len = (1 << 14) + 256; pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len; +pub const hello_retry_request_sequence = [32]u8{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +}; pub const ProtocolVersion = enum(u16) { tls_1_2 = 0x0303, @@ -450,7 +454,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const hello = frag[4..]; const legacy_version = mem.readIntBig(u16, hello[0..2]); const random = hello[2..34].*; - _ = random; + if (mem.eql(u8, &random, &hello_retry_request_sequence)) { + @panic("TODO handle HelloRetryRequest"); + } const legacy_session_id_echo_len = hello[34]; if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); @@ -551,7 +557,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\n", .{ + //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\nclient_handshake_iv: {}\nserver_handshake_iv: {}\n", .{ // std.fmt.fmtSliceHexLower(&shared_key), // std.fmt.fmtSliceHexLower(&hello_hash), // std.fmt.fmtSliceHexLower(&early_secret), @@ -560,6 +566,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { // std.fmt.fmtSliceHexLower(&p.handshake_secret), // std.fmt.fmtSliceHexLower(&client_secret), // std.fmt.fmtSliceHexLower(&server_secret), + // std.fmt.fmtSliceHexLower(&p.client_handshake_iv), + // std.fmt.fmtSliceHexLower(&p.server_handshake_iv), //}); }, .TLS_CHACHA20_POLY1305_SHA256 => { @@ -643,106 +651,113 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { }, }; - const inner_ct = cleartext[cleartext.len - 1]; - std.debug.print("inner_ct={any}\n", .{@intToEnum(ContentType, inner_ct)}); + const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { - @enumToInt(ContentType.handshake) => { - const handshake_len = mem.readIntBig(u24, cleartext[1..4]); - if (4 + handshake_len > cleartext.len - 1) return error.TlsBadLength; - std.debug.print("handshake type: {any} size: {d}\n", .{ @intToEnum(HandshakeType, cleartext[0]), handshake_len }); - switch (cleartext[0]) { - @enumToInt(HandshakeType.encrypted_extensions) => { - const ext_size = mem.readIntBig(u16, cleartext[4..6]); - std.debug.print("{d} bytes of encrypted extensions\n", .{ - ext_size, - }); - }, - @enumToInt(HandshakeType.certificate) => { - std.debug.print("cool certificate bro\n", .{}); - }, - @enumToInt(HandshakeType.certificate_verify) => { - std.debug.print("the certificate came with a fancy signature\n", .{}); - }, - @enumToInt(HandshakeType.finished) => { - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = [_]u8{ - @enumToInt(ContentType.change_cipher_spec), - 0x03, 0x03, // legacy protocol version - 0x00, 0x01, // length - 0x01, - }; - const app_cipher = switch (cipher_params) { - inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: { - const P = @TypeOf(p.*); - // TODO verify the server's data - const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = hmac(P.Hmac, &handshake_hash, p.client_finished_key); - const out_cleartext = [_]u8{ - @enumToInt(HandshakeType.finished), - 0, 0, verify_data.len + 1 + P.AEAD.tag_length, // length - } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)}; + .handshake => { + var ct_i: usize = 0; + while (true) { + const handshake_type = cleartext[ct_i]; + ct_i += 1; + const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len - 1) + return error.TlsBadLength; + switch (handshake_type) { + @enumToInt(HandshakeType.encrypted_extensions) => { + const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]); + ct_i += 2; + std.debug.print("{d} bytes of encrypted extensions\n", .{ + ext_size, + }); + }, + @enumToInt(HandshakeType.certificate) => { + std.debug.print("cool certificate bro\n", .{}); + }, + @enumToInt(HandshakeType.certificate_verify) => { + std.debug.print("the certificate came with a fancy signature\n", .{}); + }, + @enumToInt(HandshakeType.finished) => { + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = [_]u8{ + @enumToInt(ContentType.change_cipher_spec), + 0x03, 0x03, // legacy protocol version + 0x00, 0x01, // length + 0x01, + }; + const app_cipher = switch (cipher_params) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: { + const P = @TypeOf(p.*); + // TODO verify the server's data + const handshake_hash = p.transcript_hash.finalResult(); + const verify_data = hmac(P.Hmac, &handshake_hash, p.client_finished_key); + const out_cleartext = [_]u8{ + @enumToInt(HandshakeType.finished), + 0, 0, verify_data.len, // length + } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)}; - const wrapped_len = out_cleartext.len + P.AEAD.tag_length; + const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - var finished_msg = [_]u8{ - @enumToInt(ContentType.application_data), - 0x03, 0x03, // legacy protocol version - 0, wrapped_len, // byte length of encrypted record - } ++ ([1]u8{undefined} ** wrapped_len); + var finished_msg = [_]u8{ + @enumToInt(ContentType.application_data), + 0x03, 0x03, // legacy protocol version + 0, wrapped_len, // byte length of encrypted record + } ++ ([1]u8{undefined} ** wrapped_len); - const ad = finished_msg[0..5]; - const ciphertext = finished_msg[5..][0..out_cleartext.len]; - const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - const nonce = p.client_handshake_iv; - P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + const ad = finished_msg[0..5]; + const ciphertext = finished_msg[5..][0..out_cleartext.len]; + const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; + const nonce = p.client_handshake_iv; + P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - //const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - _ = client_change_cipher_spec_msg; - const both_msgs = finished_msg; - try stream.writeAll(&both_msgs); + const both_msgs = client_change_cipher_spec_msg ++ finished_msg; + try stream.writeAll(&both_msgs); - const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{ - // std.fmt.fmtSliceHexLower(&p.master_secret), - // std.fmt.fmtSliceHexLower(&client_secret), - // std.fmt.fmtSliceHexLower(&server_secret), - //}); - break :c @unionInit(ApplicationCipher, @tagName(tag), .{ - .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), - .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), - .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), - .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - }); - }, - .TLS_CHACHA20_POLY1305_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_8_SHA256 => { - @panic("TODO"); - }, - }; - std.debug.print("remaining bytes: {d}\n", .{len - end}); - return .{ - .application_cipher = app_cipher, - .read_seq = 0, - .write_seq = 0, - .partially_read_buffer = undefined, - .partially_read_len = 0, - .eof = false, - }; - }, - else => { - std.debug.print("handshake type: {d}\n", .{cleartext[0]}); - return error.TlsUnexpectedMessage; - }, + const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{ + // std.fmt.fmtSliceHexLower(&p.master_secret), + // std.fmt.fmtSliceHexLower(&client_secret), + // std.fmt.fmtSliceHexLower(&server_secret), + //}); + break :c @unionInit(ApplicationCipher, @tagName(tag), .{ + .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), + .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), + .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), + .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), + }); + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + }; + std.debug.print("remaining bytes: {d}\n", .{len - end}); + return .{ + .application_cipher = app_cipher, + .read_seq = 0, + .write_seq = 0, + .partially_read_buffer = undefined, + .partially_read_len = 0, + .eof = false, + }; + }, + else => { + std.debug.print("handshake type: {d}\n", .{cleartext[0]}); + return error.TlsUnexpectedMessage; + }, + } + ct_i = next_handshake_i; + if (ct_i >= cleartext.len - 1) break; } }, else => { - std.debug.print("inner content type: {d}\n", .{inner_ct}); + std.debug.print("inner content type: {any}\n", .{inner_ct}); return error.TlsUnexpectedMessage; }, } @@ -803,7 +818,7 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { tls.write_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); - //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}", .{ + //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}\n", .{ // tls.write_seq - 1, // std.fmt.fmtSliceHexLower(&nonce), // std.fmt.fmtSliceHexLower(&p.client_key),