From 40a85506b2e6a97af9c06bdcd001b6fd84cc549a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 15 Dec 2022 20:35:41 -0700 Subject: [PATCH] std.crypto.Tls: add read/write methods --- lib/std/crypto/Tls.zig | 560 ++++++++++++++++++++++++++++++---------- lib/std/crypto/sha2.zig | 22 ++ lib/std/http/Client.zig | 12 +- 3 files changed, 458 insertions(+), 136 deletions(-) diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index ea1bff9a08..6b5374512b 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -5,24 +5,24 @@ const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; -state: State = .start, -x25519_priv_key: [32]u8 = undefined, -x25519_pub_key: [32]u8 = undefined, -x25519_server_pub_key: [32]u8 = undefined, +application_cipher: ApplicationCipher, +read_seq: u64, +write_seq: u64, +/// The size is enough to contain exactly one TLSCiphertext record. +partially_read_buffer: [max_ciphertext_len + ciphertext_record_header_len]u8, +/// The number of partially read bytes inside `partiall_read_buffer`. +partially_read_len: u15, -const ProtocolVersion = enum(u16) { +pub const ciphertext_record_header_len = 5; +pub const max_ciphertext_len = (1 << 14) + 256; + +pub const ProtocolVersion = enum(u16) { tls_1_2 = 0x0303, tls_1_3 = 0x0304, _, }; -const State = enum { - /// In this state, all fields are undefined except state. - start, - sent_hello, -}; - -const ContentType = enum(u8) { +pub const ContentType = enum(u8) { invalid = 0, change_cipher_spec = 20, alert = 21, @@ -31,7 +31,7 @@ const ContentType = enum(u8) { _, }; -const HandshakeType = enum(u8) { +pub const HandshakeType = enum(u8) { client_hello = 1, server_hello = 2, new_session_ticket = 4, @@ -45,7 +45,7 @@ const HandshakeType = enum(u8) { message_hash = 254, }; -const ExtensionType = enum(u16) { +pub const ExtensionType = enum(u16) { /// RFC 6066 server_name = 0, /// RFC 6066 @@ -92,13 +92,13 @@ const ExtensionType = enum(u16) { key_share = 51, }; -const AlertLevel = enum(u8) { +pub const AlertLevel = enum(u8) { warning = 1, fatal = 2, _, }; -const AlertDescription = enum(u8) { +pub const AlertDescription = enum(u8) { close_notify = 0, unexpected_message = 10, bad_record_mac = 20, @@ -129,7 +129,7 @@ const AlertDescription = enum(u8) { _, }; -const SignatureScheme = enum(u16) { +pub const SignatureScheme = enum(u16) { // RSASSA-PKCS1-v1_5 algorithms rsa_pkcs1_sha256 = 0x0401, rsa_pkcs1_sha384 = 0x0501, @@ -161,7 +161,7 @@ const SignatureScheme = enum(u16) { _, }; -const NamedGroup = enum(u16) { +pub const NamedGroup = enum(u16) { // Elliptic Curve Groups (ECDHE) secp256r1 = 0x0017, secp384r1 = 0x0018, @@ -211,7 +211,7 @@ const NamedGroup = enum(u16) { // * ExtensionType extension_type; // * opaque extension_data<0..2^16-1>; -const CipherSuite = enum(u16) { +pub const CipherSuite = enum(u16) { TLS_AES_128_GCM_SHA256 = 0x1301, TLS_AES_256_GCM_SHA384 = 0x1302, TLS_CHACHA20_POLY1305_SHA256 = 0x1303, @@ -219,6 +219,73 @@ const CipherSuite = enum(u16) { TLS_AES_128_CCM_8_SHA256 = 0x1305, }; +pub const CipherParams = union(CipherSuite) { + TLS_AES_128_GCM_SHA256: struct { + const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + const Hash = crypto.hash.sha2.Sha256; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + handshake_secret: [Hkdf.key_len]u8, + master_secret: [Hkdf.key_len]u8, + client_handshake_key: [AEAD.key_length]u8, + server_handshake_key: [AEAD.key_length]u8, + client_finished_key: [Hmac.key_length]u8, + server_finished_key: [Hmac.key_length]u8, + client_handshake_iv: [AEAD.nonce_length]u8, + server_handshake_iv: [AEAD.nonce_length]u8, + transcript_hash: Hash, + }, + TLS_AES_256_GCM_SHA384: struct { + const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + const Hash = crypto.hash.sha2.Sha384; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + handshake_secret: [Hkdf.key_len]u8, + master_secret: [Hkdf.key_len]u8, + client_handshake_key: [AEAD.key_length]u8, + server_handshake_key: [AEAD.key_length]u8, + client_finished_key: [Hmac.key_length]u8, + server_finished_key: [Hmac.key_length]u8, + client_handshake_iv: [AEAD.nonce_length]u8, + server_handshake_iv: [AEAD.nonce_length]u8, + transcript_hash: Hash, + }, + TLS_CHACHA20_POLY1305_SHA256: void, + TLS_AES_128_CCM_SHA256: void, + TLS_AES_128_CCM_8_SHA256: void, +}; + +/// Encryption parameters for application traffic. +pub const ApplicationCipher = union(CipherSuite) { + TLS_AES_128_GCM_SHA256: struct { + const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + const Hash = crypto.hash.sha2.Sha256; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }, + TLS_AES_256_GCM_SHA384: struct { + const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + const Hash = crypto.hash.sha2.Sha384; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }, + TLS_CHACHA20_POLY1305_SHA256: void, + TLS_AES_128_CCM_SHA256: void, + TLS_AES_128_CCM_8_SHA256: void, +}; + const cipher_suites = blk: { const fields = @typeInfo(CipherSuite).Enum.fields; var result: [(fields.len + 1) * 2]u8 = undefined; @@ -231,10 +298,11 @@ const cipher_suites = blk: { break :blk result; }; -pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { - assert(tls.state == .start); - crypto.random.bytes(&tls.x25519_priv_key); - tls.x25519_pub_key = crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key) catch |err| { +/// `host` is only borrowed during this function call. +pub fn init(stream: net.Stream, host: []const u8) !Tls { + var x25519_priv_key: [32]u8 = undefined; + crypto.random.bytes(&x25519_priv_key); + const x25519_pub_key = crypto.dh.X25519.recoverPublicKey(x25519_priv_key) catch |err| { switch (err) { // Only possible to happen if the private key is all zeroes. error.IdentityElement => return error.InsufficientEntropy, @@ -293,7 +361,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { 0, 36, // byte length of client_shares 0x00, 0x1D, // NamedGroup.x25519 0, 32, // byte length of key_exchange - } ++ tls.x25519_pub_key ++ [_]u8{ + } ++ x25519_pub_key ++ [_]u8{ // Extension: server_name 0, 0, // ExtensionType.server_name @@ -330,25 +398,23 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { mem.writeIntBig(u16, hello_header[hello_header.len - 5 ..][0..2], @intCast(u16, 3 + host.len)); mem.writeIntBig(u16, hello_header[hello_header.len - 2 ..][0..2], @intCast(u16, 0 + host.len)); - var iovecs = [_]std.os.iovec_const{ - .{ - .iov_base = &hello_header, - .iov_len = hello_header.len, - }, - .{ - .iov_base = host.ptr, - .iov_len = host.len, - }, - }; - try stream.writevAll(&iovecs); + { + var iovecs = [_]std.os.iovec_const{ + .{ + .iov_base = &hello_header, + .iov_len = hello_header.len, + }, + .{ + .iov_base = host.ptr, + .iov_len = host.len, + }, + }; + try stream.writevAll(&iovecs); + } const client_hello_bytes1 = hello_header[5..]; - var client_handshake_key: [32]u8 = undefined; - var server_handshake_key: [32]u8 = undefined; - var client_handshake_iv: [12]u8 = undefined; - var server_handshake_iv: [12]u8 = undefined; - var cipher_suite: CipherSuite = undefined; + var cipher_params: CipherParams = undefined; var handshake_buf: [4000]u8 = undefined; var len: usize = 0; @@ -386,16 +452,16 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { const legacy_session_id_echo_len = hello[34]; if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); - cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch + const cipher_suite_tag = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch return error.TlsIllegalParameter; - std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); + std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite_tag)}); const legacy_compression_method = hello[37]; _ = legacy_compression_method; const extensions_size = mem.readIntBig(u16, hello[38..40]); if (40 + extensions_size != hello.len) return error.TlsBadLength; var i: usize = 40; var supported_version: u16 = 0; - var have_server_pub_key = false; + var opt_x25519_server_pub_key: ?*[32]u8 = null; while (i < hello.len) { const et = mem.readIntBig(u16, hello[i..][0..2]); i += 2; @@ -409,7 +475,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { supported_version = mem.readIntBig(u16, hello[i..][0..2]); }, @enumToInt(ExtensionType.key_share) => { - if (have_server_pub_key) return error.TlsIllegalParameter; + if (opt_x25519_server_pub_key != null) return error.TlsIllegalParameter; const named_group = mem.readIntBig(u16, hello[i..][0..2]); i += 2; switch (named_group) { @@ -417,8 +483,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { const key_size = mem.readIntBig(u16, hello[i..][0..2]); i += 2; if (key_size != 32) return error.TlsBadLength; - tls.x25519_server_pub_key = hello[i..][0..32].*; - have_server_pub_key = true; + opt_x25519_server_pub_key = hello[i..][0..32]; }, else => { std.debug.print("named group: {x}\n", .{named_group}); @@ -432,7 +497,8 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { } i = next_i; } - if (!have_server_pub_key) return error.TlsIllegalParameter; + const x25519_server_pub_key = opt_x25519_server_pub_key orelse + return error.TlsIllegalParameter; const tls_version = if (supported_version == 0) legacy_version else supported_version; switch (tls_version) { @enumToInt(ProtocolVersion.tls_1_2) => { @@ -445,28 +511,44 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { } const shared_key = crypto.dh.X25519.scalarmult( - tls.x25519_priv_key, - tls.x25519_server_pub_key, + x25519_priv_key, + x25519_server_pub_key.*, ) catch return error.TlsDecryptFailure; - switch (cipher_suite) { - .TLS_AES_128_GCM_SHA256 => { - const AEAD = crypto.aead.aes_gcm.Aes128Gcm; - const Hash = crypto.hash.sha2.Sha256; - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash); - const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length)); - const empty_hash = emptyHash(Hash); - const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); - const handshake_secret = Hkdf.extract(&derived_secret, &shared_key); - const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); - const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); - client_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length); - server_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length); - client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length); - server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length); + switch (cipher_suite_tag) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |tag| { + const P = std.meta.TagPayload(CipherParams, tag); + cipher_params = @unionInit(CipherParams, @tagName(tag), .{ + .handshake_secret = undefined, + .master_secret = undefined, + .client_handshake_key = undefined, + .server_handshake_key = undefined, + .client_finished_key = undefined, + .server_finished_key = undefined, + .client_handshake_iv = undefined, + .server_handshake_iv = undefined, + .transcript_hash = P.Hash.init(.{}), + }); + const p = &@field(cipher_params, @tagName(tag)); + p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 + p.transcript_hash.update(host); // Client Hello part 2 + p.transcript_hash.update(frag); // Server Hello + const hello_hash = p.transcript_hash.peek(); + const zeroes = [1]u8{0} ** P.Hash.digest_length; + const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = emptyHash(P.Hash); + const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); + p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, &shared_key); + const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); + p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); + const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); + p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); + p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); + p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\n", .{ // std.fmt.fmtSliceHexLower(&shared_key), // std.fmt.fmtSliceHexLower(&hello_hash), @@ -478,24 +560,6 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { // std.fmt.fmtSliceHexLower(&server_secret), //}); }, - .TLS_AES_256_GCM_SHA384 => { - const AEAD = crypto.aead.aes_gcm.Aes256Gcm; - const Hash = crypto.hash.sha2.Sha384; - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash); - const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length)); - const empty_hash = emptyHash(Hash); - const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); - const handshake_secret = Hkdf.extract(&derived_secret, &shared_key); - const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); - const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); - client_handshake_key = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length); - server_handshake_key = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length); - client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length); - server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length); - }, .TLS_CHACHA20_POLY1305_SHA256 => { @panic("TODO"); }, @@ -541,50 +605,24 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { }, .application_data => { var cleartext_buf: [1000]u8 = undefined; - const cleartext = switch (cipher_suite) { - .TLS_AES_128_GCM_SHA256 => c: { - const AEAD = crypto.aead.aes_gcm.Aes128Gcm; - const ciphertext_len = record_size - AEAD.tag_length; + const cleartext = switch (cipher_params) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { + const P = @TypeOf(p.*); + const ciphertext_len = record_size - P.AEAD.tag_length; const ciphertext = handshake_buf[i..][0..ciphertext_len]; i += ciphertext.len; if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; const cleartext = cleartext_buf[0..ciphertext.len]; - const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*; - const V = @Vector(AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (AEAD.nonce_length - 8); + const auth_tag = handshake_buf[i..][0..P.AEAD.tag_length].*; + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); read_seq += 1; - const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand; - //std.debug.print("seq: {d} nonce: {} operand: {}\n", .{ - // read_seq - 1, - // std.fmt.fmtSliceHexLower(&nonce), - // std.fmt.fmtSliceHexLower(&@as([12]u8, operand)), - //}); + const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_handshake_iv) ^ operand; const ad = handshake_buf[end_hdr - 5 ..][0..5]; - const key = server_handshake_key[0..AEAD.key_length].*; - AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch return error.TlsBadRecordMac; - - break :c cleartext; - }, - .TLS_AES_256_GCM_SHA384 => c: { - const AEAD = crypto.aead.aes_gcm.Aes256Gcm; - const ciphertext_len = record_size - AEAD.tag_length; - const ciphertext = handshake_buf[i..][0..ciphertext_len]; - i += ciphertext.len; - if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..ciphertext.len]; - const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*; - const V = @Vector(AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (AEAD.nonce_length - 8); - const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); - read_seq += 1; - const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand; - const ad = handshake_buf[end_hdr - 5 ..][0..5]; - const key = server_handshake_key[0..AEAD.key_length].*; - AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch - return error.TlsBadRecordMac; - + p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]); break :c cleartext; }, .TLS_CHACHA20_POLY1305_SHA256 => { @@ -611,6 +649,86 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { } std.debug.print("empty encrypted extensions\n", .{}); }, + @enumToInt(HandshakeType.certificate) => { + std.debug.print("cool certificate bro\n", .{}); + }, + @enumToInt(HandshakeType.certificate_verify) => { + std.debug.print("the certificate came with a fancy signature\n", .{}); + }, + @enumToInt(HandshakeType.finished) => { + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = [_]u8{ + @enumToInt(ContentType.change_cipher_spec), + 0x03, 0x03, // legacy protocol version + 0x00, 0x01, // length + 0x01, + }; + const app_cipher = switch (cipher_params) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: { + const P = @TypeOf(p.*); + // TODO verify the server's data + const handshake_hash = p.transcript_hash.finalResult(); + const verify_data = hmac(P.Hmac, &handshake_hash, p.client_finished_key); + const out_cleartext = [_]u8{ + @enumToInt(HandshakeType.finished), + 0, 0, verify_data.len + 1 + P.AEAD.tag_length, // length + } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)}; + + const wrapped_len = out_cleartext.len + P.AEAD.tag_length; + + var finished_msg = [_]u8{ + @enumToInt(ContentType.application_data), + 0x03, 0x03, // legacy protocol version + 0, wrapped_len, // byte length of encrypted record + } ++ ([1]u8{undefined} ** wrapped_len); + + const ad = finished_msg[0..5]; + const ciphertext = finished_msg[5..][0..out_cleartext.len]; + const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; + const nonce = p.client_handshake_iv; + P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + + { + var iovecs = [_]std.os.iovec_const{ + .{ + .iov_base = &client_change_cipher_spec_msg, + .iov_len = client_change_cipher_spec_msg.len, + }, + .{ + .iov_base = &finished_msg, + .iov_len = finished_msg.len, + }, + }; + try stream.writevAll(&iovecs); + } + + const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + break :c @unionInit(ApplicationCipher, @tagName(tag), .{ + .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), + .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), + .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), + .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), + }); + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + }; + return .{ + .application_cipher = app_cipher, + .read_seq = read_seq, + .write_seq = 1, + .partially_read_buffer = undefined, + .partially_read_len = 0, + }; + }, else => { std.debug.print("handshake type: {d}\n", .{cleartext[0]}); return error.TlsUnexpectedMessage; @@ -631,14 +749,185 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { i = end; } - tls.state = .sent_hello; + return error.TlsHandshakeFailure; } -pub fn writeAll(tls: *Tls, stream: net.Stream, buffer: []const u8) !void { - _ = tls; - _ = stream; - _ = buffer; - @panic("hold on a minute, we didn't finish implementing the handshake yet"); +pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { + var ciphertext_buf: [max_ciphertext_len * 4]u8 = undefined; + var iovecs_buf: [5]std.os.iovec_const = undefined; + var ciphertext_end: usize = 0; + var iovec_end: usize = 0; + var bytes_i: usize = 0; + switch (tls.application_cipher) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| { + const P = @TypeOf(p.*); + const V = @Vector(P.AEAD.nonce_length, u8); + while (true) { + const ciphertext_len = @intCast(u16, @min( + @min(bytes.len - bytes_i, max_ciphertext_len), + ciphertext_buf.len - 5 - P.AEAD.tag_length - ciphertext_end, + )); + if (ciphertext_len == 0) return bytes_i; + + const wrapped_len = ciphertext_len + P.AEAD.tag_length; + const record = ciphertext_buf[ciphertext_end..][0 .. 5 + wrapped_len]; + + const ad = record[0..5]; + ciphertext_end += 5; + const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; + ciphertext_end += ciphertext_len; + const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; + ciphertext_end += P.AEAD.tag_length; + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq)); + tls.write_seq += 1; + const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand; + ad.* = + [_]u8{@enumToInt(ContentType.application_data)} ++ + int2(@enumToInt(ProtocolVersion.tls_1_2)) ++ + int2(wrapped_len); + const cleartext = bytes[bytes_i..ciphertext.len]; + P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); + + iovecs_buf[iovec_end] = .{ + .iov_base = record.ptr, + .iov_len = record.len, + }; + iovec_end += 1; + + bytes_i += ciphertext_len; + } + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + } + + // Ideally we would call writev exactly once here, however, we must ensure + // that we don't return with a record partially written. + var i: usize = 0; + var total_amt: usize = 0; + while (true) { + var amt = try stream.writev(iovecs_buf[i..iovec_end]); + total_amt += amt; + while (amt >= iovecs_buf[i].iov_len) { + amt -= iovecs_buf[i].iov_len; + i += 1; + // Rely on the property that iovecs delineate records, meaning that + // if amt equals zero here, we have fortunately found ourselves + // with a short read that aligns at the record boundary. + if (i >= iovec_end or amt == 0) return total_amt; + } + iovecs_buf[i].iov_base += amt; + iovecs_buf[i].iov_len -= amt; + } +} + +pub fn writeAll(tls: *Tls, stream: net.Stream, bytes: []const u8) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try tls.write(stream, bytes[index..]); + } +} + +/// Returns number of bytes that have been read, which are now populated inside +/// `buffer`. A return value of zero bytes does not necessarily mean end of +/// stream. +pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { + const prev_len = tls.partially_read_len; + var in_buf: [max_ciphertext_len * 4]u8 = undefined; + mem.copy(u8, &in_buf, tls.partially_read_buffer[0..prev_len]); + + // Capacity of output buffer, in records, rounded up. + const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; + const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len); + const actual_read_len = try stream.read(in_buf[prev_len..@min(wanted_read_len, in_buf.len)]); + const frag = in_buf[0 .. prev_len + actual_read_len]; + var in: usize = 0; + var out: usize = 0; + + while (true) { + if (in + ciphertext_record_header_len > frag.len) { + return finishRead(tls, frag, in, out); + } + const ct = @intToEnum(ContentType, frag[in]); + in += 1; + const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); + in += 2; + _ = legacy_version; + const record_size = mem.readIntBig(u16, frag[in..][0..2]); + in += 2; + const end = in + record_size; + if (end > frag.len) { + if (record_size > max_ciphertext_len) return error.TlsRecordOverflow; + return finishRead(tls, frag, in, out); + } + switch (ct) { + .alert => { + @panic("TODO handle an alert here"); + }, + .application_data => { + const cleartext_len = switch (tls.application_cipher) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { + const P = @TypeOf(p.*); + const V = @Vector(P.AEAD.nonce_length, u8); + const ciphertext_len = record_size - P.AEAD.tag_length; + const ciphertext = frag[in..][0..ciphertext_len]; + in += ciphertext_len; + const auth_tag = frag[in..][0..P.AEAD.tag_length].*; + const cleartext = buffer[out..][0..ciphertext_len]; + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq)); + tls.read_seq += 1; + const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; + const ad = frag[0..ciphertext_record_header_len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch + return error.TlsBadRecordMac; + break :c cleartext.len; + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + }; + + const inner_ct = buffer[out + cleartext_len - 1]; + switch (inner_ct) { + @enumToInt(ContentType.handshake) => { + std.debug.print("the server wants to keep shaking hands\n", .{}); + }, + @enumToInt(ContentType.application_data) => { + out += cleartext_len - 1; + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + in = end; + } +} + +fn finishRead(tls: *Tls, frag: []const u8, in: usize, out: usize) usize { + const saved_buf = frag[in..]; + mem.copy(u8, &tls.partially_read_buffer, saved_buf); + tls.partially_read_len = @intCast(u15, saved_buf.len); + return out; } fn hkdfExpandLabel( @@ -674,13 +963,9 @@ fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 { return result; } -fn helloHash(s0: []const u8, s1: []const u8, s2: []const u8, comptime Hash: type) [Hash.digest_length]u8 { - var h = Hash.init(.{}); - h.update(s0); - h.update(s1); - h.update(s2); - var result: [Hash.digest_length]u8 = undefined; - h.final(&result); +fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) [Hmac.mac_length]u8 { + var result: [Hmac.mac_length]u8 = undefined; + Hmac.create(&result, message, &key); return result; } @@ -693,3 +978,10 @@ inline fn big(x: anytype) @TypeOf(x) { .Little => @byteSwap(x), }; } + +inline fn int2(x: u16) [2]u8 { + return .{ + @truncate(u8, x >> 8), + @truncate(u8, x), + }; +} diff --git a/lib/std/crypto/sha2.zig b/lib/std/crypto/sha2.zig index 9cdf8edcf1..217dea3723 100644 --- a/lib/std/crypto/sha2.zig +++ b/lib/std/crypto/sha2.zig @@ -142,6 +142,11 @@ fn Sha2x32(comptime params: Sha2Params32) type { d.total_len += b.len; } + pub fn peek(d: Self) [digest_length]u8 { + var copy = d; + return copy.finalResult(); + } + pub fn final(d: *Self, out: *[digest_length]u8) void { // The buffer here will never be completely full. mem.set(u8, d.buf[d.buf_len..], 0); @@ -175,6 +180,12 @@ fn Sha2x32(comptime params: Sha2Params32) type { } } + pub fn finalResult(d: *Self) [digest_length]u8 { + var result: [digest_length]u8 = undefined; + d.final(&result); + return result; + } + const W = [64]u32{ 0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5, 0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174, @@ -621,6 +632,11 @@ fn Sha2x64(comptime params: Sha2Params64) type { d.total_len += b.len; } + pub fn peek(d: Self) [digest_length]u8 { + var copy = d; + return copy.finalResult(); + } + pub fn final(d: *Self, out: *[digest_length]u8) void { // The buffer here will never be completely full. mem.set(u8, d.buf[d.buf_len..], 0); @@ -654,6 +670,12 @@ fn Sha2x64(comptime params: Sha2Params64) type { } } + pub fn finalResult(d: *Self) [digest_length]u8 { + var result: [digest_length]u8 = undefined; + d.final(&result); + return result; + } + fn round(d: *Self, b: *const [128]u8) void { var s: [80]u64 = undefined; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index b10011a6b1..e7b056830a 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -12,7 +12,7 @@ pub const Request = struct { client: *Client, stream: net.Stream, headers: std.ArrayListUnmanaged(u8) = .{}, - tls: std.crypto.Tls = .{}, + tls: std.crypto.Tls, protocol: Protocol, pub const Protocol = enum { http, https }; @@ -55,6 +55,13 @@ pub const Request = struct { }, } } + + pub fn read(req: *Request, buffer: []u8) !usize { + switch (req.protocol) { + .http => return req.stream.read(buffer), + .https => return req.tls.read(req.stream, buffer), + } + } }; pub fn deinit(client: *Client) void { @@ -68,6 +75,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { .client = client, .stream = try net.tcpConnectToHost(client.allocator, options.host, options.port), .protocol = options.protocol, + .tls = undefined, }; client.active_requests += 1; errdefer req.deinit(); @@ -75,7 +83,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { switch (options.protocol) { .http => {}, .https => { - try req.tls.init(req.stream, options.host); + req.tls = try std.crypto.Tls.init(req.stream, options.host); }, }