diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 11846ca526..6d6e0754da 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -46,7 +46,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { // Only possible to happen if the private key is all zeroes. error.IdentityElement => return error.InsufficientEntropy, }; - _ = secp256r1_kp; const extensions_payload = tls.extension(.supported_versions, [_]u8{ @@ -70,11 +69,14 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .rsa_pkcs1_sha1, .ecdsa_sha1, })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ - //.secp256r1, + .secp256r1, .x25519, })) ++ tls.extension( .key_share, - array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++ array(1, x25519_kp.public_key)), + array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++ + array(1, x25519_kp.public_key) ++ + int2(@enumToInt(tls.NamedGroup.secp256r1)) ++ + array(1, secp256r1_kp.public_key.toUncompressedSec1())), ) ++ int2(@enumToInt(tls.ExtensionType.server_name)) ++ int2(host_len + 5) ++ // byte length of this extension payload @@ -182,7 +184,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { i += 2; if (i + extensions_size != frag.len) return error.TlsBadLength; var supported_version: u16 = 0; - var opt_x25519_server_pub_key: ?*[32]u8 = null; + var shared_key: [32]u8 = undefined; + var have_shared_key = false; while (i < frag.len) { const et = mem.readIntBig(u16, frag[i..][0..2]); i += 2; @@ -196,15 +199,34 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { supported_version = mem.readIntBig(u16, frag[i..][0..2]); }, @enumToInt(tls.ExtensionType.key_share) => { - if (opt_x25519_server_pub_key != null) return error.TlsIllegalParameter; + if (have_shared_key) return error.TlsIllegalParameter; + have_shared_key = true; const named_group = mem.readIntBig(u16, frag[i..][0..2]); i += 2; + const key_size = mem.readIntBig(u16, frag[i..][0..2]); + i += 2; + switch (named_group) { @enumToInt(tls.NamedGroup.x25519) => { - const key_size = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; if (key_size != 32) return error.TlsBadLength; - opt_x25519_server_pub_key = frag[i..][0..32]; + const server_pub_key = frag[i..][0..32]; + + shared_key = crypto.dh.X25519.scalarmult( + x25519_kp.secret_key, + server_pub_key.*, + ) catch return error.TlsDecryptFailure; + }, + @enumToInt(tls.NamedGroup.secp256r1) => { + const server_pub_key = frag[i..][0..key_size]; + + const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; + const pk = PublicKey.fromSec1(server_pub_key) catch { + return error.TlsDecryptFailure; + }; + const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .Big) catch { + return error.TlsDecryptFailure; + }; + shared_key = mul.affineCoordinates().x.toBytes(.Big); }, else => { std.debug.print("named group: {x}\n", .{named_group}); @@ -218,8 +240,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { } i = next_i; } - const x25519_server_pub_key = opt_x25519_server_pub_key orelse - return error.TlsIllegalParameter; + if (!have_shared_key) return error.TlsIllegalParameter; const tls_version = if (supported_version == 0) legacy_version else supported_version; switch (tls_version) { @enumToInt(tls.ProtocolVersion.tls_1_2) => { @@ -231,11 +252,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { else => return error.TlsIllegalParameter, } - const shared_key = crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - x25519_server_pub_key.*, - ) catch return error.TlsDecryptFailure; - switch (cipher_suite_tag) { inline .AES_128_GCM_SHA256, .AES_256_GCM_SHA384,