std.crypto.tls.Client.readvAdvanced: fix bugs

* When there is buffered cleartext, return it without calling the
   underlying read function. This prevents buffer overflow due to space
   used up by cleartext.
 * Avoid clearing the buffer when the buffered cleartext could not be
   completely given to the result read buffer, and there is some
   buffered ciphertext left.
 * Instead of rounding up the amount of bytes to ask for to the nearest
   TLS record size, round down, with a minimum of 1. This prevents the
   code path from being taken which requires extra memory copies.
 * Avoid calling `@memcpy` with overlapping arguments.

closes #15590
This commit is contained in:
Andrew Kelley 2023-05-17 20:39:12 -07:00
parent 378264d404
commit 7cf2cbb33e

View File

@ -924,7 +924,9 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
const amt = @intCast(u15, vp.put(partial_cleartext));
c.partial_cleartext_idx += amt;
if (c.partial_ciphertext_end == c.partial_ciphertext_idx) {
if (c.partial_cleartext_idx == c.partial_ciphertext_idx and
c.partial_ciphertext_end == c.partial_ciphertext_idx)
{
// The buffer is now empty.
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = 0;
@ -935,7 +937,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
c.partial_ciphertext_end = 0;
assert(vp.total == amt);
return amt;
} else if (amt <= partial_cleartext.len) {
} else if (amt > 0) {
// We don't need more data, so don't call read.
assert(vp.total == amt);
return amt;
@ -970,8 +972,8 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
},
};
// Cleartext capacity of output buffer, in records, rounded up.
const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
// Cleartext capacity of output buffer, in records. Minimum one full record.
const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1);
const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len);
const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
@ -1029,7 +1031,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
@memcpy(frag[0..in], first);
limitedOverlapCopy(frag, in);
@memcpy(frag[first.len..][0..second_len], frag1[0..second_len]);
frag = frag[0..full_record_len];
frag1 = frag1[second_len..];
@ -1059,7 +1061,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
@memcpy(frag[0..in], first);
limitedOverlapCopy(frag, in);
@memcpy(frag[first.len..][0..second_len], frag1[0..second_len]);
frag = frag[0..full_record_len];
frag1 = frag1[second_len..];
@ -1176,8 +1178,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
// We have already run out of room in iovecs. Continue
// appending to `partially_read_buffer`.
const dest = c.partially_read_buffer[c.partial_ciphertext_idx..];
@memcpy(dest[0..msg.len], msg);
@memcpy(
c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len],
msg,
);
c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len);
} else {
const amt = vp.put(msg);
@ -1223,22 +1227,38 @@ fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
return out;
}
/// Note that `first` usually overlaps with `c.partially_read_buffer`.
fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize {
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
// There is cleartext at the beginning already which we need to preserve.
c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + first.len + frag1.len);
@memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first);
// TODO: eliminate this call to copyForwards
std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first);
@memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1);
} else {
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = 0;
c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), first.len + frag1.len);
// TODO: eliminate this call to copyForwards
std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first);
@memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1);
}
return out;
}
fn limitedOverlapCopy(frag: []u8, in: usize) void {
const first = frag[in..];
if (first.len <= in) {
// A single, non-overlapping memcpy suffices.
@memcpy(frag[0..first.len], first);
} else {
// Need two memcpy calls because one alone would overlap.
@memcpy(frag[0..in], first[0..in]);
const leftover = first.len - in;
@memcpy(frag[in..][0..leftover], first[in..][0..leftover]);
}
}
fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 {
if (index < s1.len) {
return s1[index];