std.crypto.tls: use a Decoder abstraction

This commit introduces tls.Decoder and then uses it in tls.Client. The
purpose is to make it difficult to introduce vulnerabilities in the
parsing code. With this abstraction in place, bugs in the TLS
implementation will trip checks in the decoder, regardless of the actual
length of packets sent by the other party, so that we can have
confidence when using ReleaseFast builds.
This commit is contained in:
Andrew Kelley 2022-12-30 17:57:31 -07:00
parent 341e68ff8f
commit 0fb78b15aa
2 changed files with 423 additions and 336 deletions

View File

@ -39,9 +39,9 @@ const assert = std.debug.assert;
pub const Client = @import("tls/Client.zig"); pub const Client = @import("tls/Client.zig");
pub const ciphertext_record_header_len = 5; pub const record_header_len = 5;
pub const max_ciphertext_len = (1 << 14) + 256; pub const max_ciphertext_len = (1 << 14) + 256;
pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len; pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len;
pub const hello_retry_request_sequence = [32]u8{ pub const hello_retry_request_sequence = [32]u8{
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
@ -360,3 +360,130 @@ pub inline fn int3(x: u24) [3]u8 {
@truncate(u8, x), @truncate(u8, x),
}; };
} }
/// An abstraction to ensure that protocol-parsing code does not perform an
/// out-of-bounds read.
pub const Decoder = struct {
buf: []u8,
/// Points to the next byte in buffer that will be decoded.
idx: usize = 0,
/// Up to this point in `buf` we have already checked that `cap` is greater than it.
our_end: usize = 0,
/// Beyond this point in `buf` is extra tag-along bytes beyond the amount we
/// requested with `readAtLeast`.
their_end: usize = 0,
/// Points to the end within buffer that has been filled. Beyond this point
/// in buf is undefined bytes.
cap: usize = 0,
/// Debug helper to prevent illegal calls to read functions.
disable_reads: bool = false,
pub fn fromTheirSlice(buf: []u8) Decoder {
return .{
.buf = buf,
.their_end = buf.len,
.cap = buf.len,
.disable_reads = true,
};
}
/// Use this function to increase `their_end`.
pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void {
assert(!d.disable_reads);
const existing_amt = d.cap - d.idx;
d.their_end = d.idx + their_amt;
if (their_amt <= existing_amt) return;
const request_amt = their_amt - existing_amt;
const dest = d.buf[d.cap..];
if (request_amt > dest.len) return error.TlsRecordOverflow;
const actual_amt = try stream.readAtLeast(dest, request_amt);
if (actual_amt < request_amt) return error.TlsConnectionTruncated;
d.cap += actual_amt;
}
/// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`.
/// Use when `our_amt` is calculated by us, not by them.
pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void {
assert(!d.disable_reads);
try readAtLeast(d, stream, our_amt);
d.our_end = d.idx + our_amt;
}
/// Use this function to increase `our_end`.
/// This should always be called with an amount provided by us, not them.
pub fn ensure(d: *Decoder, amt: usize) !void {
d.our_end = @max(d.idx + amt, d.our_end);
if (d.our_end > d.their_end) return error.TlsDecodeError;
}
/// Use this function to increase `idx`.
pub fn decode(d: *Decoder, comptime T: type) T {
switch (@typeInfo(T)) {
.Int => |info| switch (info.bits) {
8 => {
skip(d, 1);
return d.buf[d.idx - 1];
},
16 => {
skip(d, 2);
const b0: u16 = d.buf[d.idx - 2];
const b1: u16 = d.buf[d.idx - 1];
return (b0 << 8) | b1;
},
24 => {
skip(d, 3);
const b0: u24 = d.buf[d.idx - 3];
const b1: u24 = d.buf[d.idx - 2];
const b2: u24 = d.buf[d.idx - 1];
return (b0 << 16) | (b1 << 8) | b2;
},
else => @compileError("unsupported int type: " ++ @typeName(T)),
},
.Enum => |info| {
const int = d.decode(info.tag_type);
if (info.is_exhaustive) @compileError("exhaustive enum cannot be used");
return @intToEnum(T, int);
},
else => @compileError("unsupported type: " ++ @typeName(T)),
}
}
/// Use this function to increase `idx`.
pub fn array(d: *Decoder, comptime len: usize) *[len]u8 {
skip(d, len);
return d.buf[d.idx - len ..][0..len];
}
/// Use this function to increase `idx`.
pub fn slice(d: *Decoder, len: usize) []u8 {
skip(d, len);
return d.buf[d.idx - len ..][0..len];
}
/// Use this function to increase `idx`.
pub fn skip(d: *Decoder, amt: usize) void {
d.idx += amt;
assert(d.idx <= d.our_end); // insufficient ensured bytes
}
pub fn eof(d: Decoder) bool {
assert(d.our_end <= d.their_end);
assert(d.idx <= d.our_end);
return d.idx == d.their_end;
}
/// Provide the length they claim, and receive a sub-decoder specific to that slice.
/// The parent decoder is advanced to the end.
pub fn sub(d: *Decoder, their_len: usize) !Decoder {
const end = d.idx + their_len;
if (end > d.their_end) return error.TlsDecodeError;
const sub_buf = d.buf[d.idx..end];
d.idx = end;
d.our_end = end;
return fromTheirSlice(sub_buf);
}
pub fn rest(d: Decoder) []u8 {
return d.buf[d.idx..d.cap];
}
};

View File

@ -126,88 +126,73 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const client_hello_bytes1 = plaintext_header[5..]; const client_hello_bytes1 = plaintext_header[5..];
var handshake_cipher: tls.HandshakeCipher = undefined; var handshake_cipher: tls.HandshakeCipher = undefined;
var handshake_buffer: [8000]u8 = undefined;
var handshake_buf: [8000]u8 = undefined; var d: tls.Decoder = .{ .buf = &handshake_buffer };
var len: usize = 0; {
var i: usize = i: { try d.readAtLeastOurAmt(stream, tls.record_header_len);
const plaintext = handshake_buf[0..5]; const ct = d.decode(tls.ContentType);
len = try stream.readAtLeast(&handshake_buf, plaintext.len); d.skip(2); // legacy_record_version
if (len < plaintext.len) return error.EndOfStream; const record_len = d.decode(u16);
const ct = @intToEnum(tls.ContentType, plaintext[0]); try d.readAtLeast(stream, record_len);
const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]); const server_hello_fragment = d.buf[d.idx..][0..record_len];
const end = plaintext.len + frag_len; var ptd = try d.sub(record_len);
if (end > handshake_buf.len) return error.TlsRecordOverflow;
if (end > len) {
len += try stream.readAtLeast(handshake_buf[len..], end - len);
if (end > len) return error.EndOfStream;
}
const frag = handshake_buf[plaintext.len..end];
switch (ct) { switch (ct) {
.alert => { .alert => {
const level = @intToEnum(tls.AlertLevel, frag[0]); try ptd.ensure(2);
const desc = @intToEnum(tls.AlertDescription, frag[1]); const level = ptd.decode(tls.AlertLevel);
const desc = ptd.decode(tls.AlertDescription);
_ = level; _ = level;
_ = desc; _ = desc;
return error.TlsAlert; return error.TlsAlert;
}, },
.handshake => { .handshake => {
if (frag[0] != @enumToInt(tls.HandshakeType.server_hello)) { try ptd.ensure(4);
const handshake_type = ptd.decode(tls.HandshakeType);
if (handshake_type != .server_hello) return error.TlsUnexpectedMessage;
const length = ptd.decode(u24);
var hsd = try ptd.sub(length);
try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2);
const legacy_version = hsd.decode(u16);
const random = hsd.array(32);
if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) {
// This is a HelloRetryRequest message. This client implementation
// does not expect to get one.
return error.TlsUnexpectedMessage; return error.TlsUnexpectedMessage;
} }
const length = mem.readIntBig(u24, frag[1..4]); const legacy_session_id_echo_len = hsd.decode(u8);
if (4 + length != frag.len) return error.TlsBadLength;
var i: usize = 4;
const legacy_version = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
const random = frag[i..][0..32].*;
i += 32;
if (mem.eql(u8, &random, &tls.hello_retry_request_sequence)) {
@panic("TODO handle HelloRetryRequest");
}
const legacy_session_id_echo_len = frag[i];
i += 1;
if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter;
const legacy_session_id_echo = frag[i..][0..32]; const legacy_session_id_echo = hsd.array(32);
if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id))
return error.TlsIllegalParameter; return error.TlsIllegalParameter;
i += 32; const cipher_suite_tag = hsd.decode(tls.CipherSuite);
const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]); hsd.skip(1); // legacy_compression_method
i += 2; const extensions_size = hsd.decode(u16);
const cipher_suite_tag = @intToEnum(tls.CipherSuite, cipher_suite_int); var all_extd = try hsd.sub(extensions_size);
const legacy_compression_method = frag[i];
i += 1;
_ = legacy_compression_method;
const extensions_size = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
if (i + extensions_size != frag.len) return error.TlsBadLength;
var supported_version: u16 = 0; var supported_version: u16 = 0;
var shared_key: [32]u8 = undefined; var shared_key: [32]u8 = undefined;
var have_shared_key = false; var have_shared_key = false;
while (i < frag.len) { while (!all_extd.eof()) {
const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, frag[i..][0..2])); try all_extd.ensure(2 + 2);
i += 2; const et = all_extd.decode(tls.ExtensionType);
const ext_size = mem.readIntBig(u16, frag[i..][0..2]); const ext_size = all_extd.decode(u16);
i += 2; var extd = try all_extd.sub(ext_size);
const next_i = i + ext_size;
if (next_i > frag.len) return error.TlsBadLength;
switch (et) { switch (et) {
.supported_versions => { .supported_versions => {
if (supported_version != 0) return error.TlsIllegalParameter; if (supported_version != 0) return error.TlsIllegalParameter;
supported_version = mem.readIntBig(u16, frag[i..][0..2]); try extd.ensure(2);
supported_version = extd.decode(u16);
}, },
.key_share => { .key_share => {
if (have_shared_key) return error.TlsIllegalParameter; if (have_shared_key) return error.TlsIllegalParameter;
have_shared_key = true; have_shared_key = true;
const named_group = @intToEnum(tls.NamedGroup, mem.readIntBig(u16, frag[i..][0..2])); try extd.ensure(4);
i += 2; const named_group = extd.decode(tls.NamedGroup);
const key_size = mem.readIntBig(u16, frag[i..][0..2]); const key_size = extd.decode(u16);
i += 2; try extd.ensure(key_size);
switch (named_group) { switch (named_group) {
.x25519 => { .x25519 => {
if (key_size != 32) return error.TlsBadLength; if (key_size != 32) return error.TlsIllegalParameter;
const server_pub_key = frag[i..][0..32]; const server_pub_key = extd.array(32);
shared_key = crypto.dh.X25519.scalarmult( shared_key = crypto.dh.X25519.scalarmult(
x25519_kp.secret_key, x25519_kp.secret_key,
@ -215,7 +200,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
) catch return error.TlsDecryptFailure; ) catch return error.TlsDecryptFailure;
}, },
.secp256r1 => { .secp256r1 => {
const server_pub_key = frag[i..][0..key_size]; const server_pub_key = extd.slice(key_size);
const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey;
const pk = PublicKey.fromSec1(server_pub_key) catch { const pk = PublicKey.fromSec1(server_pub_key) catch {
@ -233,14 +218,12 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
}, },
else => {}, else => {},
} }
i = next_i;
} }
if (!have_shared_key) return error.TlsIllegalParameter; if (!have_shared_key) return error.TlsIllegalParameter;
const tls_version = if (supported_version == 0) legacy_version else supported_version; const tls_version = if (supported_version == 0) legacy_version else supported_version;
switch (tls_version) { if (tls_version != @enumToInt(tls.ProtocolVersion.tls_1_3))
@enumToInt(tls.ProtocolVersion.tls_1_3) => {}, return error.TlsIllegalParameter;
else => return error.TlsIllegalParameter,
}
switch (cipher_suite_tag) { switch (cipher_suite_tag) {
inline .AES_128_GCM_SHA256, inline .AES_128_GCM_SHA256,
@ -264,7 +247,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const p = &@field(handshake_cipher, @tagName(tag)); const p = &@field(handshake_cipher, @tagName(tag));
p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1
p.transcript_hash.update(host); // Client Hello part 2 p.transcript_hash.update(host); // Client Hello part 2
p.transcript_hash.update(frag); // Server Hello p.transcript_hash.update(server_hello_fragment);
const hello_hash = p.transcript_hash.peek(); const hello_hash = p.transcript_hash.peek();
const zeroes = [1]u8{0} ** P.Hash.digest_length; const zeroes = [1]u8{0} ** P.Hash.digest_length;
const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes);
@ -289,8 +272,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
}, },
else => return error.TlsUnexpectedMessage, else => return error.TlsUnexpectedMessage,
} }
break :i end; }
};
// This is used for two purposes: // This is used for two purposes:
// * Detect whether a certificate is the first one presented, in which case // * Detect whether a certificate is the first one presented, in which case
@ -322,29 +304,17 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
var main_cert_pub_key_len: u16 = undefined; var main_cert_pub_key_len: u16 = undefined;
while (true) { while (true) {
const end_hdr = i + 5; try d.readAtLeastOurAmt(stream, tls.record_header_len);
if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; const record_header = d.buf[d.idx..][0..5];
if (end_hdr > len) { const ct = d.decode(tls.ContentType);
len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len); d.skip(2); // legacy_version
if (end_hdr > len) return error.EndOfStream; const record_len = d.decode(u16);
} try d.readAtLeast(stream, record_len);
const ct = @intToEnum(tls.ContentType, handshake_buf[i]); var record_decoder = try d.sub(record_len);
i += 1;
const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
i += 2;
_ = legacy_version;
const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
i += 2;
const end = i + record_size;
if (end > handshake_buf.len) return error.TlsRecordOverflow;
if (end > len) {
len += try stream.readAtLeast(handshake_buf[len..], end - len);
if (end > len) return error.EndOfStream;
}
switch (ct) { switch (ct) {
.change_cipher_spec => { .change_cipher_spec => {
if (record_size != 1) return error.TlsUnexpectedMessage; try record_decoder.ensure(1);
if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter;
}, },
.application_data => { .application_data => {
const cleartext_buf = &cleartext_bufs[cert_index % 2]; const cleartext_buf = &cleartext_bufs[cert_index % 2];
@ -352,276 +322,261 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const cleartext = switch (handshake_cipher) { const cleartext = switch (handshake_cipher) {
inline else => |*p| c: { inline else => |*p| c: {
const P = @TypeOf(p.*); const P = @TypeOf(p.*);
const ciphertext_len = record_size - P.AEAD.tag_length; const ciphertext_len = record_len - P.AEAD.tag_length;
const ciphertext = handshake_buf[i..][0..ciphertext_len]; try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length);
i += ciphertext.len; const ciphertext = record_decoder.slice(ciphertext_len);
if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_buf[0..ciphertext.len]; const cleartext = cleartext_buf[0..ciphertext.len];
const auth_tag = handshake_buf[i..][0..P.AEAD.tag_length].*; const auth_tag = record_decoder.array(P.AEAD.tag_length).*;
const V = @Vector(P.AEAD.nonce_length, u8); const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
read_seq += 1; read_seq += 1;
const nonce = @as(V, p.server_handshake_iv) ^ operand; const nonce = @as(V, p.server_handshake_iv) ^ operand;
const ad = handshake_buf[end_hdr - 5 ..][0..5]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch
return error.TlsBadRecordMac; return error.TlsBadRecordMac;
break :c cleartext; break :c cleartext;
}, },
}; };
const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
switch (inner_ct) { if (inner_ct != .handshake) return error.TlsUnexpectedMessage;
.handshake => {
var ct_i: usize = 0; var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]);
while (true) { while (true) {
const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); try ctd.ensure(4);
ct_i += 1; const handshake_type = ctd.decode(tls.HandshakeType);
const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); const handshake_len = ctd.decode(u24);
ct_i += 3; var hsd = try ctd.sub(handshake_len);
const next_handshake_i = ct_i + handshake_len; const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx];
if (next_handshake_i > cleartext.len - 1) const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx];
return error.TlsBadLength; switch (handshake_type) {
const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i]; .encrypted_extensions => {
const handshake = cleartext[ct_i..next_handshake_i]; if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
switch (handshake_type) { handshake_state = .certificate;
.encrypted_extensions => { switch (handshake_cipher) {
if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; inline else => |*p| p.transcript_hash.update(wrapped_handshake),
handshake_state = .certificate; }
switch (handshake_cipher) { try hsd.ensure(2);
inline else => |*p| p.transcript_hash.update(wrapped_handshake), const total_ext_size = hsd.decode(u16);
} var all_extd = try hsd.sub(total_ext_size);
const total_ext_size = mem.readIntBig(u16, handshake[0..2]); while (!all_extd.eof()) {
var hs_i: usize = 2; try all_extd.ensure(4);
const end_ext_i = 2 + total_ext_size; const et = all_extd.decode(tls.ExtensionType);
while (hs_i < end_ext_i) { const ext_size = all_extd.decode(u16);
const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, handshake[hs_i..][0..2])); var extd = try all_extd.sub(ext_size);
hs_i += 2; _ = extd;
const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); switch (et) {
hs_i += 2; .server_name => {},
const next_ext_i = hs_i + ext_size; else => {},
switch (et) { }
.server_name => {}, }
else => {}, },
} .certificate => cert: {
hs_i = next_ext_i; switch (handshake_cipher) {
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
}
switch (handshake_state) {
.certificate => {},
.trust_chain_established => break :cert,
else => return error.TlsUnexpectedMessage,
}
try hsd.ensure(1 + 4);
const cert_req_ctx_len = hsd.decode(u8);
if (cert_req_ctx_len != 0) return error.TlsIllegalParameter;
const certs_size = hsd.decode(u24);
var certs_decoder = try hsd.sub(certs_size);
while (!certs_decoder.eof()) {
try certs_decoder.ensure(3);
const cert_size = certs_decoder.decode(u24);
var certd = try certs_decoder.sub(cert_size);
const subject_cert: Certificate = .{
.buffer = certd.buf,
.index = @intCast(u32, certd.idx),
};
const subject = try subject_cert.parse();
if (cert_index == 0) {
// Verify the host on the first certificate.
if (!hostMatchesCommonName(host, subject.commonName())) {
return error.TlsCertificateHostMismatch;
} }
// Keep track of the public key for the
// certificate_verify message later.
main_cert_pub_key_algo = subject.pub_key_algo;
const pub_key = subject.pubKey();
if (pub_key.len > main_cert_pub_key_buf.len)
return error.CertificatePublicKeyInvalid;
@memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len);
main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len);
} else {
try prev_cert.verify(subject);
}
if (ca_bundle.verify(subject)) |_| {
handshake_state = .trust_chain_established;
break :cert;
} else |err| switch (err) {
error.CertificateIssuerNotFound => {},
else => |e| return e,
}
prev_cert = subject;
cert_index += 1;
try certs_decoder.ensure(2);
const total_ext_size = certs_decoder.decode(u16);
var all_extd = try certs_decoder.sub(total_ext_size);
_ = all_extd;
}
},
.certificate_verify => {
switch (handshake_state) {
.trust_chain_established => handshake_state = .finished,
.certificate => return error.TlsCertificateNotVerified,
else => return error.TlsUnexpectedMessage,
}
try hsd.ensure(4);
const scheme = hsd.decode(tls.SignatureScheme);
const sig_len = hsd.decode(u16);
try hsd.ensure(sig_len);
const encoded_sig = hsd.slice(sig_len);
const max_digest_len = 64;
var verify_buffer =
([1]u8{0x20} ** 64) ++
"TLS 1.3, server CertificateVerify\x00".* ++
@as([max_digest_len]u8, undefined);
const verify_bytes = switch (handshake_cipher) {
inline else => |*p| v: {
const transcript_digest = p.transcript_hash.peek();
verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest;
p.transcript_hash.update(wrapped_handshake);
break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len];
}, },
.certificate => cert: { };
switch (handshake_cipher) { const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len];
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
}
switch (handshake_state) {
.certificate => {},
.trust_chain_established => break :cert,
else => return error.TlsUnexpectedMessage,
}
var hs_i: u32 = 0;
const cert_req_ctx_len = handshake[hs_i];
hs_i += 1;
if (cert_req_ctx_len != 0) return error.TlsIllegalParameter;
const certs_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
hs_i += 3;
const end_certs = hs_i + certs_size;
while (hs_i < end_certs) {
const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
hs_i += 3;
const end_cert = hs_i + cert_size;
const subject_cert: Certificate = .{ switch (scheme) {
.buffer = handshake, inline .ecdsa_secp256r1_sha256,
.index = hs_i, .ecdsa_secp384r1_sha384,
}; => |comptime_scheme| {
const subject = try subject_cert.parse(); if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey)
if (cert_index == 0) { return error.TlsBadSignatureScheme;
// Verify the host on the first certificate. const Ecdsa = SchemeEcdsa(comptime_scheme);
if (!hostMatchesCommonName(host, subject.commonName())) { const sig = try Ecdsa.Signature.fromDer(encoded_sig);
return error.TlsCertificateHostMismatch; const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key);
} try sig.verify(verify_bytes, key);
// Keep track of the public key for
// the certificate_verify message
// later.
main_cert_pub_key_algo = subject.pub_key_algo;
const pub_key = subject.pubKey();
if (pub_key.len > main_cert_pub_key_buf.len)
return error.CertificatePublicKeyInvalid;
@memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len);
main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len);
} else {
try prev_cert.verify(subject);
}
if (ca_bundle.verify(subject)) |_| {
handshake_state = .trust_chain_established;
break :cert;
} else |err| switch (err) {
error.CertificateIssuerNotFound => {},
else => |e| return e,
}
prev_cert = subject;
cert_index += 1;
hs_i = end_cert;
const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
hs_i += 2;
hs_i += total_ext_size;
}
}, },
.certificate_verify => { .rsa_pss_rsae_sha256 => {
switch (handshake_state) { if (main_cert_pub_key_algo != .rsaEncryption)
.trust_chain_established => handshake_state = .finished, return error.TlsBadSignatureScheme;
.certificate => return error.TlsCertificateNotVerified,
else => return error.TlsUnexpectedMessage,
}
const scheme = @intToEnum(tls.SignatureScheme, mem.readIntBig(u16, handshake[0..2])); const Hash = crypto.hash.sha2.Sha256;
const sig_len = mem.readIntBig(u16, handshake[2..4]); const rsa = Certificate.rsa;
if (4 + sig_len > handshake.len) return error.TlsBadLength; const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
const encoded_sig = handshake[4..][0..sig_len]; const exponent = components.exponent;
const max_digest_len = 64; const modulus = components.modulus;
var verify_buffer = var rsa_mem_buf: [512 * 32]u8 = undefined;
([1]u8{0x20} ** 64) ++ var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf);
"TLS 1.3, server CertificateVerify\x00".* ++ const ally = fba.allocator();
@as([max_digest_len]u8, undefined); switch (modulus.len) {
inline 128, 256, 512 => |modulus_len| {
const verify_bytes = switch (handshake_cipher) { const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally);
inline else => |*p| v: { const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
const transcript_digest = p.transcript_hash.peek(); try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally);
verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest;
p.transcript_hash.update(wrapped_handshake);
break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len];
},
};
const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len];
switch (scheme) {
inline .ecdsa_secp256r1_sha256,
.ecdsa_secp384r1_sha384,
=> |comptime_scheme| {
if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey)
return error.TlsBadSignatureScheme;
const Ecdsa = SchemeEcdsa(comptime_scheme);
const sig = try Ecdsa.Signature.fromDer(encoded_sig);
const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key);
try sig.verify(verify_bytes, key);
},
.rsa_pss_rsae_sha256 => {
if (main_cert_pub_key_algo != .rsaEncryption)
return error.TlsBadSignatureScheme;
const Hash = crypto.hash.sha2.Sha256;
const rsa = Certificate.rsa;
const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
const exponent = components.exponent;
const modulus = components.modulus;
var rsa_mem_buf: [512 * 32]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf);
const ally = fba.allocator();
switch (modulus.len) {
inline 128, 256, 512 => |modulus_len| {
const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally);
const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally);
},
else => {
return error.TlsBadRsaSignatureBitCount;
},
}
}, },
else => { else => {
return error.TlsBadSignatureScheme; return error.TlsBadRsaSignatureBitCount;
}, },
} }
}, },
.finished => {
if (handshake_state != .finished) return error.TlsUnexpectedMessage;
// This message is to trick buggy proxies into behaving correctly.
const client_change_cipher_spec_msg = [_]u8{
@enumToInt(tls.ContentType.change_cipher_spec),
0x03, 0x03, // legacy protocol version
0x00, 0x01, // length
0x01,
};
const app_cipher = switch (handshake_cipher) {
inline else => |*p, tag| c: {
const P = @TypeOf(p.*);
const finished_digest = p.transcript_hash.peek();
p.transcript_hash.update(wrapped_handshake);
const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key);
if (!mem.eql(u8, &expected_server_verify_data, handshake))
return error.TlsDecryptError;
const handshake_hash = p.transcript_hash.finalResult();
const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
const out_cleartext = [_]u8{
@enumToInt(tls.HandshakeType.finished),
0, 0, verify_data.len, // length
} ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)};
const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
var finished_msg = [_]u8{
@enumToInt(tls.ContentType.application_data),
0x03, 0x03, // legacy protocol version
0, wrapped_len, // byte length of encrypted record
} ++ @as([wrapped_len]u8, undefined);
const ad = finished_msg[0..5];
const ciphertext = finished_msg[5..][0..out_cleartext.len];
const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..];
const nonce = p.client_handshake_iv;
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
try stream.writeAll(&both_msgs);
const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{
.client_secret = client_secret,
.server_secret = server_secret,
.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length),
.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length),
});
},
};
var client: Client = .{
.read_seq = 0,
.write_seq = 0,
.partial_cleartext_idx = 0,
.partial_ciphertext_idx = 0,
.partial_ciphertext_end = @intCast(u15, len - end),
.received_close_notify = false,
.application_cipher = app_cipher,
.partially_read_buffer = undefined,
};
mem.copy(u8, &client.partially_read_buffer, handshake_buf[len..end]);
return client;
},
else => { else => {
return error.TlsUnexpectedMessage; return error.TlsBadSignatureScheme;
}, },
} }
ct_i = next_handshake_i; },
if (ct_i >= cleartext.len - 1) break; .finished => {
} if (handshake_state != .finished) return error.TlsUnexpectedMessage;
}, // This message is to trick buggy proxies into behaving correctly.
else => { const client_change_cipher_spec_msg = [_]u8{
return error.TlsUnexpectedMessage; @enumToInt(tls.ContentType.change_cipher_spec),
}, 0x03, 0x03, // legacy protocol version
0x00, 0x01, // length
0x01,
};
const app_cipher = switch (handshake_cipher) {
inline else => |*p, tag| c: {
const P = @TypeOf(p.*);
const finished_digest = p.transcript_hash.peek();
p.transcript_hash.update(wrapped_handshake);
const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key);
if (!mem.eql(u8, &expected_server_verify_data, handshake))
return error.TlsDecryptError;
const handshake_hash = p.transcript_hash.finalResult();
const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
const out_cleartext = [_]u8{
@enumToInt(tls.HandshakeType.finished),
0, 0, verify_data.len, // length
} ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)};
const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
var finished_msg = [_]u8{
@enumToInt(tls.ContentType.application_data),
0x03, 0x03, // legacy protocol version
0, wrapped_len, // byte length of encrypted record
} ++ @as([wrapped_len]u8, undefined);
const ad = finished_msg[0..5];
const ciphertext = finished_msg[5..][0..out_cleartext.len];
const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..];
const nonce = p.client_handshake_iv;
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
try stream.writeAll(&both_msgs);
const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{
.client_secret = client_secret,
.server_secret = server_secret,
.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length),
.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length),
});
},
};
const leftover = d.rest();
var client: Client = .{
.read_seq = 0,
.write_seq = 0,
.partial_cleartext_idx = 0,
.partial_ciphertext_idx = 0,
.partial_ciphertext_end = @intCast(u15, leftover.len),
.received_close_notify = false,
.application_cipher = app_cipher,
.partially_read_buffer = undefined,
};
mem.copy(u8, &client.partially_read_buffer, leftover);
return client;
},
else => {
return error.TlsUnexpectedMessage;
},
}
if (ctd.eof()) break;
} }
}, },
else => { else => {
return error.TlsUnexpectedMessage; return error.TlsUnexpectedMessage;
}, },
} }
i = end;
} }
return error.TlsHandshakeFailure;
} }
pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
@ -638,12 +593,12 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
inline else => |*p| l: { inline else => |*p| l: {
const P = @TypeOf(p.*); const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8); const V = @Vector(P.AEAD.nonce_length, u8);
const overhead_len = tls.ciphertext_record_header_len + P.AEAD.tag_length + 1; const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
while (true) { while (true) {
const encrypted_content_len = @intCast(u16, @min( const encrypted_content_len = @intCast(u16, @min(
@min(bytes.len - bytes_i, max_ciphertext_len - 1), @min(bytes.len - bytes_i, max_ciphertext_len - 1),
ciphertext_buf.len - ciphertext_buf.len -
tls.ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1, tls.record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
)); ));
if (encrypted_content_len == 0) break :l overhead_len; if (encrypted_content_len == 0) break :l overhead_len;
@ -829,7 +784,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
// Cleartext capacity of output buffer, in records, rounded up. // Cleartext capacity of output buffer, in records, rounded up.
const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len; const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len); const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len);
const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
const actual_read_len = try stream.readv(ask_iovecs); const actual_read_len = try stream.readv(ask_iovecs);
@ -860,13 +815,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
continue; continue;
} }
if (in + tls.ciphertext_record_header_len > frag.len) { if (in + tls.record_header_len > frag.len) {
if (frag.ptr == frag1.ptr) if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total); return finishRead(c, frag, in, vp.total);
const first = frag[in..]; const first = frag[in..];
if (frag1.len < tls.ciphertext_record_header_len) if (frag1.len < tls.record_header_len)
return finishRead2(c, first, frag1, vp.total); return finishRead2(c, first, frag1, vp.total);
// A record straddles the two fragments. Copy into the now-empty first fragment. // A record straddles the two fragments. Copy into the now-empty first fragment.
@ -875,7 +830,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const record_len = (record_len_byte_0 << 8) | record_len_byte_1; const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
const full_record_len = record_len + tls.ciphertext_record_header_len; const full_record_len = record_len + tls.record_header_len;
const second_len = full_record_len - first.len; const second_len = full_record_len - first.len;
if (frag1.len < second_len) if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total); return finishRead2(c, first, frag1, vp.total);
@ -898,14 +853,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const end = in + record_len; const end = in + record_len;
if (end > frag.len) { if (end > frag.len) {
// We need the record header on the next iteration of the loop. // We need the record header on the next iteration of the loop.
in -= tls.ciphertext_record_header_len; in -= tls.record_header_len;
if (frag.ptr == frag1.ptr) if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total); return finishRead(c, frag, in, vp.total);
// A record straddles the two fragments. Copy into the now-empty first fragment. // A record straddles the two fragments. Copy into the now-empty first fragment.
const first = frag[in..]; const first = frag[in..];
const full_record_len = record_len + tls.ciphertext_record_header_len; const full_record_len = record_len + tls.record_header_len;
const second_len = full_record_len - first.len; const second_len = full_record_len - first.len;
if (frag1.len < second_len) if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total); return finishRead2(c, first, frag1, vp.total);
@ -919,7 +874,12 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
} }
switch (ct) { switch (ct) {
.alert => { .alert => {
@panic("TODO handle an alert here"); if (in + 2 > frag.len) return error.TlsDecodeError;
const level = @intToEnum(tls.AlertLevel, frag[in]);
const desc = @intToEnum(tls.AlertDescription, frag[in + 1]);
_ = level;
_ = desc;
return error.TlsAlert;
}, },
.application_data => { .application_data => {
const cleartext = switch (c.application_cipher) { const cleartext = switch (c.application_cipher) {