diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index 65f54ffa68..ea1bff9a08 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -234,7 +234,12 @@ const cipher_suites = blk: { pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { assert(tls.state == .start); crypto.random.bytes(&tls.x25519_priv_key); - tls.x25519_pub_key = try crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key); + tls.x25519_pub_key = crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key) catch |err| { + switch (err) { + // Only possible to happen if the private key is all zeroes. + error.IdentityElement => return error.InsufficientEntropy, + } + }; // random (u32) var rand_buf: [32]u8 = undefined; @@ -337,6 +342,14 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { }; try stream.writevAll(&iovecs); + const client_hello_bytes1 = hello_header[5..]; + + var client_handshake_key: [32]u8 = undefined; + var server_handshake_key: [32]u8 = undefined; + var client_handshake_iv: [12]u8 = undefined; + var server_handshake_iv: [12]u8 = undefined; + var cipher_suite: CipherSuite = undefined; + var handshake_buf: [4000]u8 = undefined; var len: usize = 0; var i: usize = i: { @@ -373,7 +386,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { const legacy_session_id_echo_len = hello[34]; if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); - const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch + cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch return error.TlsIllegalParameter; std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); const legacy_compression_method = hello[37]; @@ -404,12 +417,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { const key_size = mem.readIntBig(u16, hello[i..][0..2]); i += 2; if (key_size != 32) return error.TlsBadLength; - const encrypted_key = hello[i..][0..32].*; - const server_pub_key = try crypto.dh.X25519.scalarmult( - tls.x25519_priv_key, - encrypted_key, - ); - tls.x25519_server_pub_key = server_pub_key; + tls.x25519_server_pub_key = hello[i..][0..32].*; have_server_pub_key = true; }, else => { @@ -435,12 +443,77 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { }, else => return error.TlsIllegalParameter, } + + const shared_key = crypto.dh.X25519.scalarmult( + tls.x25519_priv_key, + tls.x25519_server_pub_key, + ) catch return error.TlsDecryptFailure; + + switch (cipher_suite) { + .TLS_AES_128_GCM_SHA256 => { + const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + const Hash = crypto.hash.sha2.Sha256; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash); + const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length)); + const empty_hash = emptyHash(Hash); + const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); + const handshake_secret = Hkdf.extract(&derived_secret, &shared_key); + const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); + client_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length); + server_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length); + client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length); + server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length); + //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\n", .{ + // std.fmt.fmtSliceHexLower(&shared_key), + // std.fmt.fmtSliceHexLower(&hello_hash), + // std.fmt.fmtSliceHexLower(&early_secret), + // std.fmt.fmtSliceHexLower(&empty_hash), + // std.fmt.fmtSliceHexLower(&derived_secret), + // std.fmt.fmtSliceHexLower(&handshake_secret), + // std.fmt.fmtSliceHexLower(&client_secret), + // std.fmt.fmtSliceHexLower(&server_secret), + //}); + }, + .TLS_AES_256_GCM_SHA384 => { + const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + const Hash = crypto.hash.sha2.Sha384; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash); + const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length)); + const empty_hash = emptyHash(Hash); + const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); + const handshake_secret = Hkdf.extract(&derived_secret, &shared_key); + const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); + client_handshake_key = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length); + server_handshake_key = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length); + client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length); + server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length); + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + } }, else => return error.TlsUnexpectedMessage, } break :i end; }; + var read_seq: u64 = 0; + while (true) { const end_hdr = i + 5; if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; @@ -467,7 +540,88 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; }, .application_data => { - std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size}); + var cleartext_buf: [1000]u8 = undefined; + const cleartext = switch (cipher_suite) { + .TLS_AES_128_GCM_SHA256 => c: { + const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + const ciphertext_len = record_size - AEAD.tag_length; + const ciphertext = handshake_buf[i..][0..ciphertext_len]; + i += ciphertext.len; + if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_buf[0..ciphertext.len]; + const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*; + const V = @Vector(AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); + read_seq += 1; + const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand; + //std.debug.print("seq: {d} nonce: {} operand: {}\n", .{ + // read_seq - 1, + // std.fmt.fmtSliceHexLower(&nonce), + // std.fmt.fmtSliceHexLower(&@as([12]u8, operand)), + //}); + const ad = handshake_buf[end_hdr - 5 ..][0..5]; + const key = server_handshake_key[0..AEAD.key_length].*; + AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch + return error.TlsBadRecordMac; + + break :c cleartext; + }, + .TLS_AES_256_GCM_SHA384 => c: { + const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + const ciphertext_len = record_size - AEAD.tag_length; + const ciphertext = handshake_buf[i..][0..ciphertext_len]; + i += ciphertext.len; + if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_buf[0..ciphertext.len]; + const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*; + const V = @Vector(AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); + read_seq += 1; + const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand; + const ad = handshake_buf[end_hdr - 5 ..][0..5]; + const key = server_handshake_key[0..AEAD.key_length].*; + AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch + return error.TlsBadRecordMac; + + break :c cleartext; + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + }; + + const inner_ct = cleartext[cleartext.len - 1]; + switch (inner_ct) { + @enumToInt(ContentType.handshake) => { + const handshake_len = mem.readIntBig(u24, cleartext[1..4]); + if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength; + switch (cleartext[0]) { + @enumToInt(HandshakeType.encrypted_extensions) => { + const ext_size = mem.readIntBig(u16, cleartext[4..6]); + if (ext_size != 0) { + @panic("TODO handle encrypted extensions"); + } + std.debug.print("empty encrypted extensions\n", .{}); + }, + else => { + std.debug.print("handshake type: {d}\n", .{cleartext[0]}); + return error.TlsUnexpectedMessage; + }, + } + }, + else => { + std.debug.print("inner content type: {d}\n", .{inner_ct}); + return error.TlsUnexpectedMessage; + }, + } }, else => { std.debug.print("content type: {s}\n", .{@tagName(ct)}); @@ -486,3 +640,56 @@ pub fn writeAll(tls: *Tls, stream: net.Stream, buffer: []const u8) !void { _ = buffer; @panic("hold on a minute, we didn't finish implementing the handshake yet"); } + +fn hkdfExpandLabel( + comptime Hkdf: type, + key: [Hkdf.prk_length]u8, + label: []const u8, + context: []const u8, + comptime len: usize, +) [len]u8 { + const max_label_len = 255; + const max_context_len = 255; + const tls13 = "tls13 "; + var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined; + mem.writeIntBig(u16, buf[0..2], len); + buf[2] = @intCast(u8, tls13.len + label.len); + buf[3..][0..tls13.len].* = tls13.*; + var i: usize = 3 + tls13.len; + mem.copy(u8, buf[i..], label); + i += label.len; + buf[i] = @intCast(u8, context.len); + i += 1; + mem.copy(u8, buf[i..], context); + i += context.len; + + var result: [len]u8 = undefined; + Hkdf.expand(&result, buf[0..i], key); + return result; +} + +fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 { + var result: [Hash.digest_length]u8 = undefined; + Hash.hash(&.{}, &result, .{}); + return result; +} + +fn helloHash(s0: []const u8, s1: []const u8, s2: []const u8, comptime Hash: type) [Hash.digest_length]u8 { + var h = Hash.init(.{}); + h.update(s0); + h.update(s1); + h.update(s2); + var result: [Hash.digest_length]u8 = undefined; + h.final(&result); + return result; +} + +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); + +inline fn big(x: anytype) @TypeOf(x) { + return switch (native_endian) { + .Big => x, + .Little => @byteSwap(x), + }; +}