From e1c37f70d4ae9a7bfa6de92dcb26e7cfdffc17c2 Mon Sep 17 00:00:00 2001 From: Nameless Date: Tue, 3 Oct 2023 14:26:06 -0500 Subject: [PATCH] std.http.Client: store *Connection instead of a pool node, buffer writes --- lib/std/crypto/tls/Client.zig | 2 +- lib/std/http/Client.zig | 198 ++++++++++++++++++---------------- lib/std/http/protocol.zig | 6 +- test/standalone/http.zig | 2 +- 4 files changed, 111 insertions(+), 97 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 37306dd37f..7671d06469 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -881,7 +881,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { /// The `iovecs` parameter is mutable because this function needs to mutate the fields in /// order to handle partial reads from the underlying stream layer. pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { - return readvAtLeast(c, stream, iovecs); + return readvAtLeast(c, stream, iovecs, 1); } /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index d97334fe4c..55ae62a183 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -54,7 +54,7 @@ pub const ConnectionPool = struct { /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. /// If no connection is found, null is returned. - pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node { + pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -65,7 +65,7 @@ pub const ConnectionPool = struct { if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; pool.acquireUnsafe(node); - return node; + return &node.data; } return null; @@ -89,10 +89,12 @@ pub const ConnectionPool = struct { /// Tries to release a connection back to the connection pool. This function is threadsafe. /// If the connection is marked as closing, it will be closed instead. - pub fn release(pool: *ConnectionPool, allocator: Allocator, node: *Node) void { + pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { pool.mutex.lock(); defer pool.mutex.unlock(); + const node = @fieldParentPtr(Node, "data", connection); + pool.used.remove(node); if (node.data.closing or pool.free_size == 0) { @@ -151,6 +153,8 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + const BufferSize = std.math.IntFittingRange(0, buffer_size); + pub const Protocol = enum { plain, tls }; stream: net.Stream, @@ -164,14 +168,16 @@ pub const Connection = struct { proxied: bool = false, closing: bool = false, - read_start: u16 = 0, - read_end: u16 = 0, + read_start: BufferSize = 0, + read_end: BufferSize = 0, + write_end: BufferSize = 0, read_buf: [buffer_size]u8 = undefined, + write_buf: [buffer_size]u8 = undefined, - pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { + pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { return switch (conn.protocol) { - .plain => conn.stream.readAtLeast(buffer, len), - .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), + .plain => conn.stream.readv(buffers), + .tls => conn.tls_client.readv(conn.stream, buffers), } catch |err| { // TODO: https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; @@ -188,58 +194,52 @@ pub const Connection = struct { pub fn fill(conn: *Connection) ReadError!void { if (conn.read_end != conn.read_start) return; - const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); + var iovecs = [1]std.os.iovec{ + .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); if (nread == 0) return error.EndOfStream; conn.read_start = 0; - conn.read_end = @as(u16, @intCast(nread)); + conn.read_end = @intCast(nread); } pub fn peek(conn: *Connection) []const u8 { return conn.read_buf[conn.read_start..conn.read_end]; } - pub fn drop(conn: *Connection, num: u16) void { + pub fn drop(conn: *Connection, num: BufferSize) void { conn.read_start += num; } - pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len; - var out_index: u16 = 0; - while (out_index < len) { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len - out_index; + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + conn.read_start += @intCast(available_buffer); - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - out_index += @as(u16, @intCast(available_buffer)); - conn.read_start += @as(u16, @intCast(available_buffer)); + return available_buffer; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + conn.read_start += available_read; - break; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - out_index += available_read; - conn.read_start += available_read; - - if (out_index >= len) break; - } - - const leftover_buffer = available_buffer - available_read; - const leftover_len = len - out_index; - - if (leftover_buffer > conn.read_buf.len) { - // skip the buffer if the output is large enough - return conn.rawReadAtLeast(buffer[out_index..], leftover_len); - } - - try conn.fill(); + return available_read; } - return out_index; - } + var iovecs = [2]std.os.iovec{ + .{ .iov_base = buffer.ptr, .iov_len = buffer.len }, + .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); + if (nread > buffer.len) { + conn.read_start = 0; + conn.read_end = @intCast(nread - buffer.len); + return buffer.len; + } + + return nread; } pub const ReadError = error{ @@ -257,7 +257,7 @@ pub const Connection = struct { return Reader{ .context = conn }; } - pub fn writeAll(conn: *Connection, buffer: []const u8) !void { + pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { return switch (conn.protocol) { .plain => conn.stream.writeAll(buffer), .tls => conn.tls_client.writeAll(conn.stream, buffer), @@ -267,14 +267,27 @@ pub const Connection = struct { }; } - pub fn write(conn: *Connection, buffer: []const u8) !usize { - return switch (conn.protocol) { - .plain => conn.stream.write(buffer), - .tls => conn.tls_client.write(conn.stream, buffer), - } catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; + pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { + if (conn.write_end + buffer.len > conn.write_buf.len) { + try conn.flush(); + + if (buffer.len > conn.write_buf.len) { + try conn.writeAllDirect(buffer); + return buffer.len; + } + } + + @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); + conn.write_end += @intCast(buffer.len); + + return buffer.len; + } + + pub fn flush(conn: *Connection) WriteError!void { + if (conn.write_end == 0) return; + + try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); + conn.write_end = 0; } pub const WriteError = error{ @@ -455,7 +468,7 @@ pub const Request = struct { uri: Uri, client: *Client, /// is null when this connection is released - connection: ?*ConnectionPool.Node, + connection: ?*Connection, method: http.Method, version: http.Version = .@"HTTP/1.1", @@ -489,7 +502,7 @@ pub const Request = struct { if (req.connection) |connection| { if (!req.response.parser.done) { // If the response wasn't fully read, then we need to close the connection. - connection.data.closing = true; + connection.closing = true; } req.client.connection_pool.release(req.client.allocator, connection); } @@ -548,8 +561,7 @@ pub const Request = struct { pub fn start(req: *Request, options: StartOptions) StartError!void { if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding; - var buffered = std.io.bufferedWriter(req.connection.?.data.writer()); - const w = buffered.writer(); + const w = req.connection.?.writer(); try req.method.write(w); try w.writeByte(' '); @@ -558,9 +570,9 @@ pub const Request = struct { try req.uri.writeToStream(.{ .authority = true }, w); } else { try req.uri.writeToStream(.{ - .scheme = req.connection.?.data.proxied, - .authentication = req.connection.?.data.proxied, - .authority = req.connection.?.data.proxied, + .scheme = req.connection.?.proxied, + .authentication = req.connection.?.proxied, + .authority = req.connection.?.proxied, .path = true, .query = true, .raw = options.raw_uri, @@ -629,8 +641,8 @@ pub const Request = struct { try w.writeAll("\r\n"); } - if (req.connection.?.data.proxied) { - const proxy_headers: ?http.Headers = switch (req.connection.?.data.protocol) { + if (req.connection.?.proxied) { + const proxy_headers: ?http.Headers = switch (req.connection.?.protocol) { .plain => if (req.client.http_proxy) |proxy| proxy.headers else null, .tls => if (req.client.https_proxy) |proxy| proxy.headers else null, }; @@ -649,7 +661,7 @@ pub const Request = struct { try w.writeAll("\r\n"); - try buffered.flush(); + try req.connection.?.flush(); } const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; @@ -665,7 +677,7 @@ pub const Request = struct { var index: usize = 0; while (index == 0) { - const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip); + const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); if (amt == 0 and req.response.parser.done) break; index += amt; } @@ -683,10 +695,10 @@ pub const Request = struct { pub fn wait(req: *Request) WaitError!void { while (true) { // handle redirects while (true) { // read headers - try req.connection.?.data.fill(); + try req.connection.?.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek()); - req.connection.?.data.drop(@as(u16, @intCast(nchecked))); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); if (req.response.parser.state.isContent()) break; } @@ -701,7 +713,7 @@ pub const Request = struct { // we're switching protocols, so this connection is no longer doing http if (req.response.status == .switching_protocols or (req.method == .CONNECT and req.response.status == .ok)) { - req.connection.?.data.closing = false; + req.connection.?.closing = false; req.response.parser.done = true; } @@ -712,9 +724,9 @@ pub const Request = struct { const res_connection = req.response.headers.getFirstValue("connection"); const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); if (res_keepalive and (req_keepalive or req_connection == null)) { - req.connection.?.data.closing = false; + req.connection.?.closing = false; } else { - req.connection.?.data.closing = true; + req.connection.?.closing = true; } if (req.response.transfer_encoding) |te| { @@ -827,10 +839,10 @@ pub const Request = struct { const has_trail = !req.response.parser.state.isContent(); while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.data.fill(); + try req.connection.?.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek()); - req.connection.?.data.drop(@as(u16, @intCast(nchecked))); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); } if (has_trail) { @@ -868,16 +880,16 @@ pub const Request = struct { pub fn write(req: *Request, bytes: []const u8) WriteError!usize { switch (req.transfer_encoding) { .chunked => { - try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.data.writeAll(bytes); - try req.connection.?.data.writeAll("\r\n"); + try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.?.writer().writeAll(bytes); + try req.connection.?.writer().writeAll("\r\n"); return bytes.len; }, .content_length => |*len| { if (len.* < bytes.len) return error.MessageTooLong; - const amt = try req.connection.?.data.write(bytes); + const amt = try req.connection.?.write(bytes); len.* -= amt; return amt; }, @@ -897,10 +909,12 @@ pub const Request = struct { /// Finish the body of a request. This notifies the server that you have no more data to send. pub fn finish(req: *Request) FinishError!void { switch (req.transfer_encoding) { - .chunked => try req.connection.?.data.writeAll("0\r\n\r\n"), + .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } + + try req.connection.?.flush(); } }; @@ -1024,7 +1038,7 @@ pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, Network /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. /// This function is threadsafe. -pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*ConnectionPool.Node { +pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { if (client.connection_pool.findConnection(.{ .host = host, .port = port, @@ -1074,12 +1088,12 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec client.connection_pool.addUsed(conn); - return conn; + return &conn.data; } pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{ NameTooLong, Unsupported } || std.os.ConnectError; -pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*ConnectionPool.Node { +pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { if (!net.has_unix_sockets) return error.Unsupported; if (client.connection_pool.findConnection(.{ @@ -1108,7 +1122,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti client.connection_pool.addUsed(conn); - return conn; + return &conn.data; } pub fn connectTunnel( @@ -1116,7 +1130,7 @@ pub fn connectTunnel( proxy: *ProxyInformation, tunnel_host: []const u8, tunnel_port: u16, -) !*ConnectionPool.Node { +) !*Connection { if (!proxy.supports_connect) return error.TunnelNotSupported; if (client.connection_pool.findConnection(.{ @@ -1130,7 +1144,7 @@ pub fn connectTunnel( _ = tunnel: { const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); errdefer { - conn.data.closing = true; + conn.closing = true; client.connection_pool.release(client.allocator, conn); } @@ -1171,12 +1185,12 @@ pub fn connectTunnel( // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. req.connection = null; - client.allocator.free(conn.data.host); - conn.data.host = try client.allocator.dupe(u8, tunnel_host); - errdefer client.allocator.free(conn.data.host); + client.allocator.free(conn.host); + conn.host = try client.allocator.dupe(u8, tunnel_host); + errdefer client.allocator.free(conn.host); - conn.data.port = tunnel_port; - conn.data.closing = false; + conn.port = tunnel_port; + conn.closing = false; return conn; } catch { @@ -1190,7 +1204,7 @@ pub fn connectTunnel( const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, ConnectionRefused }; pub const ConnectError = ConnectErrorPartial || RequestError; -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node { +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*Connection { // pointer required so that `supports_connect` can be updated if a CONNECT fails const potential_proxy: ?*ProxyInformation = switch (protocol) { .plain => if (client.http_proxy) |*proxy_info| proxy_info else null, @@ -1213,11 +1227,11 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio // fall back to using the proxy as a normal http proxy const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); errdefer { - conn.data.closing = true; + conn.closing = true; client.connection_pool.release(conn); } - conn.data.proxied = true; + conn.proxied = true; return conn; } @@ -1240,7 +1254,7 @@ pub const RequestOptions = struct { header_strategy: StorageStrategy = .{ .dynamic = 16 * 1024 }, /// Must be an already acquired connection. - connection: ?*ConnectionPool.Node = null, + connection: ?*Connection = null, pub const StorageStrategy = union(enum) { /// In this case, the client's Allocator will be used to store the diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index a369c38581..74e0207f34 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -529,7 +529,7 @@ pub const HeadersParser = struct { try conn.fill(); const nread = @min(conn.peek().len, data_avail); - conn.drop(@as(u16, @intCast(nread))); + conn.drop(@intCast(nread)); r.next_chunk_length -= nread; if (r.next_chunk_length == 0) r.done = true; @@ -553,7 +553,7 @@ pub const HeadersParser = struct { try conn.fill(); const i = r.findChunkedLen(conn.peek()); - conn.drop(@as(u16, @intCast(i))); + conn.drop(@intCast(i)); switch (r.state) { .invalid => return error.HttpChunkInvalid, @@ -582,7 +582,7 @@ pub const HeadersParser = struct { try conn.fill(); const nread = @min(conn.peek().len, data_avail); - conn.drop(@as(u16, @intCast(nread))); + conn.drop(@intCast(nread)); r.next_chunk_length -= nread; } else if (out_avail > 0) { const can_read: usize = @intCast(@min(data_avail, out_avail)); diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 8b538a092f..71f3481767 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -680,7 +680,7 @@ pub fn main() !void { for (0..total_connections) |i| { var req = try client.request(.GET, uri, .{ .allocator = calloc }, .{}); req.response.parser.done = true; - req.connection.?.data.closing = false; + req.connection.?.closing = false; requests[i] = req; }