std.crypto.tls: use a Decoder abstraction

This commit introduces tls.Decoder and then uses it in tls.Client. The
purpose is to make it difficult to introduce vulnerabilities in the
parsing code. With this abstraction in place, bugs in the TLS
implementation will trip checks in the decoder, regardless of the actual
length of packets sent by the other party, so that we can have
confidence when using ReleaseFast builds.
This commit is contained in:
Andrew Kelley 2022-12-30 17:57:31 -07:00
parent 341e68ff8f
commit 0fb78b15aa
2 changed files with 423 additions and 336 deletions

View File

@ -39,9 +39,9 @@ const assert = std.debug.assert;
pub const Client = @import("tls/Client.zig");
pub const ciphertext_record_header_len = 5;
pub const record_header_len = 5;
pub const max_ciphertext_len = (1 << 14) + 256;
pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len;
pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len;
pub const hello_retry_request_sequence = [32]u8{
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
@ -360,3 +360,130 @@ pub inline fn int3(x: u24) [3]u8 {
@truncate(u8, x),
};
}
/// An abstraction to ensure that protocol-parsing code does not perform an
/// out-of-bounds read.
pub const Decoder = struct {
buf: []u8,
/// Points to the next byte in buffer that will be decoded.
idx: usize = 0,
/// Up to this point in `buf` we have already checked that `cap` is greater than it.
our_end: usize = 0,
/// Beyond this point in `buf` is extra tag-along bytes beyond the amount we
/// requested with `readAtLeast`.
their_end: usize = 0,
/// Points to the end within buffer that has been filled. Beyond this point
/// in buf is undefined bytes.
cap: usize = 0,
/// Debug helper to prevent illegal calls to read functions.
disable_reads: bool = false,
pub fn fromTheirSlice(buf: []u8) Decoder {
return .{
.buf = buf,
.their_end = buf.len,
.cap = buf.len,
.disable_reads = true,
};
}
/// Use this function to increase `their_end`.
pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void {
assert(!d.disable_reads);
const existing_amt = d.cap - d.idx;
d.their_end = d.idx + their_amt;
if (their_amt <= existing_amt) return;
const request_amt = their_amt - existing_amt;
const dest = d.buf[d.cap..];
if (request_amt > dest.len) return error.TlsRecordOverflow;
const actual_amt = try stream.readAtLeast(dest, request_amt);
if (actual_amt < request_amt) return error.TlsConnectionTruncated;
d.cap += actual_amt;
}
/// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`.
/// Use when `our_amt` is calculated by us, not by them.
pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void {
assert(!d.disable_reads);
try readAtLeast(d, stream, our_amt);
d.our_end = d.idx + our_amt;
}
/// Use this function to increase `our_end`.
/// This should always be called with an amount provided by us, not them.
pub fn ensure(d: *Decoder, amt: usize) !void {
d.our_end = @max(d.idx + amt, d.our_end);
if (d.our_end > d.their_end) return error.TlsDecodeError;
}
/// Use this function to increase `idx`.
pub fn decode(d: *Decoder, comptime T: type) T {
switch (@typeInfo(T)) {
.Int => |info| switch (info.bits) {
8 => {
skip(d, 1);
return d.buf[d.idx - 1];
},
16 => {
skip(d, 2);
const b0: u16 = d.buf[d.idx - 2];
const b1: u16 = d.buf[d.idx - 1];
return (b0 << 8) | b1;
},
24 => {
skip(d, 3);
const b0: u24 = d.buf[d.idx - 3];
const b1: u24 = d.buf[d.idx - 2];
const b2: u24 = d.buf[d.idx - 1];
return (b0 << 16) | (b1 << 8) | b2;
},
else => @compileError("unsupported int type: " ++ @typeName(T)),
},
.Enum => |info| {
const int = d.decode(info.tag_type);
if (info.is_exhaustive) @compileError("exhaustive enum cannot be used");
return @intToEnum(T, int);
},
else => @compileError("unsupported type: " ++ @typeName(T)),
}
}
/// Use this function to increase `idx`.
pub fn array(d: *Decoder, comptime len: usize) *[len]u8 {
skip(d, len);
return d.buf[d.idx - len ..][0..len];
}
/// Use this function to increase `idx`.
pub fn slice(d: *Decoder, len: usize) []u8 {
skip(d, len);
return d.buf[d.idx - len ..][0..len];
}
/// Use this function to increase `idx`.
pub fn skip(d: *Decoder, amt: usize) void {
d.idx += amt;
assert(d.idx <= d.our_end); // insufficient ensured bytes
}
pub fn eof(d: Decoder) bool {
assert(d.our_end <= d.their_end);
assert(d.idx <= d.our_end);
return d.idx == d.their_end;
}
/// Provide the length they claim, and receive a sub-decoder specific to that slice.
/// The parent decoder is advanced to the end.
pub fn sub(d: *Decoder, their_len: usize) !Decoder {
const end = d.idx + their_len;
if (end > d.their_end) return error.TlsDecodeError;
const sub_buf = d.buf[d.idx..end];
d.idx = end;
d.our_end = end;
return fromTheirSlice(sub_buf);
}
pub fn rest(d: Decoder) []u8 {
return d.buf[d.idx..d.cap];
}
};

View File

@ -126,88 +126,73 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const client_hello_bytes1 = plaintext_header[5..];
var handshake_cipher: tls.HandshakeCipher = undefined;
var handshake_buf: [8000]u8 = undefined;
var len: usize = 0;
var i: usize = i: {
const plaintext = handshake_buf[0..5];
len = try stream.readAtLeast(&handshake_buf, plaintext.len);
if (len < plaintext.len) return error.EndOfStream;
const ct = @intToEnum(tls.ContentType, plaintext[0]);
const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
const end = plaintext.len + frag_len;
if (end > handshake_buf.len) return error.TlsRecordOverflow;
if (end > len) {
len += try stream.readAtLeast(handshake_buf[len..], end - len);
if (end > len) return error.EndOfStream;
}
const frag = handshake_buf[plaintext.len..end];
var handshake_buffer: [8000]u8 = undefined;
var d: tls.Decoder = .{ .buf = &handshake_buffer };
{
try d.readAtLeastOurAmt(stream, tls.record_header_len);
const ct = d.decode(tls.ContentType);
d.skip(2); // legacy_record_version
const record_len = d.decode(u16);
try d.readAtLeast(stream, record_len);
const server_hello_fragment = d.buf[d.idx..][0..record_len];
var ptd = try d.sub(record_len);
switch (ct) {
.alert => {
const level = @intToEnum(tls.AlertLevel, frag[0]);
const desc = @intToEnum(tls.AlertDescription, frag[1]);
try ptd.ensure(2);
const level = ptd.decode(tls.AlertLevel);
const desc = ptd.decode(tls.AlertDescription);
_ = level;
_ = desc;
return error.TlsAlert;
},
.handshake => {
if (frag[0] != @enumToInt(tls.HandshakeType.server_hello)) {
try ptd.ensure(4);
const handshake_type = ptd.decode(tls.HandshakeType);
if (handshake_type != .server_hello) return error.TlsUnexpectedMessage;
const length = ptd.decode(u24);
var hsd = try ptd.sub(length);
try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2);
const legacy_version = hsd.decode(u16);
const random = hsd.array(32);
if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) {
// This is a HelloRetryRequest message. This client implementation
// does not expect to get one.
return error.TlsUnexpectedMessage;
}
const length = mem.readIntBig(u24, frag[1..4]);
if (4 + length != frag.len) return error.TlsBadLength;
var i: usize = 4;
const legacy_version = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
const random = frag[i..][0..32].*;
i += 32;
if (mem.eql(u8, &random, &tls.hello_retry_request_sequence)) {
@panic("TODO handle HelloRetryRequest");
}
const legacy_session_id_echo_len = frag[i];
i += 1;
const legacy_session_id_echo_len = hsd.decode(u8);
if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter;
const legacy_session_id_echo = frag[i..][0..32];
const legacy_session_id_echo = hsd.array(32);
if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id))
return error.TlsIllegalParameter;
i += 32;
const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
const cipher_suite_tag = @intToEnum(tls.CipherSuite, cipher_suite_int);
const legacy_compression_method = frag[i];
i += 1;
_ = legacy_compression_method;
const extensions_size = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
if (i + extensions_size != frag.len) return error.TlsBadLength;
const cipher_suite_tag = hsd.decode(tls.CipherSuite);
hsd.skip(1); // legacy_compression_method
const extensions_size = hsd.decode(u16);
var all_extd = try hsd.sub(extensions_size);
var supported_version: u16 = 0;
var shared_key: [32]u8 = undefined;
var have_shared_key = false;
while (i < frag.len) {
const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, frag[i..][0..2]));
i += 2;
const ext_size = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
const next_i = i + ext_size;
if (next_i > frag.len) return error.TlsBadLength;
while (!all_extd.eof()) {
try all_extd.ensure(2 + 2);
const et = all_extd.decode(tls.ExtensionType);
const ext_size = all_extd.decode(u16);
var extd = try all_extd.sub(ext_size);
switch (et) {
.supported_versions => {
if (supported_version != 0) return error.TlsIllegalParameter;
supported_version = mem.readIntBig(u16, frag[i..][0..2]);
try extd.ensure(2);
supported_version = extd.decode(u16);
},
.key_share => {
if (have_shared_key) return error.TlsIllegalParameter;
have_shared_key = true;
const named_group = @intToEnum(tls.NamedGroup, mem.readIntBig(u16, frag[i..][0..2]));
i += 2;
const key_size = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
try extd.ensure(4);
const named_group = extd.decode(tls.NamedGroup);
const key_size = extd.decode(u16);
try extd.ensure(key_size);
switch (named_group) {
.x25519 => {
if (key_size != 32) return error.TlsBadLength;
const server_pub_key = frag[i..][0..32];
if (key_size != 32) return error.TlsIllegalParameter;
const server_pub_key = extd.array(32);
shared_key = crypto.dh.X25519.scalarmult(
x25519_kp.secret_key,
@ -215,7 +200,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
) catch return error.TlsDecryptFailure;
},
.secp256r1 => {
const server_pub_key = frag[i..][0..key_size];
const server_pub_key = extd.slice(key_size);
const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey;
const pk = PublicKey.fromSec1(server_pub_key) catch {
@ -233,14 +218,12 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
},
else => {},
}
i = next_i;
}
if (!have_shared_key) return error.TlsIllegalParameter;
const tls_version = if (supported_version == 0) legacy_version else supported_version;
switch (tls_version) {
@enumToInt(tls.ProtocolVersion.tls_1_3) => {},
else => return error.TlsIllegalParameter,
}
if (tls_version != @enumToInt(tls.ProtocolVersion.tls_1_3))
return error.TlsIllegalParameter;
switch (cipher_suite_tag) {
inline .AES_128_GCM_SHA256,
@ -264,7 +247,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const p = &@field(handshake_cipher, @tagName(tag));
p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1
p.transcript_hash.update(host); // Client Hello part 2
p.transcript_hash.update(frag); // Server Hello
p.transcript_hash.update(server_hello_fragment);
const hello_hash = p.transcript_hash.peek();
const zeroes = [1]u8{0} ** P.Hash.digest_length;
const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes);
@ -289,8 +272,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
},
else => return error.TlsUnexpectedMessage,
}
break :i end;
};
}
// This is used for two purposes:
// * Detect whether a certificate is the first one presented, in which case
@ -322,29 +304,17 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
var main_cert_pub_key_len: u16 = undefined;
while (true) {
const end_hdr = i + 5;
if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
if (end_hdr > len) {
len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
if (end_hdr > len) return error.EndOfStream;
}
const ct = @intToEnum(tls.ContentType, handshake_buf[i]);
i += 1;
const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
i += 2;
_ = legacy_version;
const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
i += 2;
const end = i + record_size;
if (end > handshake_buf.len) return error.TlsRecordOverflow;
if (end > len) {
len += try stream.readAtLeast(handshake_buf[len..], end - len);
if (end > len) return error.EndOfStream;
}
try d.readAtLeastOurAmt(stream, tls.record_header_len);
const record_header = d.buf[d.idx..][0..5];
const ct = d.decode(tls.ContentType);
d.skip(2); // legacy_version
const record_len = d.decode(u16);
try d.readAtLeast(stream, record_len);
var record_decoder = try d.sub(record_len);
switch (ct) {
.change_cipher_spec => {
if (record_size != 1) return error.TlsUnexpectedMessage;
if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
try record_decoder.ensure(1);
if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter;
},
.application_data => {
const cleartext_buf = &cleartext_bufs[cert_index % 2];
@ -352,38 +322,34 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const cleartext = switch (handshake_cipher) {
inline else => |*p| c: {
const P = @TypeOf(p.*);
const ciphertext_len = record_size - P.AEAD.tag_length;
const ciphertext = handshake_buf[i..][0..ciphertext_len];
i += ciphertext.len;
const ciphertext_len = record_len - P.AEAD.tag_length;
try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length);
const ciphertext = record_decoder.slice(ciphertext_len);
if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_buf[0..ciphertext.len];
const auth_tag = handshake_buf[i..][0..P.AEAD.tag_length].*;
const auth_tag = record_decoder.array(P.AEAD.tag_length).*;
const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
read_seq += 1;
const nonce = @as(V, p.server_handshake_iv) ^ operand;
const ad = handshake_buf[end_hdr - 5 ..][0..5];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch
return error.TlsBadRecordMac;
break :c cleartext;
},
};
const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
switch (inner_ct) {
.handshake => {
var ct_i: usize = 0;
if (inner_ct != .handshake) return error.TlsUnexpectedMessage;
var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]);
while (true) {
const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]);
ct_i += 1;
const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
ct_i += 3;
const next_handshake_i = ct_i + handshake_len;
if (next_handshake_i > cleartext.len - 1)
return error.TlsBadLength;
const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i];
const handshake = cleartext[ct_i..next_handshake_i];
try ctd.ensure(4);
const handshake_type = ctd.decode(tls.HandshakeType);
const handshake_len = ctd.decode(u24);
var hsd = try ctd.sub(handshake_len);
const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx];
const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx];
switch (handshake_type) {
.encrypted_extensions => {
if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
@ -391,20 +357,19 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
switch (handshake_cipher) {
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
}
const total_ext_size = mem.readIntBig(u16, handshake[0..2]);
var hs_i: usize = 2;
const end_ext_i = 2 + total_ext_size;
while (hs_i < end_ext_i) {
const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, handshake[hs_i..][0..2]));
hs_i += 2;
const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
hs_i += 2;
const next_ext_i = hs_i + ext_size;
try hsd.ensure(2);
const total_ext_size = hsd.decode(u16);
var all_extd = try hsd.sub(total_ext_size);
while (!all_extd.eof()) {
try all_extd.ensure(4);
const et = all_extd.decode(tls.ExtensionType);
const ext_size = all_extd.decode(u16);
var extd = try all_extd.sub(ext_size);
_ = extd;
switch (et) {
.server_name => {},
else => {},
}
hs_i = next_ext_i;
}
},
.certificate => cert: {
@ -416,21 +381,19 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
.trust_chain_established => break :cert,
else => return error.TlsUnexpectedMessage,
}
var hs_i: u32 = 0;
const cert_req_ctx_len = handshake[hs_i];
hs_i += 1;
try hsd.ensure(1 + 4);
const cert_req_ctx_len = hsd.decode(u8);
if (cert_req_ctx_len != 0) return error.TlsIllegalParameter;
const certs_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
hs_i += 3;
const end_certs = hs_i + certs_size;
while (hs_i < end_certs) {
const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
hs_i += 3;
const end_cert = hs_i + cert_size;
const certs_size = hsd.decode(u24);
var certs_decoder = try hsd.sub(certs_size);
while (!certs_decoder.eof()) {
try certs_decoder.ensure(3);
const cert_size = certs_decoder.decode(u24);
var certd = try certs_decoder.sub(cert_size);
const subject_cert: Certificate = .{
.buffer = handshake,
.index = hs_i,
.buffer = certd.buf,
.index = @intCast(u32, certd.idx),
};
const subject = try subject_cert.parse();
if (cert_index == 0) {
@ -439,9 +402,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
return error.TlsCertificateHostMismatch;
}
// Keep track of the public key for
// the certificate_verify message
// later.
// Keep track of the public key for the
// certificate_verify message later.
main_cert_pub_key_algo = subject.pub_key_algo;
const pub_key = subject.pubKey();
if (pub_key.len > main_cert_pub_key_buf.len)
@ -463,10 +425,10 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
prev_cert = subject;
cert_index += 1;
hs_i = end_cert;
const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
hs_i += 2;
hs_i += total_ext_size;
try certs_decoder.ensure(2);
const total_ext_size = certs_decoder.decode(u16);
var all_extd = try certs_decoder.sub(total_ext_size);
_ = all_extd;
}
},
.certificate_verify => {
@ -476,10 +438,11 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
else => return error.TlsUnexpectedMessage,
}
const scheme = @intToEnum(tls.SignatureScheme, mem.readIntBig(u16, handshake[0..2]));
const sig_len = mem.readIntBig(u16, handshake[2..4]);
if (4 + sig_len > handshake.len) return error.TlsBadLength;
const encoded_sig = handshake[4..][0..sig_len];
try hsd.ensure(4);
const scheme = hsd.decode(tls.SignatureScheme);
const sig_len = hsd.decode(u16);
try hsd.ensure(sig_len);
const encoded_sig = hsd.slice(sig_len);
const max_digest_len = 64;
var verify_buffer =
([1]u8{0x20} ** 64) ++
@ -588,40 +551,32 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
});
},
};
const leftover = d.rest();
var client: Client = .{
.read_seq = 0,
.write_seq = 0,
.partial_cleartext_idx = 0,
.partial_ciphertext_idx = 0,
.partial_ciphertext_end = @intCast(u15, len - end),
.partial_ciphertext_end = @intCast(u15, leftover.len),
.received_close_notify = false,
.application_cipher = app_cipher,
.partially_read_buffer = undefined,
};
mem.copy(u8, &client.partially_read_buffer, handshake_buf[len..end]);
mem.copy(u8, &client.partially_read_buffer, leftover);
return client;
},
else => {
return error.TlsUnexpectedMessage;
},
}
ct_i = next_handshake_i;
if (ct_i >= cleartext.len - 1) break;
if (ctd.eof()) break;
}
},
else => {
return error.TlsUnexpectedMessage;
},
}
},
else => {
return error.TlsUnexpectedMessage;
},
}
i = end;
}
return error.TlsHandshakeFailure;
}
pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
@ -638,12 +593,12 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
inline else => |*p| l: {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
const overhead_len = tls.ciphertext_record_header_len + P.AEAD.tag_length + 1;
const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
while (true) {
const encrypted_content_len = @intCast(u16, @min(
@min(bytes.len - bytes_i, max_ciphertext_len - 1),
ciphertext_buf.len -
tls.ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
tls.record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
));
if (encrypted_content_len == 0) break :l overhead_len;
@ -829,7 +784,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
// Cleartext capacity of output buffer, in records, rounded up.
const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len);
const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len);
const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
const actual_read_len = try stream.readv(ask_iovecs);
@ -860,13 +815,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
continue;
}
if (in + tls.ciphertext_record_header_len > frag.len) {
if (in + tls.record_header_len > frag.len) {
if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total);
const first = frag[in..];
if (frag1.len < tls.ciphertext_record_header_len)
if (frag1.len < tls.record_header_len)
return finishRead2(c, first, frag1, vp.total);
// A record straddles the two fragments. Copy into the now-empty first fragment.
@ -875,7 +830,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
const full_record_len = record_len + tls.ciphertext_record_header_len;
const full_record_len = record_len + tls.record_header_len;
const second_len = full_record_len - first.len;
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
@ -898,14 +853,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const end = in + record_len;
if (end > frag.len) {
// We need the record header on the next iteration of the loop.
in -= tls.ciphertext_record_header_len;
in -= tls.record_header_len;
if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total);
// A record straddles the two fragments. Copy into the now-empty first fragment.
const first = frag[in..];
const full_record_len = record_len + tls.ciphertext_record_header_len;
const full_record_len = record_len + tls.record_header_len;
const second_len = full_record_len - first.len;
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
@ -919,7 +874,12 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
}
switch (ct) {
.alert => {
@panic("TODO handle an alert here");
if (in + 2 > frag.len) return error.TlsDecodeError;
const level = @intToEnum(tls.AlertLevel, frag[in]);
const desc = @intToEnum(tls.AlertDescription, frag[in + 1]);
_ = level;
_ = desc;
return error.TlsAlert;
},
.application_data => {
const cleartext = switch (c.application_cipher) {