Merge pull request #16233 from jacobly0/tls

crypto.tls.Client: fix occasional crash in `readvAdvanced`
This commit is contained in:
Andrew Kelley 2023-06-27 00:45:49 -07:00 committed by GitHub
commit d9e867172e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -140,7 +140,7 @@ pub fn InitError(comptime Stream: type) type {
///
/// `host` is only borrowed during this function call.
pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client {
const host_len = @as(u16, @intCast(host.len));
const host_len: u16 = @intCast(host.len);
var random_buffer: [128]u8 = undefined;
crypto.random.bytes(&random_buffer);
@ -194,7 +194,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
int2(host_len);
const extensions_header =
int2(@as(u16, @intCast(extensions_payload.len + host_len))) ++
int2(@intCast(extensions_payload.len + host_len)) ++
extensions_payload;
const legacy_compression_methods = 0x0100;
@ -209,13 +209,13 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
const out_handshake =
[_]u8{@intFromEnum(tls.HandshakeType.client_hello)} ++
int3(@as(u24, @intCast(client_hello.len + host_len))) ++
int3(@intCast(client_hello.len + host_len)) ++
client_hello;
const plaintext_header = [_]u8{
@intFromEnum(tls.ContentType.handshake),
0x03, 0x01, // legacy_record_version
} ++ int2(@as(u16, @intCast(out_handshake.len + host_len))) ++ out_handshake;
} ++ int2(@intCast(out_handshake.len + host_len)) ++ out_handshake;
{
var iovecs = [_]std.os.iovec_const{
@ -466,7 +466,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
},
};
const inner_ct = @as(tls.ContentType, @enumFromInt(cleartext[cleartext.len - 1]));
const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]);
if (inner_ct != .handshake) return error.TlsUnexpectedMessage;
var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]);
@ -520,7 +520,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
const subject_cert: Certificate = .{
.buffer = certd.buf,
.index = @as(u32, @intCast(certd.idx)),
.index = @intCast(certd.idx),
};
const subject = try subject_cert.parse();
if (cert_index == 0) {
@ -534,7 +534,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
if (pub_key.len > main_cert_pub_key_buf.len)
return error.CertificatePublicKeyInvalid;
@memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key);
main_cert_pub_key_len = @as(@TypeOf(main_cert_pub_key_len), @intCast(pub_key.len));
main_cert_pub_key_len = @intCast(pub_key.len);
} else {
try prev_cert.verify(subject, now_sec);
}
@ -679,7 +679,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
.write_seq = 0,
.partial_cleartext_idx = 0,
.partial_ciphertext_idx = 0,
.partial_ciphertext_end = @as(u15, @intCast(leftover.len)),
.partial_ciphertext_end = @intCast(leftover.len),
.received_close_notify = false,
.application_cipher = app_cipher,
.partially_read_buffer = undefined,
@ -797,11 +797,11 @@ fn prepareCiphertextRecord(
const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
while (true) {
const encrypted_content_len = @as(u16, @intCast(@min(
const encrypted_content_len: u16 = @intCast(@min(
@min(bytes.len - bytes_i, max_ciphertext_len - 1),
ciphertext_buf.len - close_notify_alert_reserved -
overhead_len - ciphertext_end,
)));
));
if (encrypted_content_len == 0) return .{
.iovec_end = iovec_end,
.ciphertext_end = ciphertext_end,
@ -920,7 +920,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
// Give away the buffered cleartext we have, if any.
const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx];
if (partial_cleartext.len > 0) {
const amt = @as(u15, @intCast(vp.put(partial_cleartext)));
const amt: u15 = @intCast(vp.put(partial_cleartext));
c.partial_cleartext_idx += amt;
if (c.partial_cleartext_idx == c.partial_ciphertext_idx and
@ -958,6 +958,13 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
// The amount of the user's buffer that will be used to give cleartext. The
// beginning of the buffer will be used for such purposes.
const cleartext_buf_len = free_size - ciphertext_buf_len;
// Recoup `partially_read_buffer space`. This is necessary because it is assumed
// below that `frag0` is big enough to hold at least one record.
limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx);
c.partial_ciphertext_end -= c.partial_ciphertext_idx;
c.partial_ciphertext_idx = 0;
c.partial_cleartext_idx = 0;
const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..];
var ask_iovecs_buf: [2]std.os.iovec = .{
@ -1037,7 +1044,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
in = 0;
continue;
}
const ct = @as(tls.ContentType, @enumFromInt(frag[in]));
const ct: tls.ContentType = @enumFromInt(frag[in]);
in += 1;
const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
in += 2;
@ -1070,8 +1077,8 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
switch (ct) {
.alert => {
if (in + 2 > frag.len) return error.TlsDecodeError;
const level = @as(tls.AlertLevel, @enumFromInt(frag[in]));
const desc = @as(tls.AlertDescription, @enumFromInt(frag[in + 1]));
const level: tls.AlertLevel = @enumFromInt(frag[in]);
const desc: tls.AlertDescription = @enumFromInt(frag[in + 1]);
_ = level;
try desc.toError();
@ -1105,11 +1112,11 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
c.read_seq = try std.math.add(u64, c.read_seq, 1);
const inner_ct = @as(tls.ContentType, @enumFromInt(cleartext[cleartext.len - 1]));
const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]);
switch (inner_ct) {
.alert => {
const level = @as(tls.AlertLevel, @enumFromInt(cleartext[0]));
const desc = @as(tls.AlertDescription, @enumFromInt(cleartext[1]));
const level: tls.AlertLevel = @enumFromInt(cleartext[0]);
const desc: tls.AlertDescription = @enumFromInt(cleartext[1]);
if (desc == .close_notify) {
c.received_close_notify = true;
c.partial_ciphertext_end = c.partial_ciphertext_idx;
@ -1124,7 +1131,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
.handshake => {
var ct_i: usize = 0;
while (true) {
const handshake_type = @as(tls.HandshakeType, @enumFromInt(cleartext[ct_i]));
const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]);
ct_i += 1;
const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
ct_i += 3;
@ -1186,13 +1193,13 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len],
msg,
);
c.partial_ciphertext_idx = @as(@TypeOf(c.partial_ciphertext_idx), @intCast(c.partial_ciphertext_idx + msg.len));
c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + msg.len);
} else {
const amt = vp.put(msg);
if (amt < msg.len) {
const rest = msg[amt..];
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = @as(@TypeOf(c.partial_ciphertext_idx), @intCast(rest.len));
c.partial_ciphertext_idx = @intCast(rest.len);
@memcpy(c.partially_read_buffer[0..rest.len], rest);
}
}
@ -1220,12 +1227,12 @@ fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
const saved_buf = frag[in..];
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
// There is cleartext at the beginning already which we need to preserve.
c.partial_ciphertext_end = @as(@TypeOf(c.partial_ciphertext_end), @intCast(c.partial_ciphertext_idx + saved_buf.len));
c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len);
@memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf);
} else {
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = 0;
c.partial_ciphertext_end = @as(@TypeOf(c.partial_ciphertext_end), @intCast(saved_buf.len));
c.partial_ciphertext_end = @intCast(saved_buf.len);
@memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf);
}
return out;
@ -1235,14 +1242,14 @@ fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize {
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
// There is cleartext at the beginning already which we need to preserve.
c.partial_ciphertext_end = @as(@TypeOf(c.partial_ciphertext_end), @intCast(c.partial_ciphertext_idx + first.len + frag1.len));
c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len);
// TODO: eliminate this call to copyForwards
std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first);
@memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1);
} else {
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = 0;
c.partial_ciphertext_end = @as(@TypeOf(c.partial_ciphertext_end), @intCast(first.len + frag1.len));
c.partial_ciphertext_end = @intCast(first.len + frag1.len);
// TODO: eliminate this call to copyForwards
std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first);
@memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1);