diff --git a/lib/std/Io.zig b/lib/std/Io.zig index 19b34c6968..7cce6397bd 100644 --- a/lib/std/Io.zig +++ b/lib/std/Io.zig @@ -717,7 +717,12 @@ pub fn Poller(comptime StreamEnum: type) type { const unused = r.buffer[r.end..]; if (unused.len >= min_len) return unused; } - if (r.seek > 0) r.rebase(r.buffer.len) catch unreachable; + if (r.seek > 0) { + const data = r.buffer[r.seek..r.end]; + @memmove(r.buffer[0..data.len], data); + r.seek = 0; + r.end = data.len; + } { var list: std.ArrayListUnmanaged(u8) = .{ .items = r.buffer[0..r.end], diff --git a/lib/std/Io/Reader.zig b/lib/std/Io/Reader.zig index c4923127ee..8fa9291467 100644 --- a/lib/std/Io/Reader.zig +++ b/lib/std/Io/Reader.zig @@ -86,12 +86,12 @@ pub const VTable = struct { /// `Reader.buffer`, whichever is bigger. readVec: *const fn (r: *Reader, data: [][]u8) Error!usize = defaultReadVec, - /// Ensures `capacity` more data can be buffered without rebasing. + /// Ensures `capacity` data can be buffered without rebasing. /// /// Asserts `capacity` is within buffer capacity, or that the stream ends /// within `capacity` bytes. /// - /// Only called when `capacity` cannot fit into the unused capacity of + /// Only called when `capacity` cannot be satisfied by unused capacity of /// `buffer`. /// /// The default implementation moves buffered data to the start of @@ -1037,7 +1037,7 @@ fn fillUnbuffered(r: *Reader, n: usize) Error!void { /// /// Asserts buffer capacity is at least 1. pub fn fillMore(r: *Reader) Error!void { - try rebase(r, 1); + try rebase(r, r.end - r.seek + 1); var bufs: [1][]u8 = .{""}; _ = try r.vtable.readVec(r, &bufs); } @@ -1205,24 +1205,6 @@ pub fn takeLeb128(r: *Reader, comptime Result: type) TakeLeb128Error!Result { } }))) orelse error.Overflow; } -pub fn expandTotalCapacity(r: *Reader, allocator: Allocator, n: usize) Allocator.Error!void { - if (n <= r.buffer.len) return; - if (r.seek > 0) rebase(r, r.buffer.len); - var list: ArrayList(u8) = .{ - .items = r.buffer[0..r.end], - .capacity = r.buffer.len, - }; - defer r.buffer = list.allocatedSlice(); - try list.ensureTotalCapacity(allocator, n); -} - -pub const FillAllocError = Error || Allocator.Error; - -pub fn fillAlloc(r: *Reader, allocator: Allocator, n: usize) FillAllocError!void { - try expandTotalCapacity(r, allocator, n); - return fill(r, n); -} - fn takeMultipleOf7Leb128(r: *Reader, comptime Result: type) TakeLeb128Error!Result { const result_info = @typeInfo(Result).int; comptime assert(result_info.bits % 7 == 0); @@ -1253,9 +1235,9 @@ fn takeMultipleOf7Leb128(r: *Reader, comptime Result: type) TakeLeb128Error!Resu } } -/// Ensures `capacity` more data can be buffered without rebasing. +/// Ensures `capacity` data can be buffered without rebasing. pub fn rebase(r: *Reader, capacity: usize) RebaseError!void { - if (r.end + capacity <= r.buffer.len) { + if (r.buffer.len - r.seek >= capacity) { @branchHint(.likely); return; } @@ -1263,11 +1245,12 @@ pub fn rebase(r: *Reader, capacity: usize) RebaseError!void { } pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void { - if (r.end <= r.buffer.len - capacity) return; + assert(r.buffer.len - r.seek < capacity); const data = r.buffer[r.seek..r.end]; @memmove(r.buffer[0..data.len], data); r.seek = 0; r.end = data.len; + assert(r.buffer.len - r.seek >= capacity); } test fixed { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index aef9a60232..f9c897e3f0 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -183,7 +183,6 @@ const InitError = error{ /// `input` is asserted to have buffer capacity at least `min_buffer_len`. pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client { assert(input.buffer.len >= min_buffer_len); - assert(output.buffer.len >= min_buffer_len); const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -1124,12 +1123,6 @@ fn readIndirect(c: *Client) Reader.Error!usize { if (record_end > input.buffered().len) return 0; } - if (r.seek == r.end) { - r.seek = 0; - r.end = 0; - } - const cleartext_buffer = r.buffer[r.end..]; - const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { inline else => |*p| switch (c.tls_version) { .tls_1_3 => { @@ -1145,7 +1138,8 @@ fn readIndirect(c: *Client) Reader.Error!usize { const operand: V = pad ++ mem.toBytes(big(c.read_seq)); break :nonce @as(V, pv.server_iv) ^ operand; }; - const cleartext = cleartext_buffer[0..ciphertext.len]; + rebase(r, ciphertext.len); + const cleartext = r.buffer[r.end..][0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch return failRead(c, error.TlsBadRecordMac); // TODO use scalar, non-slice version @@ -1171,7 +1165,8 @@ fn readIndirect(c: *Client) Reader.Error!usize { }; const ciphertext = input.take(message_len) catch unreachable; // already peeked const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked - const cleartext = cleartext_buffer[0..ciphertext.len]; + rebase(r, ciphertext.len); + const cleartext = r.buffer[r.end..][0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch return failRead(c, error.TlsBadRecordMac); break :cleartext .{ cleartext.len, ct }; @@ -1179,7 +1174,7 @@ fn readIndirect(c: *Client) Reader.Error!usize { else => unreachable, }, }; - const cleartext = cleartext_buffer[0..cleartext_len]; + const cleartext = r.buffer[r.end..][0..cleartext_len]; c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow); switch (inner_ct) { .alert => { @@ -1275,6 +1270,15 @@ fn readIndirect(c: *Client) Reader.Error!usize { } } +fn rebase(r: *Reader, capacity: usize) void { + if (r.buffer.len - r.end >= capacity) return; + const data = r.buffer[r.seek..r.end]; + @memmove(r.buffer[0..data.len], data); + r.seek = 0; + r.end = data.len; + assert(r.buffer.len - r.end >= capacity); +} + fn failRead(c: *Client, err: ReadError) error{ReadFailed} { c.read_err = err; return error.ReadFailed; diff --git a/lib/std/http.zig b/lib/std/http.zig index 09251f6c69..640ac2e208 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -329,6 +329,7 @@ pub const Reader = struct { /// read from `in`. trailers: []const u8 = &.{}, body_err: ?BodyError = null, + max_head_len: usize, pub const RemainingChunkLen = enum(u64) { head = 0, @@ -387,10 +388,11 @@ pub const Reader = struct { pub fn receiveHead(reader: *Reader) HeadError![]const u8 { reader.trailers = &.{}; const in = reader.in; + const max_head_len = reader.max_head_len; var hp: HeadParser = .{}; var head_len: usize = 0; while (true) { - if (in.buffer.len - head_len == 0) return error.HttpHeadersOversize; + if (head_len >= max_head_len) return error.HttpHeadersOversize; const remaining = in.buffered()[head_len..]; if (remaining.len == 0) { in.fillMore() catch |err| switch (err) { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index fe28a930a4..f052943816 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -42,7 +42,7 @@ connection_pool: ConnectionPool = .{}, /// /// If the entire HTTP header cannot fit in this amount of bytes, /// `error.HttpHeadersOversize` will be returned from `Request.wait`. -read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len, +read_buffer_size: usize = 8192, /// Each `Connection` allocates this amount for the writer buffer. write_buffer_size: usize = 1024, @@ -302,18 +302,22 @@ pub const Connection = struct { const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len); errdefer gpa.free(base); const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len]; - const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size]; + // The TLS client wants enough buffer for the max encrypted frame + // size, and the HTTP body reader wants enough buffer for the + // entire HTTP header. This means we need a combined upper bound. + const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size; + const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..tls_read_buffer_len]; const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size]; - const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; - const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size]; - assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len); + const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; + const socket_read_buffer = socket_write_buffer.ptr[socket_write_buffer.len..][0..client.tls_buffer_size]; + assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len); @memcpy(host_buffer, remote_host); const tls: *Tls = @ptrCast(base); tls.* = .{ .connection = .{ .client = client, .stream_writer = stream.writer(tls_write_buffer), - .stream_reader = stream.reader(tls_read_buffer), + .stream_reader = stream.reader(socket_read_buffer), .pool_node = .{}, .port = port, .host_len = @intCast(remote_host.len), @@ -329,8 +333,8 @@ pub const Connection = struct { .host = .{ .explicit = remote_host }, .ca = .{ .bundle = client.ca_bundle }, .ssl_key_log = client.ssl_key_log, - .read_buffer = read_buffer, - .write_buffer = write_buffer, + .read_buffer = tls_read_buffer, + .write_buffer = socket_write_buffer, // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. .allow_truncation_attacks = true, @@ -348,8 +352,9 @@ pub const Connection = struct { } fn allocLen(client: *Client, host_len: usize) usize { - return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + - client.write_buffer_size + client.read_buffer_size; + const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size; + return @sizeOf(Tls) + host_len + tls_read_buffer_len + client.tls_buffer_size + + client.write_buffer_size + client.tls_buffer_size; } fn host(tls: *Tls) []u8 { @@ -1214,6 +1219,7 @@ pub const Request = struct { .state = .ready, // Populated when `http.Reader.bodyReader` is called. .interface = undefined, + .max_head_len = r.client.read_buffer_size, }; r.redirect_behavior.subtractOne(); } @@ -1679,6 +1685,7 @@ pub fn request( .state = .ready, // Populated when `http.Reader.bodyReader` is called. .interface = undefined, + .max_head_len = client.read_buffer_size, }, .keep_alive = options.keep_alive, .method = method, diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index ebea94acc9..c62906827a 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -29,6 +29,7 @@ pub fn init(in: *Reader, out: *Writer) Server { .state = .ready, // Populated when `http.Reader.bodyReader` is called. .interface = undefined, + .max_head_len = in.buffer.len, }, .out = out, }; @@ -251,6 +252,7 @@ pub const Request = struct { .in = undefined, .state = .received_head, .interface = undefined, + .max_head_len = 4096, }, .out = undefined, };