diff --git a/lib/std/Io/Reader.zig b/lib/std/Io/Reader.zig index b8fa6f2313..ebb70a0e06 100644 --- a/lib/std/Io/Reader.zig +++ b/lib/std/Io/Reader.zig @@ -25,9 +25,7 @@ pub const VTable = struct { /// /// Returns the number of bytes written, which will be at minimum `0` and /// at most `limit`. The number returned, including zero, does not indicate - /// end of stream. `limit` is guaranteed to be at least as large as the - /// buffer capacity of `w`, a value whose minimum size is determined by the - /// stream implementation. + /// end of stream. /// /// The reader's internal logical seek position moves forward in accordance /// with the number of bytes returned from this function. diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 5e89c071c6..aef9a60232 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -61,9 +61,6 @@ pub const ReadError = error{ TlsUnexpectedMessage, TlsIllegalParameter, TlsSequenceOverflow, - /// The buffer provided to the read function was not at least - /// `min_buffer_len`. - OutputBufferUndersize, }; pub const SslKeyLog = struct { @@ -372,7 +369,8 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client }; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch return error.TlsBadRecordMac; - cleartext_fragment_end += std.mem.trimEnd(u8, cleartext, "\x00").len; + // TODO use scalar, non-slice version + cleartext_fragment_end += mem.trimEnd(u8, cleartext, "\x00").len; }, } read_seq += 1; @@ -395,9 +393,9 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..]; if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow; const cleartext = cleartext_fragment_buf[0..message_len]; - const ad = std.mem.toBytes(big(read_seq)) ++ + const ad = mem.toBytes(big(read_seq)) ++ record_header[0 .. 1 + 2] ++ - std.mem.toBytes(big(message_len)); + mem.toBytes(big(message_len)); const record_iv = record_decoder.array(P.record_iv_length).*; const masked_read_seq = read_seq & comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); @@ -738,7 +736,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client &.{ "server finished", &p.transcript_hash.finalResult() }, P.verify_data_length, ), - .app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block), + .app_cipher = mem.bytesToValue(P.Tls_1_2, &key_block), } }; const pv = &p.version.tls_1_2; const nonce: [P.AEAD.nonce_length]u8 = nonce: { @@ -756,7 +754,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client client_verify_cleartext.len ..][0..client_verify_cleartext.len], client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length], &client_verify_cleartext, - std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len), + mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len), nonce, pv.app_cipher.client_write_key, ); @@ -873,7 +871,10 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client .input = input, .reader = .{ .buffer = options.read_buffer, - .vtable = &.{ .stream = stream }, + .vtable = &.{ + .stream = stream, + .readVec = readVec, + }, .seek = 0, .end = 0, }, @@ -1017,7 +1018,7 @@ fn prepareCiphertextRecord( const nonce = nonce: { const V = @Vector(P.AEAD.nonce_length, u8); const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ std.mem.toBytes(big(c.write_seq)); + const operand: V = pad ++ mem.toBytes(big(c.write_seq)); break :nonce @as(V, pv.client_iv) ^ operand; }; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key); @@ -1048,7 +1049,7 @@ fn prepareCiphertextRecord( record_header.* = .{@intFromEnum(inner_content_type)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ int(u16, P.record_iv_length + message_len + P.mac_length); - const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len); + const ad = mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len); const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length]; ciphertext_end += P.record_iv_length; const nonce: [P.AEAD.nonce_length]u8 = nonce: { @@ -1076,7 +1077,22 @@ pub fn eof(c: Client) bool { } fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize { + // This function writes exclusively to the buffer. + _ = w; + _ = limit; const c: *Client = @alignCast(@fieldParentPtr("reader", r)); + return readIndirect(c); +} + +fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize { + // This function writes exclusively to the buffer. + _ = data; + const c: *Client = @alignCast(@fieldParentPtr("reader", r)); + return readIndirect(c); +} + +fn readIndirect(c: *Client) Reader.Error!usize { + const r = &c.reader; if (c.eof()) return error.EndOfStream; const input = c.input; // If at least one full encrypted record is not buffered, read once. @@ -1108,8 +1124,13 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize if (record_end > input.buffered().len) return 0; } - var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { + if (r.seek == r.end) { + r.seek = 0; + r.end = 0; + } + const cleartext_buffer = r.buffer[r.end..]; + + const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { inline else => |*p| switch (c.tls_version) { .tls_1_3 => { const pv = &p.tls_1_3; @@ -1121,23 +1142,24 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize const nonce = nonce: { const V = @Vector(P.AEAD.nonce_length, u8); const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); + const operand: V = pad ++ mem.toBytes(big(c.read_seq)); break :nonce @as(V, pv.server_iv) ^ operand; }; - const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + const cleartext = cleartext_buffer[0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch return failRead(c, error.TlsBadRecordMac); + // TODO use scalar, non-slice version const msg = mem.trimRight(u8, cleartext, "\x00"); - break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; + break :cleartext .{ msg.len - 1, @enumFromInt(msg[msg.len - 1]) }; }, .tls_1_2 => { const pv = &p.tls_1_2; const P = @TypeOf(p.*); const message_len: u16 = record_len - P.record_iv_length - P.mac_length; const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked - const ad = std.mem.toBytes(big(c.read_seq)) ++ + const ad = mem.toBytes(big(c.read_seq)) ++ ad_header[0 .. 1 + 2] ++ - std.mem.toBytes(big(message_len)); + mem.toBytes(big(message_len)); const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked const masked_read_seq = c.read_seq & comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); @@ -1149,14 +1171,15 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize }; const ciphertext = input.take(message_len) catch unreachable; // already peeked const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked - const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + const cleartext = cleartext_buffer[0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch return failRead(c, error.TlsBadRecordMac); - break :cleartext .{ cleartext, ct }; + break :cleartext .{ cleartext.len, ct }; }, else => unreachable, }, }; + const cleartext = cleartext_buffer[0..cleartext_len]; c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow); switch (inner_ct) { .alert => { @@ -1245,9 +1268,8 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize return 0; }, .application_data => { - if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize); - try w.writeAll(cleartext); - return cleartext.len; + r.end += cleartext.len; + return 0; }, else => return failRead(c, error.TlsUnexpectedMessage), }