diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 7732f3b74e..6479c77d75 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -291,6 +291,12 @@ pub const NamedGroup = enum(u16) { _, }; +pub const PskKeyExchangeMode = enum(u8) { + psk_ke = 0, + psk_dhe_ke = 1, + _, +}; + pub const CipherSuite = enum(u16) { RSA_WITH_AES_128_CBC_SHA = 0x002F, DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033, @@ -407,6 +413,11 @@ pub const CipherSuite = enum(u16) { } }; +pub const CompressionMethod = enum(u8) { + null = 0, + _, +}; + pub const CertificateType = enum(u8) { X509 = 0, RawPublicKey = 2, @@ -419,6 +430,11 @@ pub const KeyUpdateRequest = enum(u8) { _, }; +pub const ChangeCipherSpecType = enum(u8) { + change_cipher_spec = 1, + _, +}; + pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type, comptime explicit_iv_length: comptime_int) type { return struct { pub const A = ApplicationCipherT(AeadType, HashType, explicit_iv_length); @@ -560,34 +576,38 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) return result; } -pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { - return int2(@intFromEnum(et)) ++ array(1, bytes); +pub inline fn extension(et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { + return int(u16, @intFromEnum(et)) ++ array(u16, u8, bytes); } -pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 { - comptime assert(bytes.len % elem_size == 0); - return int2(bytes.len) ++ bytes; -} - -pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 { - assert(@sizeOf(E) == 2); - var result: [tags.len * 2]u8 = undefined; - for (tags, 0..) |elem, i| { - result[i * 2] = @as(u8, @truncate(@intFromEnum(elem) >> 8)); - result[i * 2 + 1] = @as(u8, @truncate(@intFromEnum(elem))); +pub inline fn array( + comptime Len: type, + comptime Elem: type, + elems: anytype, +) [@divExact(@bitSizeOf(Len), 8) + @divExact(@bitSizeOf(Elem), 8) * elems.len]u8 { + const len_size = @divExact(@bitSizeOf(Len), 8); + const elem_size = @divExact(@bitSizeOf(Elem), 8); + var arr: [len_size + elem_size * elems.len]u8 = undefined; + std.mem.writeInt(Len, arr[0..len_size], @intCast(elem_size * elems.len), .big); + const ElemInt = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(Elem) } }); + for (0.., @as([elems.len]Elem, elems)) |index, elem| { + std.mem.writeInt( + ElemInt, + arr[len_size + elem_size * index ..][0..elem_size], + switch (@typeInfo(Elem)) { + .int => @as(Elem, elem), + .@"enum" => @intFromEnum(@as(Elem, elem)), + else => @bitCast(@as(Elem, elem)), + }, + .big, + ); } - return array(2, result); -} - -pub inline fn int2(int: u16) [2]u8 { - var arr: [2]u8 = undefined; - std.mem.writeInt(u16, &arr, int, .big); return arr; } -pub inline fn int3(int: u24) [3]u8 { - var arr: [3]u8 = undefined; - std.mem.writeInt(u24, &arr, int, .big); +pub inline fn int(comptime Int: type, val: Int) [@divExact(@bitSizeOf(Int), 8)]u8 { + var arr: [@divExact(@bitSizeOf(Int), 8)]u8 = undefined; + std.mem.writeInt(Int, &arr, val, .big); return arr; } @@ -670,9 +690,8 @@ pub const Decoder = struct { else => @compileError("unsupported int type: " ++ @typeName(T)), }, .@"enum" => |info| { - const int = d.decode(info.tag_type); if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); - return @as(T, @enumFromInt(int)); + return @enumFromInt(d.decode(info.tag_type)); }, else => @compileError("unsupported type: " ++ @typeName(T)), } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 5b1d0a9bf6..a8624fd03f 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -10,10 +10,8 @@ const Certificate = std.crypto.Certificate; const max_ciphertext_len = tls.max_ciphertext_len; const hmacExpandLabel = tls.hmacExpandLabel; const hkdfExpandLabel = tls.hkdfExpandLabel; -const int2 = tls.int2; -const int3 = tls.int3; +const int = tls.int; const array = tls.array; -const enum_array = tls.enum_array; tls_version: tls.ProtocolVersion, read_seq: u64, @@ -156,70 +154,62 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In error.IdentityElement => return error.InsufficientEntropy, }; - const extensions_payload = - tls.extension(.supported_versions, [_]u8{2 + 2} ++ // byte length of supported versions - int2(@intFromEnum(tls.ProtocolVersion.tls_1_3)) ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2))) ++ - tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ + const extensions_payload = tls.extension(.supported_versions, array(u8, tls.ProtocolVersion, .{ + .tls_1_3, + .tls_1_2, + })) ++ tls.extension(.signature_algorithms, array(u16, tls.SignatureScheme, .{ .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, .rsa_pss_rsae_sha256, .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha512, .ed25519, - })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ + })) ++ tls.extension(.supported_groups, array(u16, tls.NamedGroup, .{ .x25519_ml_kem768, .secp256r1, .x25519, - })) ++ tls.extension( - .key_share, - array(1, int2(@intFromEnum(tls.NamedGroup.x25519)) ++ - array(1, key_share.x25519_kp.public_key) ++ - int2(@intFromEnum(tls.NamedGroup.secp256r1)) ++ - array(1, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++ - int2(@intFromEnum(tls.NamedGroup.x25519_ml_kem768)) ++ - array(1, key_share.x25519_kp.public_key ++ key_share.ml_kem768_kp.public_key.toBytes())), - ) ++ - int2(@intFromEnum(tls.ExtensionType.server_name)) ++ - int2(host_len + 5) ++ // byte length of this extension payload - int2(host_len + 3) ++ // server_name_list byte count - [1]u8{0x00} ++ // name_type - int2(host_len); + })) ++ tls.extension(.psk_key_exchange_modes, array(u8, tls.PskKeyExchangeMode, .{ + .psk_dhe_ke, + })) ++ tls.extension(.key_share, array( + u16, + u8, + int(u16, @intFromEnum(tls.NamedGroup.x25519_ml_kem768)) ++ + array(u16, u8, key_share.ml_kem768_kp.public_key.toBytes() ++ key_share.x25519_kp.public_key) ++ + int(u16, @intFromEnum(tls.NamedGroup.secp256r1)) ++ + array(u16, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++ + int(u16, @intFromEnum(tls.NamedGroup.x25519)) ++ + array(u16, u8, key_share.x25519_kp.public_key), + )) ++ int(u16, @intFromEnum(tls.ExtensionType.server_name)) ++ + int(u16, 2 + 1 + 2 + host_len) ++ // byte length of this extension payload + int(u16, 1 + 2 + host_len) ++ // server_name_list byte count + .{0x00} ++ // name_type + int(u16, host_len); const extensions_header = - int2(@intCast(extensions_payload.len + host_len)) ++ + int(u16, @intCast(extensions_payload.len + host_len)) ++ extensions_payload; - const legacy_compression_methods = 0x0100; - const client_hello = - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ client_hello_rand ++ [1]u8{32} ++ legacy_session_id ++ cipher_suites ++ - int2(legacy_compression_methods) ++ + array(u8, tls.CompressionMethod, .{.null}) ++ extensions_header; - const out_handshake = - [_]u8{@intFromEnum(tls.HandshakeType.client_hello)} ++ - int3(@intCast(client_hello.len + host_len)) ++ + const out_handshake = .{@intFromEnum(tls.HandshakeType.client_hello)} ++ + int(u24, @intCast(client_hello.len + host_len)) ++ client_hello; - const cleartext_header = [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_0)) ++ // legacy_record_version - int2(@intCast(out_handshake.len + host_len)) ++ + const cleartext_header = .{@intFromEnum(tls.ContentType.handshake)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_0)) ++ + int(u16, @intCast(out_handshake.len + host_len)) ++ out_handshake; { var iovecs = [_]std.posix.iovec_const{ - .{ - .base = &cleartext_header, - .len = cleartext_header.len, - }, - .{ - .base = host.ptr, - .len = host.len, - }, + .{ .base = &cleartext_header, .len = cleartext_header.len }, + .{ .base = host.ptr, .len = host.len }, }; try stream.writevAll(&iovecs); } @@ -526,7 +516,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In }, .change_cipher_spec => { try ctd.ensure(1); - if (ctd.decode(u8) != 0x01) return error.TlsIllegalParameter; + if (ctd.decode(tls.ChangeCipherSpecType) != .change_cipher_spec) return error.TlsIllegalParameter; cipher_state = pending_cipher_state; }, .handshake => while (true) { @@ -648,20 +638,13 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (handshake_state != .server_hello_done) return error.TlsUnexpectedMessage; handshake_state = .finished; - const client_key_exchange_msg = - [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ // record content type - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(0x46) ++ // record length - .{@intFromEnum(tls.HandshakeType.client_key_exchange)} ++ // handshake type - int3(0x42) ++ // params length - .{0x41} ++ // pubkey length - key_share.secp256r1_kp.public_key.toUncompressedSec1(); - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = - [_]u8{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ // record content type - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(1) ++ // record length - .{0x01}; + const client_key_exchange_msg = .{@intFromEnum(tls.ContentType.handshake)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, u8, .{@intFromEnum(tls.HandshakeType.client_key_exchange)} ++ + array(u24, u8, array(u8, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()))); + const client_change_cipher_spec_msg = .{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, tls.ChangeCipherSpecType, .{.change_cipher_spec}); const pre_master_secret = key_share.getSharedSecret().?; switch (handshake_cipher) { inline else => |*p| { @@ -680,10 +663,13 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In @sizeOf(P.Tls_1_2), ); const verify_data_len = 12; - const client_verify_cleartext = - [_]u8{@intFromEnum(tls.HandshakeType.finished)} ++ // handshake type - int3(verify_data_len) ++ // verify data length - hmacExpandLabel(P.Hmac, &master_secret, &.{ "client finished", &p.transcript_hash.peek() }, verify_data_len); + const client_verify_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++ + array(u24, u8, hmacExpandLabel( + P.Hmac, + &master_secret, + &.{ "client finished", &p.transcript_hash.peek() }, + verify_data_len, + )); p.transcript_hash.update(&client_verify_cleartext); p.version = .{ .tls_1_2 = .{ .server_verify_data = hmacExpandLabel( @@ -709,25 +695,23 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const operand: V = pad ++ @as([8]u8, @bitCast(big(write_seq))); break :nonce @as(V, pv.app_cipher.client_write_IV ++ pv.app_cipher.client_salt) ^ operand; }; - var client_verify_msg = [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ // record content type - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(P.record_iv_length + client_verify_cleartext.len + P.mac_length) ++ // record length - nonce[P.fixed_iv_length..].* ++ - @as([client_verify_cleartext.len + P.mac_length]u8, undefined); + var client_verify_msg = .{@intFromEnum(tls.ContentType.handshake)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, u8, nonce[P.fixed_iv_length..].* ++ + @as([client_verify_cleartext.len + P.mac_length]u8, undefined)); P.AEAD.encrypt( client_verify_msg[client_verify_msg.len - P.mac_length - client_verify_cleartext.len ..][0..client_verify_cleartext.len], client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length], &client_verify_cleartext, - std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int2(client_verify_cleartext.len), + std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len), nonce, pv.app_cipher.client_write_key, ); const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg; - var all_msgs_vec = [_]std.posix.iovec_const{.{ - .base = &all_msgs, - .len = all_msgs.len, - }}; + var all_msgs_vec = [_]std.posix.iovec_const{ + .{ .base = &all_msgs, .len = all_msgs.len }, + }; try stream.writevAll(&all_msgs_vec); }, } @@ -755,11 +739,9 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (cipher_state == .cleartext) return error.TlsUnexpectedMessage; if (handshake_state != .finished) return error.TlsUnexpectedMessage; // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = - [_]u8{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(1) ++ // length - .{0x01}; + const client_change_cipher_spec_msg = .{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, tls.ChangeCipherSpecType, .{.change_cipher_spec}); const app_cipher = app_cipher: switch (handshake_cipher) { inline else => |*p, tag| switch (tls_version) { .tls_1_3 => { @@ -771,17 +753,15 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (!mem.eql(u8, &expected_server_verify_data, hsd.buf)) return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, pv.client_finished_key); - const out_cleartext = [_]u8{ - @intFromEnum(tls.HandshakeType.finished), - 0, 0, verify_data.len, // length - } ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)}; + const out_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++ + array(u24, u8, verify_data) ++ + .{@intFromEnum(tls.ContentType.handshake)}; const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - var finished_msg = [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(wrapped_len) ++ // byte length of encrypted record - @as([wrapped_len]u8, undefined); + var finished_msg = .{@intFromEnum(tls.ContentType.application_data)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, u8, @as([wrapped_len]u8, undefined)); const ad = finished_msg[0..tls.record_header_len]; const ciphertext = finished_msg[tls.record_header_len..][0..out_cleartext.len]; @@ -790,10 +770,9 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key); const all_msgs = client_change_cipher_spec_msg ++ finished_msg; - var all_msgs_vec = [_]std.posix.iovec_const{.{ - .base = &all_msgs, - .len = all_msgs.len, - }}; + var all_msgs_vec = [_]std.posix.iovec_const{ + .{ .base = &all_msgs, .len = all_msgs.len }, + }; try stream.writevAll(&all_msgs_vec); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); @@ -965,10 +944,9 @@ fn prepareCiphertextRecord( const record_start = ciphertext_end; const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; - ad.* = - [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(ciphertext_len + P.AEAD.tag_length); + ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int(u16, ciphertext_len + P.AEAD.tag_length); ciphertext_end += ad.len; const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; ciphertext_end += ciphertext_len; @@ -1023,10 +1001,10 @@ fn prepareCiphertextRecord( const record_start = ciphertext_end; const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ciphertext_end += tls.record_header_len; - record_header.* = [_]u8{@intFromEnum(inner_content_type)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(P.record_iv_length + message_len + P.mac_length); - const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int2(message_len); + record_header.* = .{@intFromEnum(inner_content_type)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int(u16, P.record_iv_length + message_len + P.mac_length); + const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len); const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length]; ciphertext_end += P.record_iv_length; const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and @@ -1569,25 +1547,17 @@ const KeyShare = struct { ) error{ TlsIllegalParameter, TlsDecryptFailure }!void { switch (named_group) { .x25519_ml_kem768 => { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + crypto.kem.ml_kem.MLKem768.ciphertext_length; - if (server_pub_key.len != hksl) return error.TlsIllegalParameter; + const hksl = crypto.kem.ml_kem.MLKem768.ciphertext_length; + const xksl = hksl + crypto.dh.X25519.public_length; + if (server_pub_key.len != xksl) return error.TlsIllegalParameter; - const xsk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..xksl].*) catch + const hsk = ks.ml_kem768_kp.secret_key.decaps(server_pub_key[0..hksl]) catch return error.TlsDecryptFailure; - const hsk = ks.ml_kem768_kp.secret_key.decaps(server_pub_key[xksl..hksl]) catch + const xsk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[hksl..xksl].*) catch return error.TlsDecryptFailure; - @memcpy(ks.sk_buf[0..xsk.len], &xsk); - @memcpy(ks.sk_buf[xsk.len..][0..hsk.len], &hsk); - ks.sk_len = xsk.len + hsk.len; - }, - .x25519 => { - const ksl = crypto.dh.X25519.public_length; - if (server_pub_key.len != ksl) return error.TlsIllegalParameter; - const sk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..ksl].*) catch - return error.TlsDecryptFailure; - @memcpy(ks.sk_buf[0..sk.len], &sk); - ks.sk_len = sk.len; + @memcpy(ks.sk_buf[0..hsk.len], &hsk); + @memcpy(ks.sk_buf[hsk.len..][0..xsk.len], &xsk); + ks.sk_len = hsk.len + xsk.len; }, .secp256r1 => { const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; @@ -1598,6 +1568,14 @@ const KeyShare = struct { @memcpy(ks.sk_buf[0..sk.len], &sk); ks.sk_len = sk.len; }, + .x25519 => { + const ksl = crypto.dh.X25519.public_length; + if (server_pub_key.len != ksl) return error.TlsIllegalParameter; + const sk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..ksl].*) catch + return error.TlsDecryptFailure; + @memcpy(ks.sk_buf[0..sk.len], &sk); + ks.sk_len = sk.len; + }, else => return error.TlsIllegalParameter, } } @@ -1877,7 +1855,7 @@ fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec { /// aes128-gcm: 138 MiB/s /// aes256-gcm: 120 MiB/s const cipher_suites = if (crypto.core.aes.has_hardware_support) - enum_array(tls.CipherSuite, &.{ + array(u16, tls.CipherSuite, .{ .AEGIS_128L_SHA256, .AEGIS_256_SHA512, .AES_128_GCM_SHA256, @@ -1888,7 +1866,7 @@ const cipher_suites = if (crypto.core.aes.has_hardware_support) .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, }) else - enum_array(tls.CipherSuite, &.{ + array(u16, tls.CipherSuite, .{ .CHACHA20_POLY1305_SHA256, .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, .AEGIS_128L_SHA256,