diff --git a/lib/std/crypto/ecdsa.zig b/lib/std/crypto/ecdsa.zig index 76a1761ec5..ad7b682b7e 100644 --- a/lib/std/crypto/ecdsa.zig +++ b/lib/std/crypto/ecdsa.zig @@ -168,7 +168,7 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { has_top_bit = true; } const out_slice = out[out.len - expected_len ..]; - br.read(out_slice) catch return error.InvalidEncoding; + br.readSlice(out_slice) catch return error.InvalidEncoding; if (@intFromBool(has_top_bit) != out[0] >> 7) return error.InvalidEncoding; } @@ -177,7 +177,7 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { pub fn fromDer(der: []const u8) EncodingError!Signature { if (der.len < 2) return error.InvalidEncoding; var br: std.io.BufferedReader = undefined; - br.initFixed(der); + br.initFixed(@constCast(der)); const buf = br.take(2) catch return error.InvalidEncoding; if (buf[0] != 0x30 or @as(usize, buf[1]) + 2 != der.len) return error.InvalidEncoding; var sig: Signature = mem.zeroInit(Signature, .{}); diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index c2071bcbc5..73a5fef528 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -655,7 +655,7 @@ pub const Decoder = struct { } /// Use this function to increase `their_end`. - pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void { + pub fn readAtLeast(d: *Decoder, stream: *std.io.BufferedReader, their_amt: usize) !void { assert(!d.disable_reads); const existing_amt = d.cap - d.idx; d.their_end = d.idx + their_amt; @@ -663,14 +663,16 @@ pub const Decoder = struct { const request_amt = their_amt - existing_amt; const dest = d.buf[d.cap..]; if (request_amt > dest.len) return error.TlsRecordOverflow; - const actual_amt = try stream.readAtLeast(dest, request_amt); - if (actual_amt < request_amt) return error.TlsConnectionTruncated; - d.cap += actual_amt; + stream.readSlice(dest[0..request_amt]) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + d.cap += request_amt; } /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`. /// Use when `our_amt` is calculated by us, not by them. - pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void { + pub fn readAtLeastOurAmt(d: *Decoder, stream: *std.io.BufferedReader, our_amt: usize) !void { assert(!d.disable_reads); try readAtLeast(d, stream, our_amt); d.our_end = d.idx + our_amt; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index a393691828..a422c5d74f 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -4,11 +4,12 @@ const native_endian = builtin.cpu.arch.endian(); const std = @import("../../std.zig"); const tls = std.crypto.tls; const Client = @This(); -const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; +const Reader = std.io.Reader; +const Writer = std.io.Writer; const max_ciphertext_len = tls.max_ciphertext_len; const hmacExpandLabel = tls.hmacExpandLabel; @@ -21,38 +22,22 @@ const array = tls.array; /// /// The buffer is asserted to have capacity at least `min_buffer_len`. /// -/// The size is enough to contain exactly one TLSCiphertext record. -/// This buffer is segmented into four parts: -/// 0. unused -/// 1. cleartext -/// 2. ciphertext -/// 3. unused -/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and -/// `partial_ciphertext_end` describe the span of the segments. +/// `remaining_cleartext_len` tells how many bytes inside this buffer have +/// already been decrypted. input: *std.io.BufferedReader, +/// Tells how many bytes inside `input` have already been decrypted. +remaining_cleartext_len: u15, + /// The encrypted stream from the client to the server. Bytes are pushed here /// via `writer`. -/// -/// The buffer is asserted to have capacity at least `min_buffer_len`. output: *std.io.BufferedWriter, -/// Cleartext received from the server here. -/// -/// Its buffer aliases the buffer of `input`. -reader: std.io.BufferedReader, -/// Populated when `error.TlsAlert` is returned. -alert: ?tls.Alert, -read_err: ?ReadError, +/// Populated when `error.TlsAlert` is returned. +alert: ?tls.Alert = null, +read_err: ?ReadError = null, tls_version: tls.ProtocolVersion, read_seq: u64, write_seq: u64, -/// The starting index of cleartext bytes inside the input buffer. -partial_cleartext_idx: u15, -/// The ending index of cleartext bytes inside the input buffer as well -/// as the starting index of ciphertext bytes. -partial_ciphertext_idx: u15, -/// The ending index of ciphertext bytes inside the input buffer. -partial_ciphertext_end: u15, /// When this is true, the stream may still not be at the end because there /// may be data in the input buffer. received_close_notify: bool, @@ -60,11 +45,13 @@ received_close_notify: bool, /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify /// message has been received. By setting this flag to `true`, instead, the /// end-of-stream will be forwarded to the application layer above TLS. +/// /// This makes the application vulnerable to truncation attacks unless the /// application layer itself verifies that the amount of data received equals /// the amount of data expected, such as HTTP with the Content-Length header. allow_truncation_attacks: bool, application_cipher: tls.ApplicationCipher, + /// If non-null, ssl secrets are logged to a stream. Creating such a log file /// allows other programs with access to that file to decrypt all traffic over /// this connection. @@ -80,6 +67,7 @@ pub const ReadError = error{ TlsRecordOverflow, TlsUnexpectedMessage, TlsIllegalParameter, + TlsSequenceOverflow, }; pub const SslKeyLog = struct { @@ -99,8 +87,8 @@ pub const SslKeyLog = struct { } }; -/// The `std.io.BufferedReader` and `std.io.BufferedWriter` supplied to `init` -/// each require a buffer capacity at least this amount. +/// The `std.io.BufferedReader` supplied to `init` requires a buffer capacity +/// at least this amount. pub const min_buffer_len = tls.max_ciphertext_record_len; pub const Options = struct { @@ -126,7 +114,10 @@ pub const Options = struct { }, /// If non-null, ssl secrets are logged to this stream. Creating such a log file allows /// other programs with access to that file to decrypt all traffic over this connection. - ssl_key_log: ?*std.io.BufferedWriter = null, + /// + /// Only the `writer` field is observed during the handshake (`init`). + /// After that, the other fields are populated. + ssl_key_log: ?*SslKeyLog = null, }; const InitError = error{ @@ -183,16 +174,14 @@ const InitError = error{ /// /// `host` is only borrowed during this function call. /// -/// Both `input` and `output` are asserted to have buffer capacity at least -/// `min_buffer_len`. +/// `input` is asserted to have buffer capacity at least `min_buffer_len`. pub fn init( client: *Client, input: *std.io.BufferedReader, output: *std.io.BufferedWriter, options: Options, ) InitError!void { - assert(input.storage.buffer.len >= min_buffer_len); - assert(output.buffer.len >= min_buffer_len); + assert(input.buffer.len >= min_buffer_len); client.alert = null; const host = switch (options.host) { .no_verification => "", @@ -286,7 +275,7 @@ pub fn init( { var iovecs: [2][]const u8 = .{ cleartext_header, host }; - try output.writevAll(iovecs[0..if (host.len == 0) 1 else 2]); + try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]); } var tls_version: tls.ProtocolVersion = undefined; @@ -335,20 +324,26 @@ pub fn init( var cleartext_fragment_start: usize = 0; var cleartext_fragment_end: usize = 0; var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; - var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; fragment: while (true) { - try d.readAtLeastOurAmt(input, tls.record_header_len); - const record_header = d.buf[d.idx..][0..tls.record_header_len]; - const record_ct = d.decode(tls.ContentType); - d.skip(2); // legacy_version - const record_len = d.decode(u16); - try d.readAtLeast(input, record_len); - var record_decoder = try d.sub(record_len); + // Ensure the input buffer pointer is stable in this scope. + input.rebaseCapacity(tls.max_ciphertext_record_len); + const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + const record_ct = input.takeEnumNonexhaustive(tls.ContentType, .big) catch unreachable; // already peeked + input.toss(2); // legacy_version + const record_len = input.takeInt(u16, .big) catch unreachable; // already peeked + if (record_len > tls.max_ciphertext_len) return error.TlsRecordOverflow; + const record_buffer = input.take(record_len) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + var record_decoder: tls.Decoder = .fromTheirSlice(record_buffer); var ctd, const ct = content: switch (cipher_state) { .cleartext => .{ record_decoder, record_ct }, .handshake => { - std.debug.assert(tls_version == .tls_1_3); + assert(tls_version == .tls_1_3); if (record_ct != .application_data) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; @@ -380,7 +375,7 @@ pub fn init( break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct }; }, .application => { - std.debug.assert(tls_version == .tls_1_2); + assert(tls_version == .tls_1_2); if (record_ct != .handshake) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; @@ -536,7 +531,7 @@ pub fn init( pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .client_random = &client_hello_rand, }, .{ .SERVER_HANDSHAKE_TRAFFIC_SECRET = &server_secret, @@ -710,7 +705,7 @@ pub fn init( &client_hello_rand, &server_hello_rand, }, 48); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .client_random = &client_hello_rand, }, .{ .CLIENT_RANDOM = &master_secret, @@ -763,7 +758,7 @@ pub fn init( &client_change_cipher_spec_msg, &client_verify_msg, }; - try output.writevAll(&all_msgs_vec); + try output.writeVecAll(&all_msgs_vec); }, } write_seq += 1; @@ -828,11 +823,11 @@ pub fn init( &client_change_cipher_spec_msg, &finished_msg, }; - try output.writevAll(&all_msgs_vec); + try output.writeVecAll(&all_msgs_vec); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .counter = key_seq, .client_random = &client_hello_rand, }, .{ @@ -859,11 +854,9 @@ pub fn init( else => unreachable, }, }; - const leftover = d.rest(); client.* = .{ .input = input, .output = output, - .reader = undefined, .tls_version = tls_version, .read_seq = switch (tls_version) { .tls_1_3 => 0, @@ -875,29 +868,18 @@ pub fn init( .tls_1_2 => write_seq, else => unreachable, }, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(leftover.len), + .remaining_cleartext_len = 0, .received_close_notify = false, .allow_truncation_attacks = false, .application_cipher = app_cipher, - .partially_read_buffer = undefined, - .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{ - .client_key_seq = key_seq, - .server_key_seq = key_seq, - .client_random = client_hello_rand, - .file = key_log_file, - } else null, + .ssl_key_log = options.ssl_key_log, + }; + if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{ + .client_key_seq = key_seq, + .server_key_seq = key_seq, + .client_random = client_hello_rand, + .writer = ssl_key_log.writer, }; - @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - client.reader.init(.{ - .context = client, - .vtable = &.{ - .read = read, - .readVec = readVec, - .discard = discard, - }, - }, input.storage.buffer[0..0]); return; }, else => return error.TlsUnexpectedMessage, @@ -912,17 +894,28 @@ pub fn init( } } -pub fn writer(c: *Client) std.io.Writer { +pub fn reader(c: *Client) Reader { return .{ .context = c, .vtable = &.{ - .writeSplat = writeSplat, - .writeFile = std.io.Writer.unimplementedWriteFile, + .read = read, + .readVec = readVec, + .discard = discard, }, }; } -fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize { +pub fn writer(c: *Client) Writer { + return .{ + .context = c, + .vtable = &.{ + .writeSplat = writeSplat, + .writeFile = Writer.unimplementedWriteFile, + }, + }; +} + +fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) Writer.Error!usize { const c: *Client = @alignCast(@ptrCast(context)); const sliced_data = if (splat == 0) data[0..data.len -| 1] else data; const output = c.output; @@ -942,7 +935,7 @@ fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) std.i /// Sends a `close_notify` alert, which is necessary for the server to /// distinguish between a properly finished TLS session, or a truncation /// attack. -pub fn end(c: *Client) std.io.Writer.Error!void { +pub fn end(c: *Client) Writer.Error!void { const output = c.output; const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert); @@ -1054,372 +1047,212 @@ fn prepareCiphertextRecord( } pub fn eof(c: Client) bool { - return c.received_close_notify and - c.partial_cleartext_idx >= c.partial_ciphertext_idx and - c.partial_ciphertext_idx >= c.partial_ciphertext_end; + return c.received_close_notify and c.remaining_cleartext_len == 0; } -fn read( - context: ?*anyopaque, - bw: *std.io.BufferedWriter, - limit: std.io.Reader.Limit, -) std.io.Reader.RwError!usize { - const buf = limit.slice(try bw.writableSliceGreedy(1)); - const n = try readVec(context, &.{buf}); - bw.advance(n); - return n; -} - -fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { +fn read(context: ?*anyopaque, bw: *std.io.BufferedWriter, limit: Reader.Limit) Reader.RwError!usize { const c: *Client = @ptrCast(@alignCast(context)); if (c.eof()) return error.EndOfStream; - - var vp: VecPut = .{ .iovecs = data }; - - // Give away the buffered cleartext we have, if any. - const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; - if (partial_cleartext.len > 0) { - const amt: u15 = @intCast(vp.put(partial_cleartext)); - c.partial_cleartext_idx += amt; - - if (c.partial_cleartext_idx == c.partial_ciphertext_idx and - c.partial_ciphertext_end == c.partial_ciphertext_idx) - { - // The buffer is now empty. - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = 0; - } - - if (c.received_close_notify) { - c.partial_ciphertext_end = 0; - assert(vp.total == amt); - return amt; - } else if (amt > 0) { - // We don't need more data, so don't call read. - assert(vp.total == amt); - return amt; - } + const input = c.input; + if (c.remaining_cleartext_len > 0) { + const n = try bw.write(input.bufferContents()[0..c.remaining_cleartext_len]); + c.remaining_cleartext_len = @intCast(c.remaining_cleartext_len - n); + return n; + } + // If at least one full encrypted record is not buffered, read once. + const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { + error.EndOfStream => { + // This is either a truncation attack, a bug in the server, or an + // intentional omission of the close_notify message due to truncation + // detection handled above the TLS layer. + if (c.allow_truncation_attacks) { + c.received_close_notify = true; + return error.EndOfStream; + } else { + return failRead(c, error.TlsConnectionTruncated); + } + }, + error.ReadFailed => return error.ReadFailed, + }; + const ct: tls.ContentType = @enumFromInt(record_header[0]); + const legacy_version = mem.readInt(u16, record_header[1..][0..2], .big); + _ = legacy_version; + const record_len = mem.readInt(u16, record_header[3..][0..2], .big); + if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow); + const record_end = 5 + record_len; + if (record_end > input.bufferContents().len) { + input.fillMore() catch |err| switch (err) { + error.EndOfStream => return failRead(c, error.TlsConnectionTruncated), + error.ReadFailed => return error.ReadFailed, + }; + if (record_end > input.bufferContents().len) return 0; } - assert(!c.received_close_notify); - - // Ideally, this buffer would never be used. It is needed when `iovecs` are - // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. - var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; - // How many bytes left in the user's buffer. - const free_size = vp.freeSize(); - // The amount of the user's buffer that we need to repurpose for storing - // ciphertext. The end of the buffer will be used for such purposes. - const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; - // The amount of the user's buffer that will be used to give cleartext. The - // beginning of the buffer will be used for such purposes. - const cleartext_buf_len = free_size - ciphertext_buf_len; - - // Recoup `partially_read_buffer` space. This is necessary because it is assumed - // below that `frag0` is big enough to hold at least one record. - limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); - c.partial_ciphertext_end -= c.partial_ciphertext_idx; - c.partial_ciphertext_idx = 0; - c.partial_cleartext_idx = 0; - const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; - - var ask_iovecs_buf: [2]std.posix.iovec = .{ - .{ - .base = first_iov.ptr, - .len = first_iov.len, - }, - .{ - .base = &in_stack_buffer, - .len = in_stack_buffer.len, + const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { + inline else => |*p| switch (c.tls_version) { + .tls_1_3 => { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const ad = input.take(tls.record_header_len) catch unreachable; // already peeked + const ciphertext_len = record_len - P.AEAD.tag_length; + const ciphertext = input.take(ciphertext_len) catch unreachable; // already peeked + const auth_tag = (input.takeArray(P.AEAD.tag_length) catch unreachable).*; // already peeked + const nonce = nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); + break :nonce @as(V, pv.server_iv) ^ operand; + }; + const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch + return failRead(c, error.TlsBadRecordMac); + const msg = mem.trimRight(u8, cleartext, "\x00"); + break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; + }, + .tls_1_2 => { + const pv = &p.tls_1_2; + const P = @TypeOf(p.*); + const message_len: u16 = record_len - P.record_iv_length - P.mac_length; + const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked + const ad = std.mem.toBytes(big(c.read_seq)) ++ + ad_header[0 .. 1 + 2] ++ + std.mem.toBytes(big(message_len)); + const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked + const masked_read_seq = c.read_seq & + comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); + const nonce: [P.AEAD.nonce_length]u8 = nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); + break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; + }; + const ciphertext = input.take(message_len) catch unreachable; // already peeked + const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked + const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch + return failRead(c, error.TlsBadRecordMac); + break :cleartext .{ cleartext, ct }; + }, + else => unreachable, }, }; - - // Cleartext capacity of output buffer, in records. Minimum one full record. - const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1); - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); - const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len) - c.partial_ciphertext_end; - const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); - const actual_read_len = try c.input.readv(ask_iovecs); - if (actual_read_len == 0) { - // This is either a truncation attack, a bug in the server, or an - // intentional omission of the close_notify message due to truncation - // detection handled above the TLS layer. - if (c.allow_truncation_attacks) { - c.received_close_notify = true; - } else { - return failRead(c, error.TlsConnectionTruncated); - } - } - - // There might be more bytes inside `in_stack_buffer` that need to be processed, - // but at least frag0 will have one complete ciphertext record. - const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); - const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; - var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; - // We need to decipher frag0 and frag1 but there may be a ciphertext record - // straddling the boundary. We can handle this with two memcpy() calls to - // assemble the straddling record in between handling the two sides. - var frag = frag0; - var in: usize = 0; - while (true) { - if (in == frag.len) { - // Perfect split. - if (frag.ptr == frag1.ptr) { - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; + c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow); + switch (inner_ct) { + .alert => { + if (cleartext.len != 2) return failRead(c, error.TlsDecodeError); + const alert: tls.Alert = .{ + .level = @enumFromInt(cleartext[0]), + .description = @enumFromInt(cleartext[1]), + }; + switch (alert.description) { + .close_notify => { + c.received_close_notify = true; + return 0; + }, + .user_canceled => { + // TODO: handle server-side closures + return failRead(c, error.TlsUnexpectedMessage); + }, + else => { + c.alert = alert; + return failRead(c, error.TlsAlert); + }, } - frag = frag1; - in = 0; - continue; - } - - if (in + tls.record_header_len > frag.len) { - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - const first = frag[in..]; - - if (frag1.len < tls.record_header_len) - return finishRead2(c, first, frag1, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); - const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); - const record_len = (record_len_byte_0 << 8) | record_len_byte_1; - if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow); - - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const ct: tls.ContentType = @enumFromInt(frag[in]); - in += 1; - const legacy_version = mem.readInt(u16, frag[in..][0..2], .big); - in += 2; - _ = legacy_version; - const record_len = mem.readInt(u16, frag[in..][0..2], .big); - if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow); - in += 2; - const the_end = in + record_len; - if (the_end > frag.len) { - // We need the record header on the next iteration of the loop. - in -= tls.record_header_len; - - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const first = frag[in..]; - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { - inline else => |*p| switch (c.tls_version) { - .tls_1_3 => { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const ad = frag[in - tls.record_header_len ..][0..tls.record_header_len]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const nonce = nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); - break :nonce @as(V, pv.server_iv) ^ operand; - }; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch - return failRead(c, error.TlsBadRecordMac); - const msg = mem.trimEnd(u8, cleartext, "\x00"); - break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; - }, - .tls_1_2 => { - const pv = &p.tls_1_2; - const P = @TypeOf(p.*); - const message_len: u16 = record_len - P.record_iv_length - P.mac_length; - const ad = std.mem.toBytes(big(c.read_seq)) ++ - frag[in - tls.record_header_len ..][0 .. 1 + 2] ++ - std.mem.toBytes(big(message_len)); - const record_iv = frag[in..][0..P.record_iv_length].*; - in += P.record_iv_length; - const masked_read_seq = c.read_seq & - comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); - const nonce: [P.AEAD.nonce_length]u8 = nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); - break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; - }; - const ciphertext = frag[in..][0..message_len]; - in += message_len; - const auth_tag = frag[in..][0..P.mac_length].*; - in += P.mac_length; - const out_buf = vp.peek(); - const cleartext_buf = if (message_len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch - return failRead(c, error.TlsBadRecordMac); - break :cleartext .{ cleartext, ct }; - }, - else => unreachable, - }, - }; - c.read_seq = try std.math.add(u64, c.read_seq, 1); - switch (inner_ct) { - .alert => { - if (cleartext.len != 2) return failRead(c, error.TlsDecodeError); - const alert: tls.Alert = .{ - .level = @enumFromInt(cleartext[0]), - .description = @enumFromInt(cleartext[1]), - }; - switch (alert.description) { - .close_notify => { - c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; + }, + .handshake => { + var ct_i: usize = 0; + while (true) { + const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); + ct_i += 1; + const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength); + const handshake = cleartext[ct_i..next_handshake_i]; + switch (handshake_type) { + .new_session_ticket => { + // This client implementation ignores new session tickets. }, - .user_canceled => { - // TODO: handle server-side closures - return failRead(c, error.TlsUnexpectedMessage); - }, - else => { - c.alert = alert; - return failRead(c, error.TlsAlert); - }, - } - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength); - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. - }, - .key_update => { - switch (c.application_cipher) { - inline else => |*p| { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); - if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ - .counter = key_log.serverCounter(), - .client_random = &key_log.client_random, - }, .{ - .SERVER_TRAFFIC_SECRET = &server_secret, - }); - pv.server_secret = server_secret; - pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); - if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ - .counter = key_log.clientCounter(), - .client_random = &key_log.client_random, - }, .{ - .CLIENT_TRAFFIC_SECRET = &client_secret, - }); - pv.client_secret = client_secret; - pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return failRead(c, error.TlsIllegalParameter), - } - }, - else => return failRead(c, error.TlsUnexpectedMessage), - } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len) break; - } - }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == &cleartext_stack_buffer) { - // Stack buffer was used, so we must copy to the output buffer. - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..cleartext.len], - cleartext, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + cleartext.len); - } else { - const amt = vp.put(cleartext); - if (amt < cleartext.len) { - const rest = cleartext[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); + .key_update => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ + .counter = key_log.serverCounter(), + .client_random = &key_log.client_random, + }, .{ + .SERVER_TRAFFIC_SECRET = &server_secret, + }); + pv.server_secret = server_secret; + pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len); + c.read_seq = 0; + + switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { + .update_requested => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ + .counter = key_log.clientCounter(), + .client_random = &key_log.client_random, + }, .{ + .CLIENT_TRAFFIC_SECRET = &client_secret, + }); + pv.client_secret = client_secret; + pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + }, + } + c.write_seq = 0; + }, + .update_not_requested => {}, + _ => return failRead(c, error.TlsIllegalParameter), + } + }, + else => return failRead(c, error.TlsUnexpectedMessage), } - }, - else => return failRead(c, error.TlsUnexpectedMessage), - } - in = end; + ct_i = next_handshake_i; + if (ct_i >= cleartext.len) break; + } + return 0; + }, + .application_data => { + const n = try bw.write(limit.sliceConst(cleartext)); + if (n < cleartext.len) { + const remainder = cleartext[n..]; + input.unread(remainder); + c.remaining_cleartext_len = @intCast(remainder.len); + } + return n; + }, + else => return failRead(c, error.TlsUnexpectedMessage), } } -fn discard(context: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize { - _ = context; - _ = limit; - @panic("TODO"); +fn readVec(context: ?*anyopaque, data: []const []u8) Reader.Error!usize { + var bw: std.io.BufferedWriter = undefined; + bw.initFixed(data[0]); + return read(context, &bw, .limited(data[0].len)) catch |err| switch (err) { + error.WriteFailed => unreachable, + else => |e| return e, + }; +} + +fn discard(context: ?*anyopaque, limit: Reader.Limit) Reader.Error!usize { + var null_writer: Writer.Null = undefined; + var bw = null_writer.writer().unbuffered(); + return read(context, &bw, limit) catch |err| switch (err) { + error.WriteFailed => unreachable, + else => |e| return e, + }; } fn failRead(c: *Client, err: ReadError) error{ReadFailed} { @@ -1427,12 +1260,8 @@ fn failRead(c: *Client, err: ReadError) error{ReadFailed} { return error.ReadFailed; } -fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void { - const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false; - defer if (locked) key_log_file.unlock(); - key_log_file.seekFromEnd(0) catch {}; - var w = key_log_file.writer().unbuffered(); - inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++ +fn logSecrets(bw: *std.io.BufferedWriter, context: anytype, secrets: anytype) void { + inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| bw.print("{s}" ++ (if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++ (if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{ context.client_random, @@ -1440,59 +1269,6 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi }) catch {}; } -fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) std.io.Reader.Status { - const saved_buf = frag[in..]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(saved_buf.len); - @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf); - } - return .{ .len = out, .end = c.eof() }; -} - -/// Note that `first` usually overlaps with `c.partially_read_buffer`. -fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) std.io.Reader.Status { - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first); - @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1); - } - return .{ .len = out, .end = c.eof() }; -} - -fn limitedOverlapCopy(frag: []u8, in: usize) void { - const first = frag[in..]; - if (first.len <= in) { - // A single, non-overlapping memcpy suffices. - @memcpy(frag[0..first.len], first); - } else { - // One memcpy call would overlap, so just do this instead. - std.mem.copyForwards(u8, frag, first); - } -} - -fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { - if (index < s1.len) { - return s1[index]; - } else { - return s2[index - s1.len]; - } -} - inline fn big(x: anytype) @TypeOf(x) { return switch (native_endian) { .big => x, @@ -1753,81 +1529,6 @@ const CertificatePublicKey = struct { } }; -/// Abstraction for sending multiple byte buffers to a slice of iovecs. -const VecPut = struct { - iovecs: []const std.posix.iovec, - idx: usize = 0, - off: usize = 0, - total: usize = 0, - - /// Returns the amount actually put which is always equal to bytes.len - /// unless the vectors ran out of space. - fn put(vp: *VecPut, bytes: []const u8) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var bytes_i: usize = 0; - while (true) { - const v = vp.iovecs[vp.idx]; - const dest = v.base[vp.off..v.len]; - const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; - @memcpy(dest[0..src.len], src); - bytes_i += src.len; - vp.off += src.len; - if (vp.off >= v.len) { - vp.off = 0; - vp.idx += 1; - if (vp.idx >= vp.iovecs.len) { - vp.total += bytes_i; - return bytes_i; - } - } - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } - } - } - - /// Returns the next buffer that consecutive bytes can go into. - fn peek(vp: VecPut) []u8 { - if (vp.idx >= vp.iovecs.len) return &.{}; - const v = vp.iovecs[vp.idx]; - return v.base[vp.off..v.len]; - } - - // After writing to the result of peek(), one can call next() to - // advance the cursor. - fn next(vp: *VecPut, len: usize) void { - vp.total += len; - vp.off += len; - if (vp.off >= vp.iovecs[vp.idx].len) { - vp.off = 0; - vp.idx += 1; - } - } - - fn freeSize(vp: VecPut) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var total: usize = 0; - total += vp.iovecs[vp.idx].len - vp.off; - if (vp.idx + 1 >= vp.iovecs.len) return total; - for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.len; - return total; - } -}; - -/// Limit iovecs to a specific byte size. -fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec { - var bytes_left: usize = len; - for (iovecs, 0..) |*iovec, vec_i| { - if (bytes_left <= iovec.len) { - iovec.len = bytes_left; - return iovecs[0 .. vec_i + 1]; - } - bytes_left -= iovec.len; - } - return iovecs; -} - /// The priority order here is chosen based on what crypto algorithms Zig has /// available in the standard library as well as what is faster. Following are /// a few data points on the relative performance of these algorithms. diff --git a/lib/std/http.zig b/lib/std/http.zig index fe73d9f477..fcb25af130 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -487,9 +487,7 @@ pub const Reader = struct { return decompressor.compression.gzip.reader(); }, .zstd => { - decompressor.compression = .{ .zstd = .init(reader.in, .{ - .window_buffer = decompression_buffer, - }) }; + decompressor.compression = .{ .zstd = .init(reader.in, .{ .verify_checksum = false }) }; return decompressor.compression.zstd.reader(); }, .compress => unreachable, @@ -742,7 +740,7 @@ pub const Decompressor = struct { pub const Compression = union(enum) { deflate: std.compress.zlib.Decompressor, gzip: std.compress.gzip.Decompressor, - zstd: std.compress.zstd.Decompressor, + zstd: std.compress.zstd.Decompress, none: void, }; @@ -768,12 +766,8 @@ pub const Decompressor = struct { return decompressor.compression.gzip.reader(); }, .zstd => { - const first_half = buffer[0 .. buffer.len / 2]; - const second_half = buffer[buffer.len / 2 ..]; - decompressor.buffered_reader = transfer_reader.buffered(first_half); - decompressor.compression = .{ .zstd = .init(&decompressor.buffered_reader, .{ - .window_buffer = second_half, - }) }; + decompressor.buffered_reader = transfer_reader.buffered(buffer); + decompressor.compression = .{ .zstd = .init(&decompressor.buffered_reader, .{}) }; return decompressor.compression.gzip.reader(); }, .compress => unreachable, diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 4c014908e0..dffdcdd82e 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -28,7 +28,7 @@ tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.cr /// If non-null, ssl secrets are logged to a stream. Creating such a stream /// allows other processes with access to that stream to decrypt all /// traffic over connections created with this `Client`. -ssl_key_log: ?*std.io.BufferedWriter = null, +ssl_key_log: ?*std.crypto.tls.Client.SslKeyLog = null, /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. @@ -230,8 +230,11 @@ pub const Connection = struct { stream_writer: net.Stream.Writer, stream_reader: net.Stream.Reader, /// HTTP protocol from client to server. - /// This either goes directly to `stream`, or to a TLS client. + /// This either goes directly to `stream_writer`, or to a TLS client. writer: std.io.BufferedWriter, + /// HTTP protocol from server to client. + /// This either comes directly from `stream_reader`, or from a TLS client. + reader: std.io.BufferedReader, /// Entry in `ConnectionPool.used` or `ConnectionPool.free`. pool_node: std.DoublyLinkedList.Node, port: u16, @@ -241,8 +244,6 @@ pub const Connection = struct { protocol: Protocol, const Plain = struct { - /// Data from `Connection.stream`. - reader: std.io.BufferedReader, connection: Connection, fn create( @@ -267,6 +268,7 @@ pub const Connection = struct { .stream_writer = stream.writer(), .stream_reader = stream.reader(), .writer = plain.connection.stream_writer.interface().buffered(socket_write_buffer), + .reader = plain.connection.stream_reader.interface().buffered(socket_read_buffer), .pool_node = .{}, .port = port, .host_len = @intCast(remote_host.len), @@ -274,7 +276,6 @@ pub const Connection = struct { .closing = false, .protocol = .plain, }, - .reader = plain.connection.stream_reader.interface().buffered(socket_read_buffer), }; return plain; } @@ -327,6 +328,7 @@ pub const Connection = struct { .stream_writer = stream.writer(), .stream_reader = stream.reader(), .writer = tls.client.writer().buffered(socket_write_buffer), + .reader = tls.client.reader().unbuffered(), .pool_node = .{}, .port = port, .host_len = @intCast(remote_host.len), @@ -386,21 +388,6 @@ pub const Connection = struct { }; } - /// This is either data from `stream`, or `Tls.client`. - fn reader(c: *Connection) *std.io.BufferedReader { - return switch (c.protocol) { - .tls => { - if (disable_tls) unreachable; - const tls: *Tls = @fieldParentPtr("connection", c); - return &tls.client.reader; - }, - .plain => { - const plain: *Plain = @fieldParentPtr("connection", c); - return &plain.reader; - }, - }; - } - /// If this is called without calling `flush` or `end`, data will be /// dropped unsent. pub fn destroy(c: *Connection) void { @@ -1556,7 +1543,7 @@ pub fn request( .client = client, .connection = connection, .reader = .{ - .in = connection.reader(), + .in = &connection.reader, .state = .ready, .body_state = undefined, }, @@ -1670,8 +1657,7 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult { const decompress_buffer: []u8 = switch (response.head.content_encoding) { .identity => &.{}, - .zstd => options.decompress_buffer orelse - try client.allocator.alloc(u8, std.compress.zstd.default_window_len * 2), + .zstd => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.zstd.default_window_len), else => options.decompress_buffer orelse try client.allocator.alloc(u8, 8 * 1024), }; defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer); @@ -1681,7 +1667,7 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult { const list = storage.list; if (storage.allocator) |allocator| { - reader.readRemainingArrayList(allocator, null, list, storage.append_limit) catch |err| switch (err) { + reader.readRemainingArrayList(allocator, null, list, storage.append_limit, 128) catch |err| switch (err) { error.ReadFailed => return response.bodyErr().?, else => |e| return e, }; diff --git a/lib/std/io/BufferedReader.zig b/lib/std/io/BufferedReader.zig index 23f45e93b3..fc5d531922 100644 --- a/lib/std/io/BufferedReader.zig +++ b/lib/std/io/BufferedReader.zig @@ -252,6 +252,12 @@ pub fn toss(br: *BufferedReader, n: usize) void { assert(br.seek <= br.end); } +pub fn unread(noalias br: *BufferedReader, noalias data: []const u8) void { + _ = br; + _ = data; + @panic("TODO"); +} + /// Equivalent to `peek` followed by `toss`. /// /// The data returned is invalidated by the next call to `take`, `peek`, @@ -736,24 +742,25 @@ pub fn discardDelimiterExclusive(br: *BufferedReader, delimiter: u8) Reader.Shor /// Asserts buffer capacity is at least `n`. pub fn fill(br: *BufferedReader, n: usize) Reader.Error!void { assert(n <= br.buffer.len); - const buffer = br.buffer[0..br.end]; - const seek = br.seek; - if (seek + n <= buffer.len) { + if (br.seek + n <= br.end) { @branchHint(.likely); return; } - if (seek > 0) { - const remainder = buffer[seek..]; - std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); - br.end = remainder.len; - br.seek = 0; - } - while (true) { + rebaseCapacity(br, n); + while (br.end < br.seek + n) { br.end += try br.unbuffered_reader.readVec(&.{br.buffer[br.end..]}); - if (n <= br.end) return; } } +/// Fills the buffer with at least one more byte of data, without advancing the +/// seek position, doing exactly one underlying read. +/// +/// Asserts buffer capacity is at least 1. +pub fn fillMore(br: *BufferedReader) Reader.Error!void { + rebaseCapacity(br, 1); + br.end += try br.unbuffered_reader.readVec(&.{br.buffer[br.end..]}); +} + /// Returns the next byte from the stream or returns `error.EndOfStream`. /// /// Does not advance the seek position. @@ -783,7 +790,7 @@ pub fn takeByteSigned(br: *BufferedReader) Reader.Error!i8 { return @bitCast(try br.takeByte()); } -/// Asserts the buffer was initialized with a capacity at least `@sizeOf(T)`. +/// Asserts the buffer was initialized with a capacity at least `@bitSizeOf(T) / 8`. pub inline fn takeInt(br: *BufferedReader, comptime T: type, endian: std.builtin.Endian) Reader.Error!T { const n = @divExact(@typeInfo(T).int.bits, 8); return std.mem.readInt(T, try br.takeArray(n), endian); @@ -957,6 +964,14 @@ pub fn rebase(br: *BufferedReader) void { br.end = data.len; } +/// Ensures `capacity` more data can be buffered without rebasing, by rebasing +/// if necessary. +/// +/// Asserts `capacity` is within the buffer capacity. +pub fn rebaseCapacity(br: *BufferedReader, capacity: usize) void { + if (br.end > br.buffer.len - capacity) rebase(br); +} + /// Advances the stream and decreases the size of the storage buffer by `n`, /// returning the range of bytes no longer accessible by `br`. /// diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index 41eb96d0d8..343de016ea 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -100,6 +100,10 @@ pub const Limit = enum(usize) { return s[0..l.minInt(s.len)]; } + pub fn sliceConst(l: Limit, s: []const u8) []const u8 { + return s[0..l.minInt(s.len)]; + } + pub fn toInt(l: Limit) ?usize { return switch (l) { else => @intFromEnum(l), diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index 10128e14c5..f00ba27daf 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -140,7 +140,7 @@ pub fn failingWriteFile( limit: std.io.Writer.Limit, headers_and_trailers: []const []const u8, headers_len: usize, -) Error!usize { +) FileError!usize { _ = context; _ = file; _ = offset; @@ -165,7 +165,7 @@ pub fn unimplementedWriteFile( limit: std.io.Writer.Limit, headers_and_trailers: []const []const u8, headers_len: usize, -) Error!usize { +) FileError!usize { _ = context; _ = file; _ = offset;