std.crypto.Tls: decrypting handshake messages

This commit is contained in:
Andrew Kelley 2022-12-15 00:55:33 -07:00
parent 920e5bc4ff
commit 595fff7cb6

View File

@ -234,7 +234,12 @@ const cipher_suites = blk: {
pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
assert(tls.state == .start);
crypto.random.bytes(&tls.x25519_priv_key);
tls.x25519_pub_key = try crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key);
tls.x25519_pub_key = crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key) catch |err| {
switch (err) {
// Only possible to happen if the private key is all zeroes.
error.IdentityElement => return error.InsufficientEntropy,
}
};
// random (u32)
var rand_buf: [32]u8 = undefined;
@ -337,6 +342,14 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
};
try stream.writevAll(&iovecs);
const client_hello_bytes1 = hello_header[5..];
var client_handshake_key: [32]u8 = undefined;
var server_handshake_key: [32]u8 = undefined;
var client_handshake_iv: [12]u8 = undefined;
var server_handshake_iv: [12]u8 = undefined;
var cipher_suite: CipherSuite = undefined;
var handshake_buf: [4000]u8 = undefined;
var len: usize = 0;
var i: usize = i: {
@ -373,7 +386,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
const legacy_session_id_echo_len = hello[34];
if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
return error.TlsIllegalParameter;
std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
const legacy_compression_method = hello[37];
@ -404,12 +417,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
const key_size = mem.readIntBig(u16, hello[i..][0..2]);
i += 2;
if (key_size != 32) return error.TlsBadLength;
const encrypted_key = hello[i..][0..32].*;
const server_pub_key = try crypto.dh.X25519.scalarmult(
tls.x25519_priv_key,
encrypted_key,
);
tls.x25519_server_pub_key = server_pub_key;
tls.x25519_server_pub_key = hello[i..][0..32].*;
have_server_pub_key = true;
},
else => {
@ -435,12 +443,77 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
},
else => return error.TlsIllegalParameter,
}
const shared_key = crypto.dh.X25519.scalarmult(
tls.x25519_priv_key,
tls.x25519_server_pub_key,
) catch return error.TlsDecryptFailure;
switch (cipher_suite) {
.TLS_AES_128_GCM_SHA256 => {
const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
const Hash = crypto.hash.sha2.Sha256;
const Hmac = crypto.auth.hmac.Hmac(Hash);
const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash);
const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length));
const empty_hash = emptyHash(Hash);
const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length);
const handshake_secret = Hkdf.extract(&derived_secret, &shared_key);
const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length);
const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length);
client_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length);
server_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length);
client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length);
server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length);
//std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\n", .{
// std.fmt.fmtSliceHexLower(&shared_key),
// std.fmt.fmtSliceHexLower(&hello_hash),
// std.fmt.fmtSliceHexLower(&early_secret),
// std.fmt.fmtSliceHexLower(&empty_hash),
// std.fmt.fmtSliceHexLower(&derived_secret),
// std.fmt.fmtSliceHexLower(&handshake_secret),
// std.fmt.fmtSliceHexLower(&client_secret),
// std.fmt.fmtSliceHexLower(&server_secret),
//});
},
.TLS_AES_256_GCM_SHA384 => {
const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
const Hash = crypto.hash.sha2.Sha384;
const Hmac = crypto.auth.hmac.Hmac(Hash);
const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash);
const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length));
const empty_hash = emptyHash(Hash);
const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length);
const handshake_secret = Hkdf.extract(&derived_secret, &shared_key);
const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length);
const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length);
client_handshake_key = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length);
server_handshake_key = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length);
client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length);
server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length);
},
.TLS_CHACHA20_POLY1305_SHA256 => {
@panic("TODO");
},
.TLS_AES_128_CCM_SHA256 => {
@panic("TODO");
},
.TLS_AES_128_CCM_8_SHA256 => {
@panic("TODO");
},
}
},
else => return error.TlsUnexpectedMessage,
}
break :i end;
};
var read_seq: u64 = 0;
while (true) {
const end_hdr = i + 5;
if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
@ -467,7 +540,88 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
},
.application_data => {
std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size});
var cleartext_buf: [1000]u8 = undefined;
const cleartext = switch (cipher_suite) {
.TLS_AES_128_GCM_SHA256 => c: {
const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
const ciphertext_len = record_size - AEAD.tag_length;
const ciphertext = handshake_buf[i..][0..ciphertext_len];
i += ciphertext.len;
if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_buf[0..ciphertext.len];
const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*;
const V = @Vector(AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
read_seq += 1;
const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand;
//std.debug.print("seq: {d} nonce: {} operand: {}\n", .{
// read_seq - 1,
// std.fmt.fmtSliceHexLower(&nonce),
// std.fmt.fmtSliceHexLower(&@as([12]u8, operand)),
//});
const ad = handshake_buf[end_hdr - 5 ..][0..5];
const key = server_handshake_key[0..AEAD.key_length].*;
AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch
return error.TlsBadRecordMac;
break :c cleartext;
},
.TLS_AES_256_GCM_SHA384 => c: {
const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
const ciphertext_len = record_size - AEAD.tag_length;
const ciphertext = handshake_buf[i..][0..ciphertext_len];
i += ciphertext.len;
if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_buf[0..ciphertext.len];
const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*;
const V = @Vector(AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
read_seq += 1;
const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand;
const ad = handshake_buf[end_hdr - 5 ..][0..5];
const key = server_handshake_key[0..AEAD.key_length].*;
AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch
return error.TlsBadRecordMac;
break :c cleartext;
},
.TLS_CHACHA20_POLY1305_SHA256 => {
@panic("TODO");
},
.TLS_AES_128_CCM_SHA256 => {
@panic("TODO");
},
.TLS_AES_128_CCM_8_SHA256 => {
@panic("TODO");
},
};
const inner_ct = cleartext[cleartext.len - 1];
switch (inner_ct) {
@enumToInt(ContentType.handshake) => {
const handshake_len = mem.readIntBig(u24, cleartext[1..4]);
if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength;
switch (cleartext[0]) {
@enumToInt(HandshakeType.encrypted_extensions) => {
const ext_size = mem.readIntBig(u16, cleartext[4..6]);
if (ext_size != 0) {
@panic("TODO handle encrypted extensions");
}
std.debug.print("empty encrypted extensions\n", .{});
},
else => {
std.debug.print("handshake type: {d}\n", .{cleartext[0]});
return error.TlsUnexpectedMessage;
},
}
},
else => {
std.debug.print("inner content type: {d}\n", .{inner_ct});
return error.TlsUnexpectedMessage;
},
}
},
else => {
std.debug.print("content type: {s}\n", .{@tagName(ct)});
@ -486,3 +640,56 @@ pub fn writeAll(tls: *Tls, stream: net.Stream, buffer: []const u8) !void {
_ = buffer;
@panic("hold on a minute, we didn't finish implementing the handshake yet");
}
fn hkdfExpandLabel(
comptime Hkdf: type,
key: [Hkdf.prk_length]u8,
label: []const u8,
context: []const u8,
comptime len: usize,
) [len]u8 {
const max_label_len = 255;
const max_context_len = 255;
const tls13 = "tls13 ";
var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined;
mem.writeIntBig(u16, buf[0..2], len);
buf[2] = @intCast(u8, tls13.len + label.len);
buf[3..][0..tls13.len].* = tls13.*;
var i: usize = 3 + tls13.len;
mem.copy(u8, buf[i..], label);
i += label.len;
buf[i] = @intCast(u8, context.len);
i += 1;
mem.copy(u8, buf[i..], context);
i += context.len;
var result: [len]u8 = undefined;
Hkdf.expand(&result, buf[0..i], key);
return result;
}
fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 {
var result: [Hash.digest_length]u8 = undefined;
Hash.hash(&.{}, &result, .{});
return result;
}
fn helloHash(s0: []const u8, s1: []const u8, s2: []const u8, comptime Hash: type) [Hash.digest_length]u8 {
var h = Hash.init(.{});
h.update(s0);
h.update(s1);
h.update(s2);
var result: [Hash.digest_length]u8 = undefined;
h.final(&result);
return result;
}
const builtin = @import("builtin");
const native_endian = builtin.cpu.arch.endian();
inline fn big(x: anytype) @TypeOf(x) {
return switch (native_endian) {
.Big => x,
.Little => @byteSwap(x),
};
}