diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 7f62b2d597..d4d8f85ad1 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -21,11 +21,27 @@ ca_bundle: std.crypto.Certificate.Bundle = .{}, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, +connection_pool: std.TailQueue(Connection) = .{}, + +const ConnectionPool = std.TailQueue(Connection); +const ConnectionNode = ConnectionPool.Node; + +pub fn release(client: *Client, node: *ConnectionNode) void { + if (node.data.unusable) return node.data.close(client); + + client.connection_pool.append(node); +} + pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: std.crypto.tls.Client, + tls_client: std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB. protocol: Protocol, + host: []u8, + port: u16, + + // This connection has been part of a non keepalive request and cannot be added to the pool. + unusable: bool = false, pub const Protocol = enum { plain, tls }; @@ -56,6 +72,17 @@ pub const Connection = struct { .tls => return conn.tls_client.write(conn.stream, buffer), } } + + pub fn close(conn: *Connection, client: *const Client) void { + if (conn.protocol == .tls) { + // try to cleanly close the TLS connection, for any server that cares. + _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + } + + conn.stream.close(); + + client.allocator.free(conn.host); + } }; /// TODO: emit error.UnexpectedEndOfStream or something like that when the read @@ -63,7 +90,7 @@ pub const Connection = struct { /// close_notify protection on underlying TLS streams. pub const Request = struct { client: *Client, - connection: Connection, + connection: *ConnectionNode, redirects_left: u32, response: Response, /// These are stored in Request so that they are available when following @@ -79,6 +106,7 @@ pub const Request = struct { header_bytes: std.ArrayListUnmanaged(u8), max_header_bytes: usize, next_chunk_length: u64, + done: bool, pub const Headers = struct { status: http.Status, @@ -86,6 +114,7 @@ pub const Request = struct { location: ?[]const u8 = null, content_length: ?u64 = null, transfer_encoding: ?http.TransferEncoding = null, + connection_close: bool = true, pub fn parse(bytes: []const u8) !Response.Headers { var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); @@ -126,6 +155,14 @@ pub const Request = struct { if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { + headers.connection_close = false; + } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { + headers.connection_close = true; + } else { + return error.HttpConnectionHeaderUnsupported; + } } } @@ -185,10 +222,10 @@ pub const Request = struct { chunk_r, chunk_data, - pub fn zeroMeansEnd(state: State) bool { - return switch (state) { - .finished, .chunk_data => true, - else => false, + pub fn isContent(self: State) bool { + return switch (self) { + .invalid, .start, .seen_r, .seen_rn, .seen_rnr => false, + .finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true, }; } }; @@ -201,6 +238,7 @@ pub const Request = struct { .max_header_bytes = max, .header_bytes_owned = true, .next_chunk_length = undefined, + .done = false, }; } @@ -212,6 +250,7 @@ pub const Request = struct { .max_header_bytes = buf.len, .header_bytes_owned = false, .next_chunk_length = undefined, + .done = false, }; } @@ -501,6 +540,7 @@ pub const Request = struct { pub const Headers = struct { version: http.Version = .@"HTTP/1.1", method: http.Method = .GET, + connection_close: bool = false, }; pub const Options = struct { @@ -545,6 +585,7 @@ pub const Request = struct { HttpHeadersExceededSizeLimit, HttpRedirectMissingLocation, HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, HttpContentLengthUnknown, TooManyHttpRedirects, ShortHttpStatusLine, @@ -669,8 +710,9 @@ pub const Request = struct { assert(len <= buffer.len); var index: usize = 0; while (index < len) { - const zero_means_end = req.response.state.zeroMeansEnd(); const amt = try readAdvanced(req, buffer[index..]); + const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect; + if (amt == 0 and zero_means_end) break; index += amt; } @@ -680,7 +722,29 @@ pub const Request = struct { /// This one can return 0 without meaning EOF. /// TODO change to readvAdvanced pub fn readAdvanced(req: *Request, buffer: []u8) !usize { - var in = buffer[0..try req.connection.read(buffer)]; + if (req.response.done) { + if (req.response.headers.status.class() == .redirect) { + if (req.redirects_left == 0) return error.TooManyHttpRedirects; + + const location = req.response.headers.location orelse + return error.HttpRedirectMissingLocation; + const new_url = try std.Uri.parse(location); + const new_req = try req.client.request(new_url, req.headers, .{ + .max_redirects = req.redirects_left - 1, + .header_strategy = if (req.response.header_bytes_owned) .{ + .dynamic = req.response.max_header_bytes, + } else .{ + .static = req.response.header_bytes.unusedCapacitySlice(), + }, + }); + req.deinit(); + req.* = new_req; + } else { + return 0; + } + } + + var in = buffer[0..try req.connection.data.read(buffer)]; var out_index: usize = 0; while (true) { switch (req.response.state) { @@ -698,24 +762,10 @@ pub const Request = struct { if (req.response.state == .finished) { req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); - if (req.response.headers.status.class() == .redirect) { - if (req.redirects_left == 0) return error.TooManyHttpRedirects; - const location = req.response.headers.location orelse - return error.HttpRedirectMissingLocation; - const new_url = try std.Uri.parse(location); - const new_req = try req.client.request(new_url, req.headers, .{ - .max_redirects = req.redirects_left - 1, - .header_strategy = if (req.response.header_bytes_owned) .{ - .dynamic = req.response.max_header_bytes, - } else .{ - .static = req.response.header_bytes.unusedCapacitySlice(), - }, - }); - req.deinit(); - req.* = new_req; - assert(out_index == 0); - in = buffer[0..try req.connection.read(buffer)]; - continue; + if (req.response.headers.connection_close == true) { + req.connection.data.unusable = true; + } else { + req.connection.data.unusable = false; } if (req.response.headers.transfer_encoding) |transfer_encoding| { @@ -742,11 +792,29 @@ pub const Request = struct { return 0; }, .finished => { + const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len)); + req.response.next_chunk_length -= sub_amt; + + if (req.response.next_chunk_length == 0) { + req.client.release(req.connection); + req.connection = undefined; + + req.response.done = true; + assert(in.len == sub_amt); // TODO: figure out how to not read more than necessary. + + if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0; + + mem.copy(u8, buffer[out_index..], in[0..sub_amt]); + return out_index + sub_amt; + } + + if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0; + if (in.ptr == buffer.ptr) { - return in.len; + return sub_amt; } else { - mem.copy(u8, buffer[out_index..], in); - return out_index + in.len; + mem.copy(u8, buffer[out_index..], in[0..sub_amt]); + return out_index + sub_amt; } }, .chunk_size_prefix_r => switch (in.len) { @@ -793,7 +861,10 @@ pub const Request = struct { .invalid => return error.HttpHeadersInvalid, .chunk_data => { if (req.response.next_chunk_length == 0) { - req.response.state = .start; + req.response.done = true; + req.client.release(req.connection); + req.connection = undefined; + return out_index; } in = in[i..]; @@ -807,20 +878,27 @@ pub const Request = struct { // TODO https://github.com/ziglang/zig/issues/14039 const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len)); req.response.next_chunk_length -= sub_amt; - if (req.response.next_chunk_length > 0) { - if (in.ptr == buffer.ptr) { - return sub_amt; - } else { - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; - return out_index; - } + + if (req.response.next_chunk_length == 0) { + req.response.state = .chunk_size_prefix_r; + in = in[sub_amt..]; + + if (req.response.headers.status.class() == .redirect) continue; + + mem.copy(u8, buffer[out_index..], in[0..sub_amt]); + out_index += sub_amt; + continue; + } + + if (req.response.headers.status.class() == .redirect) return 0; + + if (in.ptr == buffer.ptr) { + return sub_amt; + } else { + mem.copy(u8, buffer[out_index..], in[0..sub_amt]); + out_index += sub_amt; + return out_index; } - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; - req.response.state = .chunk_size_prefix_r; - in = in[sub_amt..]; - continue; }, } } @@ -844,24 +922,52 @@ pub const Request = struct { }; pub fn deinit(client: *Client) void { + var next = client.connection_pool.first; + while (next) |node| { + next = node.next; + + node.data.close(client); + + client.allocator.destroy(node); + } + client.ca_bundle.deinit(client.allocator); client.* = undefined; } -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !Connection { - var conn: Connection = .{ +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !*ConnectionNode { + var potential = client.connection_pool.last; + while (potential) |node| { + const same_host = mem.eql(u8, node.data.host, host); + const same_port = node.data.port == port; + const same_protocol = node.data.protocol == protocol; + + if (same_host and same_port and same_protocol) { + client.connection_pool.remove(node); + return node; + } + + potential = node.prev; + } + + const conn = try client.allocator.create(ConnectionNode); + errdefer client.allocator.destroy(conn); + + conn.* = .{ .data = .{ .stream = try net.tcpConnectToHost(client.allocator, host, port), .tls_client = undefined, .protocol = protocol, - }; + .host = try client.allocator.dupe(u8, host), + .port = port, + } }; switch (protocol) { .plain => {}, .tls => { - conn.tls_client = try std.crypto.tls.Client.init(conn.stream, client.ca_bundle, host); + conn.data.tls_client = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host); // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. - conn.tls_client.allow_truncation_attacks = true; + conn.data.tls_client.allow_truncation_attacks = true; }, } @@ -908,10 +1014,15 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req try h.appendSlice(@tagName(headers.version)); try h.appendSlice("\r\nHost: "); try h.appendSlice(host); - try h.appendSlice("\r\nConnection: close\r\n\r\n"); + if (headers.connection_close) { + try h.appendSlice("\r\nConnection: close"); + } else { + try h.appendSlice("\r\nConnection: keep-alive"); + } + try h.appendSlice("\r\n\r\n"); const header_bytes = h.slice(); - try req.connection.writeAll(header_bytes); + try req.connection.data.writeAll(header_bytes); } return req;