diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index d44b1d098d..baf0239388 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -16,6 +16,9 @@ const testing = std.testing; pub const Request = @import("Client/Request.zig"); pub const Response = @import("Client/Response.zig"); +pub const default_connection_pool_size = 32; +const connection_pool_size = std.options.http_connection_pool_size; + /// Used for tcpConnectToHost and storing HTTP headers when an externally /// managed buffer is not provided. allocator: Allocator, @@ -24,39 +27,115 @@ ca_bundle: std.crypto.Certificate.Bundle = .{}, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, -connection_mutex: std.Thread.Mutex = .{}, connection_pool: ConnectionPool = .{}, -connection_used: ConnectionPool = .{}, -pub const ConnectionPool = std.TailQueue(Connection); -pub const ConnectionNode = ConnectionPool.Node; +pub const ConnectionPool = struct { + pub const Criteria = struct { + host: []const u8, + port: u16, + is_tls: bool, + }; -/// Acquires an existing connection from the connection pool. This function is threadsafe. -/// If the caller already holds the connection mutex, it should pass `true` for `held`. -pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void { - if (!held) client.connection_mutex.lock(); - defer if (!held) client.connection_mutex.unlock(); + const Queue = std.TailQueue(Connection); + pub const Node = Queue.Node; - client.connection_pool.remove(node); - client.connection_used.append(node); -} + mutex: std.Thread.Mutex = .{}, + used: Queue = .{}, + free: Queue = .{}, + free_len: usize = 0, + free_size: usize = default_connection_pool_size, -/// Tries to release a connection back to the connection pool. This function is threadsafe. -/// If the connection is marked as closing, it will be closed instead. -pub fn release(client: *Client, node: *ConnectionNode) void { - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); + /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// If no connection is found, null is returned. + pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node { + pool.mutex.lock(); + defer pool.mutex.unlock(); - client.connection_used.remove(node); + var next = pool.free.last; + while (next) |node| : (next = node.prev) { + if ((node.data.protocol == .tls) != criteria.is_tls) continue; + if (node.data.port != criteria.port) continue; + if (std.mem.eql(u8, node.data.host, criteria.host)) continue; - if (node.data.closing) { - node.data.close(client); + pool.acquireUnsafe(node); + return node; + } - return client.allocator.destroy(node); + return null; } - client.connection_pool.append(node); -} + /// Acquires an existing connection from the connection pool. This function is not threadsafe. + pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void { + pool.free.remove(node); + pool.free_len -= 1; + + pool.used.append(node); + } + + /// Acquires an existing connection from the connection pool. This function is threadsafe. + pub fn acquire(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + return pool.acquireUnsafe(node); + } + + /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// If the connection is marked as closing, it will be closed instead. + pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + pool.used.remove(node); + + if (node.data.closing) { + node.data.close(client); + + return client.allocator.destroy(node); + } + + if (pool.free_len + 1 >= pool.free_size) { + const popped = pool.free.popFirst() orelse unreachable; + + popped.data.close(client); + + return client.allocator.destroy(popped); + } + + pool.free.append(node); + pool.free_len += 1; + } + + /// Adds a newly created node to the pool of used connections. This function is threadsafe. + pub fn addUsed(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + pool.used.append(node); + } + + pub fn deinit(pool: *ConnectionPool, client: *Client) void { + pool.mutex.lock(); + + var next = pool.free.first; + while (next) |node| { + defer client.allocator.destroy(node); + next = node.next; + + node.data.close(client); + } + + next = pool.used.first; + while (next) |node| { + defer client.allocator.destroy(node); + next = node.next; + + node.data.close(client); + } + + pool.* = undefined; + } +}; pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw); pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw); @@ -142,25 +221,7 @@ pub const Connection = struct { }; pub fn deinit(client: *Client) void { - client.connection_mutex.lock(); - - var next = client.connection_pool.first; - while (next) |node| { - next = node.next; - - node.data.close(client); - - client.allocator.destroy(node); - } - - next = client.connection_used.first; - while (next) |node| { - next = node.next; - - node.data.close(client); - - client.allocator.destroy(node); - } + client.connection_pool.deinit(client); client.ca_bundle.deinit(client.allocator); client.* = undefined; @@ -168,36 +229,25 @@ pub fn deinit(client: *Client) void { pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream); -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode { - { // Search through the connection pool for a potential connection. - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node { + if (client.connection_pool.findConnection(.{ + .host = host, + .port = port, + .is_tls = protocol == .tls, + })) |node| + return node; - var potential = client.connection_pool.last; - while (potential) |node| { - const same_host = mem.eql(u8, node.data.host, host); - const same_port = node.data.port == port; - const same_protocol = node.data.protocol == protocol; - - if (same_host and same_port and same_protocol) { - client.acquire(node, true); - return node; - } - - potential = node.prev; - } - } - - const conn = try client.allocator.create(ConnectionNode); + const conn = try client.allocator.create(ConnectionPool.Node); errdefer client.allocator.destroy(conn); + conn.* = .{ .data = undefined }; - conn.* = .{ .data = .{ + conn.data = .{ .stream = try net.tcpConnectToHost(client.allocator, host, port), .tls_client = undefined, .protocol = protocol, .host = try client.allocator.dupe(u8, host), .port = port, - } }; + }; switch (protocol) { .plain => {}, @@ -210,12 +260,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio }, } - { - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); - - client.connection_used.append(conn); - } + client.connection_pool.addUsed(conn); return conn; } @@ -247,8 +292,8 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req const host = uri.host orelse return error.UriMissingHost; if (client.next_https_rescan_certs and protocol == .tls) { - client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex. - defer client.connection_mutex.unlock(); + client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex. + defer client.connection_pool.mutex.unlock(); if (client.next_https_rescan_certs) { try client.ca_bundle.rescan(client.allocator); diff --git a/lib/std/http/Client/Request.zig b/lib/std/http/Client/Request.zig index 26ce5cb7bf..9e2ebd2d6c 100644 --- a/lib/std/http/Client/Request.zig +++ b/lib/std/http/Client/Request.zig @@ -6,7 +6,7 @@ const assert = std.debug.assert; const Client = @import("../Client.zig"); const Connection = Client.Connection; -const ConnectionNode = Client.ConnectionNode; +const ConnectionNode = Client.ConnectionPool.Node; const Response = @import("Response.zig"); const Request = @This(); @@ -85,7 +85,7 @@ pub fn deinit(req: *Request) void { if (!req.response.done) { // If the response wasn't fully read, then we need to close the connection. req.connection.data.closing = true; - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); } req.arena.deinit(); @@ -135,7 +135,7 @@ fn checkForCompleteHead(req: *Request, buffer: []u8) !usize { if (req.response.state == .finished) { req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); - if (req.response.upgrade) |_| { + if (req.response.headers.upgrade) |_| { req.connection.data.closing = false; req.response.done = true; return i; @@ -226,7 +226,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { req.response.next_chunk_length -= can_read; if (req.response.next_chunk_length == 0) { - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); req.connection = undefined; req.response.done = true; } @@ -241,7 +241,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { req.read_buffer_start += @intCast(ReadBufferIndex, can_read); if (req.response.next_chunk_length == 0) { - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); req.connection = undefined; req.response.done = true; } @@ -293,7 +293,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { .chunk_data => { if (req.response.next_chunk_length == 0) { req.response.done = true; - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); req.connection = undefined; return out_index; @@ -317,7 +317,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { req.response.next_chunk_length -= can_read; if (req.response.next_chunk_length == 0) { - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); req.connection = undefined; req.response.done = true; continue; @@ -345,13 +345,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { } } -pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ - BadHeader, - InvalidCompression, - StreamTooLong, - InvalidWindowSize, - CompressionNotSupported -}; +pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize, CompressionNotSupported }; pub const Reader = std.io.Reader(*Request, ReadError, read); diff --git a/lib/std/http/Client/Response.zig b/lib/std/http/Client/Response.zig index bc064f9a20..8b2a9a4918 100644 --- a/lib/std/http/Client/Response.zig +++ b/lib/std/http/Client/Response.zig @@ -32,6 +32,7 @@ pub const Headers = struct { transfer_encoding: ?http.TransferEncoding = null, transfer_compression: ?http.ContentEncoding = null, connection: http.Connection = .close, + upgrade: ?[]const u8 = null, number_of_headers: usize = 0, @@ -93,7 +94,7 @@ pub const Headers = struct { if (iter.next()) |second| { if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; - + const trimmed = std.mem.trim(u8, second, " "); if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { @@ -122,6 +123,8 @@ pub const Headers = struct { } else { return error.HttpConnectionHeaderUnsupported; } + } else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) { + headers.upgrade = header_value; } } diff --git a/lib/std/std.zig b/lib/std/std.zig index c1c682e224..e888ade659 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -185,6 +185,11 @@ pub const options = struct { options_override.keep_sigpipe else false; + + pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size")) + options_override.http_connection_pool_size + else + http.Client.default_connection_pool_size; }; // This forces the start.zig file to be imported, and the comptime logic inside that