std.crypto.tls: fix x25519_ml_kem768 key share

This is mostly nfc cleanup as I was bisecting the client hello to find
the problematic part, and the only bug fix ended up being

    key_share.x25519_kp.public_key ++
    key_share.ml_kem768_kp.public_key.toBytes()

to

    key_share.ml_kem768_kp.public_key.toBytes() ++
    key_share.x25519_kp.public_key)

and the same swap in `KeyShare.exchange` as per some random blog that
says "a hybrid keyshare, constructed by concatenating the public KEM key
with the public X25519 key".  I also note that based on the same blog
post, there was a draft version of this method that indeed had these
values swapped, and that used to be supported by this code, but it was
not properly fixed up when this code was updated from the draft spec.

Closes #21747
This commit is contained in:
Jacob Young 2024-11-02 02:45:12 -04:00
parent 7f20c78c95
commit 7afb277725
2 changed files with 135 additions and 138 deletions

View File

@ -291,6 +291,12 @@ pub const NamedGroup = enum(u16) {
_, _,
}; };
pub const PskKeyExchangeMode = enum(u8) {
psk_ke = 0,
psk_dhe_ke = 1,
_,
};
pub const CipherSuite = enum(u16) { pub const CipherSuite = enum(u16) {
RSA_WITH_AES_128_CBC_SHA = 0x002F, RSA_WITH_AES_128_CBC_SHA = 0x002F,
DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033, DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033,
@ -407,6 +413,11 @@ pub const CipherSuite = enum(u16) {
} }
}; };
pub const CompressionMethod = enum(u8) {
null = 0,
_,
};
pub const CertificateType = enum(u8) { pub const CertificateType = enum(u8) {
X509 = 0, X509 = 0,
RawPublicKey = 2, RawPublicKey = 2,
@ -419,6 +430,11 @@ pub const KeyUpdateRequest = enum(u8) {
_, _,
}; };
pub const ChangeCipherSpecType = enum(u8) {
change_cipher_spec = 1,
_,
};
pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type, comptime explicit_iv_length: comptime_int) type { pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type, comptime explicit_iv_length: comptime_int) type {
return struct { return struct {
pub const A = ApplicationCipherT(AeadType, HashType, explicit_iv_length); pub const A = ApplicationCipherT(AeadType, HashType, explicit_iv_length);
@ -560,34 +576,38 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8)
return result; return result;
} }
pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { pub inline fn extension(et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 {
return int2(@intFromEnum(et)) ++ array(1, bytes); return int(u16, @intFromEnum(et)) ++ array(u16, u8, bytes);
} }
pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 { pub inline fn array(
comptime assert(bytes.len % elem_size == 0); comptime Len: type,
return int2(bytes.len) ++ bytes; comptime Elem: type,
} elems: anytype,
) [@divExact(@bitSizeOf(Len), 8) + @divExact(@bitSizeOf(Elem), 8) * elems.len]u8 {
pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 { const len_size = @divExact(@bitSizeOf(Len), 8);
assert(@sizeOf(E) == 2); const elem_size = @divExact(@bitSizeOf(Elem), 8);
var result: [tags.len * 2]u8 = undefined; var arr: [len_size + elem_size * elems.len]u8 = undefined;
for (tags, 0..) |elem, i| { std.mem.writeInt(Len, arr[0..len_size], @intCast(elem_size * elems.len), .big);
result[i * 2] = @as(u8, @truncate(@intFromEnum(elem) >> 8)); const ElemInt = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(Elem) } });
result[i * 2 + 1] = @as(u8, @truncate(@intFromEnum(elem))); for (0.., @as([elems.len]Elem, elems)) |index, elem| {
std.mem.writeInt(
ElemInt,
arr[len_size + elem_size * index ..][0..elem_size],
switch (@typeInfo(Elem)) {
.int => @as(Elem, elem),
.@"enum" => @intFromEnum(@as(Elem, elem)),
else => @bitCast(@as(Elem, elem)),
},
.big,
);
} }
return array(2, result);
}
pub inline fn int2(int: u16) [2]u8 {
var arr: [2]u8 = undefined;
std.mem.writeInt(u16, &arr, int, .big);
return arr; return arr;
} }
pub inline fn int3(int: u24) [3]u8 { pub inline fn int(comptime Int: type, val: Int) [@divExact(@bitSizeOf(Int), 8)]u8 {
var arr: [3]u8 = undefined; var arr: [@divExact(@bitSizeOf(Int), 8)]u8 = undefined;
std.mem.writeInt(u24, &arr, int, .big); std.mem.writeInt(Int, &arr, val, .big);
return arr; return arr;
} }
@ -670,9 +690,8 @@ pub const Decoder = struct {
else => @compileError("unsupported int type: " ++ @typeName(T)), else => @compileError("unsupported int type: " ++ @typeName(T)),
}, },
.@"enum" => |info| { .@"enum" => |info| {
const int = d.decode(info.tag_type);
if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); if (info.is_exhaustive) @compileError("exhaustive enum cannot be used");
return @as(T, @enumFromInt(int)); return @enumFromInt(d.decode(info.tag_type));
}, },
else => @compileError("unsupported type: " ++ @typeName(T)), else => @compileError("unsupported type: " ++ @typeName(T)),
} }

View File

@ -10,10 +10,8 @@ const Certificate = std.crypto.Certificate;
const max_ciphertext_len = tls.max_ciphertext_len; const max_ciphertext_len = tls.max_ciphertext_len;
const hmacExpandLabel = tls.hmacExpandLabel; const hmacExpandLabel = tls.hmacExpandLabel;
const hkdfExpandLabel = tls.hkdfExpandLabel; const hkdfExpandLabel = tls.hkdfExpandLabel;
const int2 = tls.int2; const int = tls.int;
const int3 = tls.int3;
const array = tls.array; const array = tls.array;
const enum_array = tls.enum_array;
tls_version: tls.ProtocolVersion, tls_version: tls.ProtocolVersion,
read_seq: u64, read_seq: u64,
@ -156,70 +154,62 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
error.IdentityElement => return error.InsufficientEntropy, error.IdentityElement => return error.InsufficientEntropy,
}; };
const extensions_payload = const extensions_payload = tls.extension(.supported_versions, array(u8, tls.ProtocolVersion, .{
tls.extension(.supported_versions, [_]u8{2 + 2} ++ // byte length of supported versions .tls_1_3,
int2(@intFromEnum(tls.ProtocolVersion.tls_1_3)) ++ .tls_1_2,
int2(@intFromEnum(tls.ProtocolVersion.tls_1_2))) ++ })) ++ tls.extension(.signature_algorithms, array(u16, tls.SignatureScheme, .{
tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{
.ecdsa_secp256r1_sha256, .ecdsa_secp256r1_sha256,
.ecdsa_secp384r1_sha384, .ecdsa_secp384r1_sha384,
.rsa_pss_rsae_sha256, .rsa_pss_rsae_sha256,
.rsa_pss_rsae_sha384, .rsa_pss_rsae_sha384,
.rsa_pss_rsae_sha512, .rsa_pss_rsae_sha512,
.ed25519, .ed25519,
})) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ })) ++ tls.extension(.supported_groups, array(u16, tls.NamedGroup, .{
.x25519_ml_kem768, .x25519_ml_kem768,
.secp256r1, .secp256r1,
.x25519, .x25519,
})) ++ tls.extension( })) ++ tls.extension(.psk_key_exchange_modes, array(u8, tls.PskKeyExchangeMode, .{
.key_share, .psk_dhe_ke,
array(1, int2(@intFromEnum(tls.NamedGroup.x25519)) ++ })) ++ tls.extension(.key_share, array(
array(1, key_share.x25519_kp.public_key) ++ u16,
int2(@intFromEnum(tls.NamedGroup.secp256r1)) ++ u8,
array(1, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++ int(u16, @intFromEnum(tls.NamedGroup.x25519_ml_kem768)) ++
int2(@intFromEnum(tls.NamedGroup.x25519_ml_kem768)) ++ array(u16, u8, key_share.ml_kem768_kp.public_key.toBytes() ++ key_share.x25519_kp.public_key) ++
array(1, key_share.x25519_kp.public_key ++ key_share.ml_kem768_kp.public_key.toBytes())), int(u16, @intFromEnum(tls.NamedGroup.secp256r1)) ++
) ++ array(u16, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++
int2(@intFromEnum(tls.ExtensionType.server_name)) ++ int(u16, @intFromEnum(tls.NamedGroup.x25519)) ++
int2(host_len + 5) ++ // byte length of this extension payload array(u16, u8, key_share.x25519_kp.public_key),
int2(host_len + 3) ++ // server_name_list byte count )) ++ int(u16, @intFromEnum(tls.ExtensionType.server_name)) ++
[1]u8{0x00} ++ // name_type int(u16, 2 + 1 + 2 + host_len) ++ // byte length of this extension payload
int2(host_len); int(u16, 1 + 2 + host_len) ++ // server_name_list byte count
.{0x00} ++ // name_type
int(u16, host_len);
const extensions_header = const extensions_header =
int2(@intCast(extensions_payload.len + host_len)) ++ int(u16, @intCast(extensions_payload.len + host_len)) ++
extensions_payload; extensions_payload;
const legacy_compression_methods = 0x0100;
const client_hello = const client_hello =
int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
client_hello_rand ++ client_hello_rand ++
[1]u8{32} ++ legacy_session_id ++ [1]u8{32} ++ legacy_session_id ++
cipher_suites ++ cipher_suites ++
int2(legacy_compression_methods) ++ array(u8, tls.CompressionMethod, .{.null}) ++
extensions_header; extensions_header;
const out_handshake = const out_handshake = .{@intFromEnum(tls.HandshakeType.client_hello)} ++
[_]u8{@intFromEnum(tls.HandshakeType.client_hello)} ++ int(u24, @intCast(client_hello.len + host_len)) ++
int3(@intCast(client_hello.len + host_len)) ++
client_hello; client_hello;
const cleartext_header = [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ const cleartext_header = .{@intFromEnum(tls.ContentType.handshake)} ++
int2(@intFromEnum(tls.ProtocolVersion.tls_1_0)) ++ // legacy_record_version int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_0)) ++
int2(@intCast(out_handshake.len + host_len)) ++ int(u16, @intCast(out_handshake.len + host_len)) ++
out_handshake; out_handshake;
{ {
var iovecs = [_]std.posix.iovec_const{ var iovecs = [_]std.posix.iovec_const{
.{ .{ .base = &cleartext_header, .len = cleartext_header.len },
.base = &cleartext_header, .{ .base = host.ptr, .len = host.len },
.len = cleartext_header.len,
},
.{
.base = host.ptr,
.len = host.len,
},
}; };
try stream.writevAll(&iovecs); try stream.writevAll(&iovecs);
} }
@ -526,7 +516,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
}, },
.change_cipher_spec => { .change_cipher_spec => {
try ctd.ensure(1); try ctd.ensure(1);
if (ctd.decode(u8) != 0x01) return error.TlsIllegalParameter; if (ctd.decode(tls.ChangeCipherSpecType) != .change_cipher_spec) return error.TlsIllegalParameter;
cipher_state = pending_cipher_state; cipher_state = pending_cipher_state;
}, },
.handshake => while (true) { .handshake => while (true) {
@ -648,20 +638,13 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
if (handshake_state != .server_hello_done) return error.TlsUnexpectedMessage; if (handshake_state != .server_hello_done) return error.TlsUnexpectedMessage;
handshake_state = .finished; handshake_state = .finished;
const client_key_exchange_msg = const client_key_exchange_msg = .{@intFromEnum(tls.ContentType.handshake)} ++
[_]u8{@intFromEnum(tls.ContentType.handshake)} ++ // record content type int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version array(u16, u8, .{@intFromEnum(tls.HandshakeType.client_key_exchange)} ++
int2(0x46) ++ // record length array(u24, u8, array(u8, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1())));
.{@intFromEnum(tls.HandshakeType.client_key_exchange)} ++ // handshake type const client_change_cipher_spec_msg = .{@intFromEnum(tls.ContentType.change_cipher_spec)} ++
int3(0x42) ++ // params length int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
.{0x41} ++ // pubkey length array(u16, tls.ChangeCipherSpecType, .{.change_cipher_spec});
key_share.secp256r1_kp.public_key.toUncompressedSec1();
// This message is to trick buggy proxies into behaving correctly.
const client_change_cipher_spec_msg =
[_]u8{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ // record content type
int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version
int2(1) ++ // record length
.{0x01};
const pre_master_secret = key_share.getSharedSecret().?; const pre_master_secret = key_share.getSharedSecret().?;
switch (handshake_cipher) { switch (handshake_cipher) {
inline else => |*p| { inline else => |*p| {
@ -680,10 +663,13 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
@sizeOf(P.Tls_1_2), @sizeOf(P.Tls_1_2),
); );
const verify_data_len = 12; const verify_data_len = 12;
const client_verify_cleartext = const client_verify_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++
[_]u8{@intFromEnum(tls.HandshakeType.finished)} ++ // handshake type array(u24, u8, hmacExpandLabel(
int3(verify_data_len) ++ // verify data length P.Hmac,
hmacExpandLabel(P.Hmac, &master_secret, &.{ "client finished", &p.transcript_hash.peek() }, verify_data_len); &master_secret,
&.{ "client finished", &p.transcript_hash.peek() },
verify_data_len,
));
p.transcript_hash.update(&client_verify_cleartext); p.transcript_hash.update(&client_verify_cleartext);
p.version = .{ .tls_1_2 = .{ p.version = .{ .tls_1_2 = .{
.server_verify_data = hmacExpandLabel( .server_verify_data = hmacExpandLabel(
@ -709,25 +695,23 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
const operand: V = pad ++ @as([8]u8, @bitCast(big(write_seq))); const operand: V = pad ++ @as([8]u8, @bitCast(big(write_seq)));
break :nonce @as(V, pv.app_cipher.client_write_IV ++ pv.app_cipher.client_salt) ^ operand; break :nonce @as(V, pv.app_cipher.client_write_IV ++ pv.app_cipher.client_salt) ^ operand;
}; };
var client_verify_msg = [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ // record content type var client_verify_msg = .{@intFromEnum(tls.ContentType.handshake)} ++
int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
int2(P.record_iv_length + client_verify_cleartext.len + P.mac_length) ++ // record length array(u16, u8, nonce[P.fixed_iv_length..].* ++
nonce[P.fixed_iv_length..].* ++ @as([client_verify_cleartext.len + P.mac_length]u8, undefined));
@as([client_verify_cleartext.len + P.mac_length]u8, undefined);
P.AEAD.encrypt( P.AEAD.encrypt(
client_verify_msg[client_verify_msg.len - P.mac_length - client_verify_msg[client_verify_msg.len - P.mac_length -
client_verify_cleartext.len ..][0..client_verify_cleartext.len], client_verify_cleartext.len ..][0..client_verify_cleartext.len],
client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length], client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length],
&client_verify_cleartext, &client_verify_cleartext,
std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int2(client_verify_cleartext.len), std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
nonce, nonce,
pv.app_cipher.client_write_key, pv.app_cipher.client_write_key,
); );
const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg; const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg;
var all_msgs_vec = [_]std.posix.iovec_const{.{ var all_msgs_vec = [_]std.posix.iovec_const{
.base = &all_msgs, .{ .base = &all_msgs, .len = all_msgs.len },
.len = all_msgs.len, };
}};
try stream.writevAll(&all_msgs_vec); try stream.writevAll(&all_msgs_vec);
}, },
} }
@ -755,11 +739,9 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
if (cipher_state == .cleartext) return error.TlsUnexpectedMessage; if (cipher_state == .cleartext) return error.TlsUnexpectedMessage;
if (handshake_state != .finished) return error.TlsUnexpectedMessage; if (handshake_state != .finished) return error.TlsUnexpectedMessage;
// This message is to trick buggy proxies into behaving correctly. // This message is to trick buggy proxies into behaving correctly.
const client_change_cipher_spec_msg = const client_change_cipher_spec_msg = .{@intFromEnum(tls.ContentType.change_cipher_spec)} ++
[_]u8{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version array(u16, tls.ChangeCipherSpecType, .{.change_cipher_spec});
int2(1) ++ // length
.{0x01};
const app_cipher = app_cipher: switch (handshake_cipher) { const app_cipher = app_cipher: switch (handshake_cipher) {
inline else => |*p, tag| switch (tls_version) { inline else => |*p, tag| switch (tls_version) {
.tls_1_3 => { .tls_1_3 => {
@ -771,17 +753,15 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
if (!mem.eql(u8, &expected_server_verify_data, hsd.buf)) return error.TlsDecryptError; if (!mem.eql(u8, &expected_server_verify_data, hsd.buf)) return error.TlsDecryptError;
const handshake_hash = p.transcript_hash.finalResult(); const handshake_hash = p.transcript_hash.finalResult();
const verify_data = tls.hmac(P.Hmac, &handshake_hash, pv.client_finished_key); const verify_data = tls.hmac(P.Hmac, &handshake_hash, pv.client_finished_key);
const out_cleartext = [_]u8{ const out_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++
@intFromEnum(tls.HandshakeType.finished), array(u24, u8, verify_data) ++
0, 0, verify_data.len, // length .{@intFromEnum(tls.ContentType.handshake)};
} ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)};
const wrapped_len = out_cleartext.len + P.AEAD.tag_length; const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
var finished_msg = [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ var finished_msg = .{@intFromEnum(tls.ContentType.application_data)} ++
int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
int2(wrapped_len) ++ // byte length of encrypted record array(u16, u8, @as([wrapped_len]u8, undefined));
@as([wrapped_len]u8, undefined);
const ad = finished_msg[0..tls.record_header_len]; const ad = finished_msg[0..tls.record_header_len];
const ciphertext = finished_msg[tls.record_header_len..][0..out_cleartext.len]; const ciphertext = finished_msg[tls.record_header_len..][0..out_cleartext.len];
@ -790,10 +770,9 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key); P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key);
const all_msgs = client_change_cipher_spec_msg ++ finished_msg; const all_msgs = client_change_cipher_spec_msg ++ finished_msg;
var all_msgs_vec = [_]std.posix.iovec_const{.{ var all_msgs_vec = [_]std.posix.iovec_const{
.base = &all_msgs, .{ .base = &all_msgs, .len = all_msgs.len },
.len = all_msgs.len, };
}};
try stream.writevAll(&all_msgs_vec); try stream.writevAll(&all_msgs_vec);
const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
@ -965,10 +944,9 @@ fn prepareCiphertextRecord(
const record_start = ciphertext_end; const record_start = ciphertext_end;
const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
ad.* = ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++
[_]u8{@intFromEnum(tls.ContentType.application_data)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ int(u16, ciphertext_len + P.AEAD.tag_length);
int2(ciphertext_len + P.AEAD.tag_length);
ciphertext_end += ad.len; ciphertext_end += ad.len;
const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len];
ciphertext_end += ciphertext_len; ciphertext_end += ciphertext_len;
@ -1023,10 +1001,10 @@ fn prepareCiphertextRecord(
const record_start = ciphertext_end; const record_start = ciphertext_end;
const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
ciphertext_end += tls.record_header_len; ciphertext_end += tls.record_header_len;
record_header.* = [_]u8{@intFromEnum(inner_content_type)} ++ record_header.* = .{@intFromEnum(inner_content_type)} ++
int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
int2(P.record_iv_length + message_len + P.mac_length); int(u16, P.record_iv_length + message_len + P.mac_length);
const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int2(message_len); const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length]; const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length];
ciphertext_end += P.record_iv_length; ciphertext_end += P.record_iv_length;
const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and
@ -1569,25 +1547,17 @@ const KeyShare = struct {
) error{ TlsIllegalParameter, TlsDecryptFailure }!void { ) error{ TlsIllegalParameter, TlsDecryptFailure }!void {
switch (named_group) { switch (named_group) {
.x25519_ml_kem768 => { .x25519_ml_kem768 => {
const xksl = crypto.dh.X25519.public_length; const hksl = crypto.kem.ml_kem.MLKem768.ciphertext_length;
const hksl = xksl + crypto.kem.ml_kem.MLKem768.ciphertext_length; const xksl = hksl + crypto.dh.X25519.public_length;
if (server_pub_key.len != hksl) return error.TlsIllegalParameter; if (server_pub_key.len != xksl) return error.TlsIllegalParameter;
const xsk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..xksl].*) catch const hsk = ks.ml_kem768_kp.secret_key.decaps(server_pub_key[0..hksl]) catch
return error.TlsDecryptFailure; return error.TlsDecryptFailure;
const hsk = ks.ml_kem768_kp.secret_key.decaps(server_pub_key[xksl..hksl]) catch const xsk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[hksl..xksl].*) catch
return error.TlsDecryptFailure; return error.TlsDecryptFailure;
@memcpy(ks.sk_buf[0..xsk.len], &xsk); @memcpy(ks.sk_buf[0..hsk.len], &hsk);
@memcpy(ks.sk_buf[xsk.len..][0..hsk.len], &hsk); @memcpy(ks.sk_buf[hsk.len..][0..xsk.len], &xsk);
ks.sk_len = xsk.len + hsk.len; ks.sk_len = hsk.len + xsk.len;
},
.x25519 => {
const ksl = crypto.dh.X25519.public_length;
if (server_pub_key.len != ksl) return error.TlsIllegalParameter;
const sk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..ksl].*) catch
return error.TlsDecryptFailure;
@memcpy(ks.sk_buf[0..sk.len], &sk);
ks.sk_len = sk.len;
}, },
.secp256r1 => { .secp256r1 => {
const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey;
@ -1598,6 +1568,14 @@ const KeyShare = struct {
@memcpy(ks.sk_buf[0..sk.len], &sk); @memcpy(ks.sk_buf[0..sk.len], &sk);
ks.sk_len = sk.len; ks.sk_len = sk.len;
}, },
.x25519 => {
const ksl = crypto.dh.X25519.public_length;
if (server_pub_key.len != ksl) return error.TlsIllegalParameter;
const sk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..ksl].*) catch
return error.TlsDecryptFailure;
@memcpy(ks.sk_buf[0..sk.len], &sk);
ks.sk_len = sk.len;
},
else => return error.TlsIllegalParameter, else => return error.TlsIllegalParameter,
} }
} }
@ -1877,7 +1855,7 @@ fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec {
/// aes128-gcm: 138 MiB/s /// aes128-gcm: 138 MiB/s
/// aes256-gcm: 120 MiB/s /// aes256-gcm: 120 MiB/s
const cipher_suites = if (crypto.core.aes.has_hardware_support) const cipher_suites = if (crypto.core.aes.has_hardware_support)
enum_array(tls.CipherSuite, &.{ array(u16, tls.CipherSuite, .{
.AEGIS_128L_SHA256, .AEGIS_128L_SHA256,
.AEGIS_256_SHA512, .AEGIS_256_SHA512,
.AES_128_GCM_SHA256, .AES_128_GCM_SHA256,
@ -1888,7 +1866,7 @@ const cipher_suites = if (crypto.core.aes.has_hardware_support)
.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
}) })
else else
enum_array(tls.CipherSuite, &.{ array(u16, tls.CipherSuite, .{
.CHACHA20_POLY1305_SHA256, .CHACHA20_POLY1305_SHA256,
.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
.AEGIS_128L_SHA256, .AEGIS_128L_SHA256,