std.http.Client: store *Connection instead of a pool node, buffer writes

This commit is contained in:
Nameless 2023-10-03 14:26:06 -05:00
parent 1afeada2d9
commit e1c37f70d4
No known key found for this signature in database
GPG Key ID: A477BC03CAFCCAF7
4 changed files with 111 additions and 97 deletions

View File

@ -881,7 +881,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
/// order to handle partial reads from the underlying stream layer.
pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize {
return readvAtLeast(c, stream, iovecs);
return readvAtLeast(c, stream, iovecs, 1);
}
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.

View File

@ -54,7 +54,7 @@ pub const ConnectionPool = struct {
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
/// If no connection is found, null is returned.
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node {
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection {
pool.mutex.lock();
defer pool.mutex.unlock();
@ -65,7 +65,7 @@ pub const ConnectionPool = struct {
if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue;
pool.acquireUnsafe(node);
return node;
return &node.data;
}
return null;
@ -89,10 +89,12 @@ pub const ConnectionPool = struct {
/// Tries to release a connection back to the connection pool. This function is threadsafe.
/// If the connection is marked as closing, it will be closed instead.
pub fn release(pool: *ConnectionPool, allocator: Allocator, node: *Node) void {
pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void {
pool.mutex.lock();
defer pool.mutex.unlock();
const node = @fieldParentPtr(Node, "data", connection);
pool.used.remove(node);
if (node.data.closing or pool.free_size == 0) {
@ -151,6 +153,8 @@ pub const ConnectionPool = struct {
/// An interface to either a plain or TLS connection.
pub const Connection = struct {
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
const BufferSize = std.math.IntFittingRange(0, buffer_size);
pub const Protocol = enum { plain, tls };
stream: net.Stream,
@ -164,14 +168,16 @@ pub const Connection = struct {
proxied: bool = false,
closing: bool = false,
read_start: u16 = 0,
read_end: u16 = 0,
read_start: BufferSize = 0,
read_end: BufferSize = 0,
write_end: BufferSize = 0,
read_buf: [buffer_size]u8 = undefined,
write_buf: [buffer_size]u8 = undefined,
pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize {
return switch (conn.protocol) {
.plain => conn.stream.readAtLeast(buffer, len),
.tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
.plain => conn.stream.readv(buffers),
.tls => conn.tls_client.readv(conn.stream, buffers),
} catch |err| {
// TODO: https://github.com/ziglang/zig/issues/2473
if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
@ -188,58 +194,52 @@ pub const Connection = struct {
pub fn fill(conn: *Connection) ReadError!void {
if (conn.read_end != conn.read_start) return;
const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
var iovecs = [1]std.os.iovec{
.{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len },
};
const nread = try conn.readvDirect(&iovecs);
if (nread == 0) return error.EndOfStream;
conn.read_start = 0;
conn.read_end = @as(u16, @intCast(nread));
conn.read_end = @intCast(nread);
}
pub fn peek(conn: *Connection) []const u8 {
return conn.read_buf[conn.read_start..conn.read_end];
}
pub fn drop(conn: *Connection, num: u16) void {
pub fn drop(conn: *Connection, num: BufferSize) void {
conn.read_start += num;
}
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
assert(len <= buffer.len);
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
const available_read = conn.read_end - conn.read_start;
const available_buffer = buffer.len;
var out_index: u16 = 0;
while (out_index < len) {
const available_read = conn.read_end - conn.read_start;
const available_buffer = buffer.len - out_index;
if (available_read > available_buffer) { // partially read buffered data
@memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
conn.read_start += @intCast(available_buffer);
if (available_read > available_buffer) { // partially read buffered data
@memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
out_index += @as(u16, @intCast(available_buffer));
conn.read_start += @as(u16, @intCast(available_buffer));
return available_buffer;
} else if (available_read > 0) { // fully read buffered data
@memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
conn.read_start += available_read;
break;
} else if (available_read > 0) { // fully read buffered data
@memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
out_index += available_read;
conn.read_start += available_read;
if (out_index >= len) break;
}
const leftover_buffer = available_buffer - available_read;
const leftover_len = len - out_index;
if (leftover_buffer > conn.read_buf.len) {
// skip the buffer if the output is large enough
return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
}
try conn.fill();
return available_read;
}
return out_index;
}
var iovecs = [2]std.os.iovec{
.{ .iov_base = buffer.ptr, .iov_len = buffer.len },
.{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len },
};
const nread = try conn.readvDirect(&iovecs);
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
return conn.readAtLeast(buffer, 1);
if (nread > buffer.len) {
conn.read_start = 0;
conn.read_end = @intCast(nread - buffer.len);
return buffer.len;
}
return nread;
}
pub const ReadError = error{
@ -257,7 +257,7 @@ pub const Connection = struct {
return Reader{ .context = conn };
}
pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void {
return switch (conn.protocol) {
.plain => conn.stream.writeAll(buffer),
.tls => conn.tls_client.writeAll(conn.stream, buffer),
@ -267,14 +267,27 @@ pub const Connection = struct {
};
}
pub fn write(conn: *Connection, buffer: []const u8) !usize {
return switch (conn.protocol) {
.plain => conn.stream.write(buffer),
.tls => conn.tls_client.write(conn.stream, buffer),
} catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
if (conn.write_end + buffer.len > conn.write_buf.len) {
try conn.flush();
if (buffer.len > conn.write_buf.len) {
try conn.writeAllDirect(buffer);
return buffer.len;
}
}
@memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer);
conn.write_end += @intCast(buffer.len);
return buffer.len;
}
pub fn flush(conn: *Connection) WriteError!void {
if (conn.write_end == 0) return;
try conn.writeAllDirect(conn.write_buf[0..conn.write_end]);
conn.write_end = 0;
}
pub const WriteError = error{
@ -455,7 +468,7 @@ pub const Request = struct {
uri: Uri,
client: *Client,
/// is null when this connection is released
connection: ?*ConnectionPool.Node,
connection: ?*Connection,
method: http.Method,
version: http.Version = .@"HTTP/1.1",
@ -489,7 +502,7 @@ pub const Request = struct {
if (req.connection) |connection| {
if (!req.response.parser.done) {
// If the response wasn't fully read, then we need to close the connection.
connection.data.closing = true;
connection.closing = true;
}
req.client.connection_pool.release(req.client.allocator, connection);
}
@ -548,8 +561,7 @@ pub const Request = struct {
pub fn start(req: *Request, options: StartOptions) StartError!void {
if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding;
var buffered = std.io.bufferedWriter(req.connection.?.data.writer());
const w = buffered.writer();
const w = req.connection.?.writer();
try req.method.write(w);
try w.writeByte(' ');
@ -558,9 +570,9 @@ pub const Request = struct {
try req.uri.writeToStream(.{ .authority = true }, w);
} else {
try req.uri.writeToStream(.{
.scheme = req.connection.?.data.proxied,
.authentication = req.connection.?.data.proxied,
.authority = req.connection.?.data.proxied,
.scheme = req.connection.?.proxied,
.authentication = req.connection.?.proxied,
.authority = req.connection.?.proxied,
.path = true,
.query = true,
.raw = options.raw_uri,
@ -629,8 +641,8 @@ pub const Request = struct {
try w.writeAll("\r\n");
}
if (req.connection.?.data.proxied) {
const proxy_headers: ?http.Headers = switch (req.connection.?.data.protocol) {
if (req.connection.?.proxied) {
const proxy_headers: ?http.Headers = switch (req.connection.?.protocol) {
.plain => if (req.client.http_proxy) |proxy| proxy.headers else null,
.tls => if (req.client.https_proxy) |proxy| proxy.headers else null,
};
@ -649,7 +661,7 @@ pub const Request = struct {
try w.writeAll("\r\n");
try buffered.flush();
try req.connection.?.flush();
}
const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
@ -665,7 +677,7 @@ pub const Request = struct {
var index: usize = 0;
while (index == 0) {
const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip);
const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip);
if (amt == 0 and req.response.parser.done) break;
index += amt;
}
@ -683,10 +695,10 @@ pub const Request = struct {
pub fn wait(req: *Request) WaitError!void {
while (true) { // handle redirects
while (true) { // read headers
try req.connection.?.data.fill();
try req.connection.?.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
req.connection.?.data.drop(@as(u16, @intCast(nchecked)));
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek());
req.connection.?.drop(@intCast(nchecked));
if (req.response.parser.state.isContent()) break;
}
@ -701,7 +713,7 @@ pub const Request = struct {
// we're switching protocols, so this connection is no longer doing http
if (req.response.status == .switching_protocols or (req.method == .CONNECT and req.response.status == .ok)) {
req.connection.?.data.closing = false;
req.connection.?.closing = false;
req.response.parser.done = true;
}
@ -712,9 +724,9 @@ pub const Request = struct {
const res_connection = req.response.headers.getFirstValue("connection");
const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
if (res_keepalive and (req_keepalive or req_connection == null)) {
req.connection.?.data.closing = false;
req.connection.?.closing = false;
} else {
req.connection.?.data.closing = true;
req.connection.?.closing = true;
}
if (req.response.transfer_encoding) |te| {
@ -827,10 +839,10 @@ pub const Request = struct {
const has_trail = !req.response.parser.state.isContent();
while (!req.response.parser.state.isContent()) { // read trailing headers
try req.connection.?.data.fill();
try req.connection.?.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
req.connection.?.data.drop(@as(u16, @intCast(nchecked)));
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek());
req.connection.?.drop(@intCast(nchecked));
}
if (has_trail) {
@ -868,16 +880,16 @@ pub const Request = struct {
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
switch (req.transfer_encoding) {
.chunked => {
try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len});
try req.connection.?.data.writeAll(bytes);
try req.connection.?.data.writeAll("\r\n");
try req.connection.?.writer().print("{x}\r\n", .{bytes.len});
try req.connection.?.writer().writeAll(bytes);
try req.connection.?.writer().writeAll("\r\n");
return bytes.len;
},
.content_length => |*len| {
if (len.* < bytes.len) return error.MessageTooLong;
const amt = try req.connection.?.data.write(bytes);
const amt = try req.connection.?.write(bytes);
len.* -= amt;
return amt;
},
@ -897,10 +909,12 @@ pub const Request = struct {
/// Finish the body of a request. This notifies the server that you have no more data to send.
pub fn finish(req: *Request) FinishError!void {
switch (req.transfer_encoding) {
.chunked => try req.connection.?.data.writeAll("0\r\n\r\n"),
.chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"),
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
.none => {},
}
try req.connection.?.flush();
}
};
@ -1024,7 +1038,7 @@ pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, Network
/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
/// This function is threadsafe.
pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*ConnectionPool.Node {
pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection {
if (client.connection_pool.findConnection(.{
.host = host,
.port = port,
@ -1074,12 +1088,12 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
client.connection_pool.addUsed(conn);
return conn;
return &conn.data;
}
pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{ NameTooLong, Unsupported } || std.os.ConnectError;
pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*ConnectionPool.Node {
pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection {
if (!net.has_unix_sockets) return error.Unsupported;
if (client.connection_pool.findConnection(.{
@ -1108,7 +1122,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti
client.connection_pool.addUsed(conn);
return conn;
return &conn.data;
}
pub fn connectTunnel(
@ -1116,7 +1130,7 @@ pub fn connectTunnel(
proxy: *ProxyInformation,
tunnel_host: []const u8,
tunnel_port: u16,
) !*ConnectionPool.Node {
) !*Connection {
if (!proxy.supports_connect) return error.TunnelNotSupported;
if (client.connection_pool.findConnection(.{
@ -1130,7 +1144,7 @@ pub fn connectTunnel(
_ = tunnel: {
const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol);
errdefer {
conn.data.closing = true;
conn.closing = true;
client.connection_pool.release(client.allocator, conn);
}
@ -1171,12 +1185,12 @@ pub fn connectTunnel(
// this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized.
req.connection = null;
client.allocator.free(conn.data.host);
conn.data.host = try client.allocator.dupe(u8, tunnel_host);
errdefer client.allocator.free(conn.data.host);
client.allocator.free(conn.host);
conn.host = try client.allocator.dupe(u8, tunnel_host);
errdefer client.allocator.free(conn.host);
conn.data.port = tunnel_port;
conn.data.closing = false;
conn.port = tunnel_port;
conn.closing = false;
return conn;
} catch {
@ -1190,7 +1204,7 @@ pub fn connectTunnel(
const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, ConnectionRefused };
pub const ConnectError = ConnectErrorPartial || RequestError;
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*Connection {
// pointer required so that `supports_connect` can be updated if a CONNECT fails
const potential_proxy: ?*ProxyInformation = switch (protocol) {
.plain => if (client.http_proxy) |*proxy_info| proxy_info else null,
@ -1213,11 +1227,11 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
// fall back to using the proxy as a normal http proxy
const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol);
errdefer {
conn.data.closing = true;
conn.closing = true;
client.connection_pool.release(conn);
}
conn.data.proxied = true;
conn.proxied = true;
return conn;
}
@ -1240,7 +1254,7 @@ pub const RequestOptions = struct {
header_strategy: StorageStrategy = .{ .dynamic = 16 * 1024 },
/// Must be an already acquired connection.
connection: ?*ConnectionPool.Node = null,
connection: ?*Connection = null,
pub const StorageStrategy = union(enum) {
/// In this case, the client's Allocator will be used to store the

View File

@ -529,7 +529,7 @@ pub const HeadersParser = struct {
try conn.fill();
const nread = @min(conn.peek().len, data_avail);
conn.drop(@as(u16, @intCast(nread)));
conn.drop(@intCast(nread));
r.next_chunk_length -= nread;
if (r.next_chunk_length == 0) r.done = true;
@ -553,7 +553,7 @@ pub const HeadersParser = struct {
try conn.fill();
const i = r.findChunkedLen(conn.peek());
conn.drop(@as(u16, @intCast(i)));
conn.drop(@intCast(i));
switch (r.state) {
.invalid => return error.HttpChunkInvalid,
@ -582,7 +582,7 @@ pub const HeadersParser = struct {
try conn.fill();
const nread = @min(conn.peek().len, data_avail);
conn.drop(@as(u16, @intCast(nread)));
conn.drop(@intCast(nread));
r.next_chunk_length -= nread;
} else if (out_avail > 0) {
const can_read: usize = @intCast(@min(data_avail, out_avail));

View File

@ -680,7 +680,7 @@ pub fn main() !void {
for (0..total_connections) |i| {
var req = try client.request(.GET, uri, .{ .allocator = calloc }, .{});
req.response.parser.done = true;
req.connection.?.data.closing = false;
req.connection.?.closing = false;
requests[i] = req;
}