mirror of
https://github.com/ziglang/zig.git
synced 2026-02-13 04:48:20 +00:00
std.http: add connection pooling and make keep-alive requests by default
This commit is contained in:
parent
95f6a5935a
commit
afb26f4e6b
@ -21,11 +21,27 @@ ca_bundle: std.crypto.Certificate.Bundle = .{},
|
||||
/// it will first rescan the system for root certificates.
|
||||
next_https_rescan_certs: bool = true,
|
||||
|
||||
connection_pool: std.TailQueue(Connection) = .{},
|
||||
|
||||
const ConnectionPool = std.TailQueue(Connection);
|
||||
const ConnectionNode = ConnectionPool.Node;
|
||||
|
||||
pub fn release(client: *Client, node: *ConnectionNode) void {
|
||||
if (node.data.unusable) return node.data.close(client);
|
||||
|
||||
client.connection_pool.append(node);
|
||||
}
|
||||
|
||||
pub const Connection = struct {
|
||||
stream: net.Stream,
|
||||
/// undefined unless protocol is tls.
|
||||
tls_client: std.crypto.tls.Client,
|
||||
tls_client: std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
|
||||
protocol: Protocol,
|
||||
host: []u8,
|
||||
port: u16,
|
||||
|
||||
// This connection has been part of a non keepalive request and cannot be added to the pool.
|
||||
unusable: bool = false,
|
||||
|
||||
pub const Protocol = enum { plain, tls };
|
||||
|
||||
@ -56,6 +72,17 @@ pub const Connection = struct {
|
||||
.tls => return conn.tls_client.write(conn.stream, buffer),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close(conn: *Connection, client: *const Client) void {
|
||||
if (conn.protocol == .tls) {
|
||||
// try to cleanly close the TLS connection, for any server that cares.
|
||||
_ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
|
||||
}
|
||||
|
||||
conn.stream.close();
|
||||
|
||||
client.allocator.free(conn.host);
|
||||
}
|
||||
};
|
||||
|
||||
/// TODO: emit error.UnexpectedEndOfStream or something like that when the read
|
||||
@ -63,7 +90,7 @@ pub const Connection = struct {
|
||||
/// close_notify protection on underlying TLS streams.
|
||||
pub const Request = struct {
|
||||
client: *Client,
|
||||
connection: Connection,
|
||||
connection: *ConnectionNode,
|
||||
redirects_left: u32,
|
||||
response: Response,
|
||||
/// These are stored in Request so that they are available when following
|
||||
@ -79,6 +106,7 @@ pub const Request = struct {
|
||||
header_bytes: std.ArrayListUnmanaged(u8),
|
||||
max_header_bytes: usize,
|
||||
next_chunk_length: u64,
|
||||
done: bool,
|
||||
|
||||
pub const Headers = struct {
|
||||
status: http.Status,
|
||||
@ -86,6 +114,7 @@ pub const Request = struct {
|
||||
location: ?[]const u8 = null,
|
||||
content_length: ?u64 = null,
|
||||
transfer_encoding: ?http.TransferEncoding = null,
|
||||
connection_close: bool = true,
|
||||
|
||||
pub fn parse(bytes: []const u8) !Response.Headers {
|
||||
var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n");
|
||||
@ -126,6 +155,14 @@ pub const Request = struct {
|
||||
if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
|
||||
headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse
|
||||
return error.HttpTransferEncodingUnsupported;
|
||||
} else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
|
||||
if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
|
||||
headers.connection_close = false;
|
||||
} else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
|
||||
headers.connection_close = true;
|
||||
} else {
|
||||
return error.HttpConnectionHeaderUnsupported;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -185,10 +222,10 @@ pub const Request = struct {
|
||||
chunk_r,
|
||||
chunk_data,
|
||||
|
||||
pub fn zeroMeansEnd(state: State) bool {
|
||||
return switch (state) {
|
||||
.finished, .chunk_data => true,
|
||||
else => false,
|
||||
pub fn isContent(self: State) bool {
|
||||
return switch (self) {
|
||||
.invalid, .start, .seen_r, .seen_rn, .seen_rnr => false,
|
||||
.finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true,
|
||||
};
|
||||
}
|
||||
};
|
||||
@ -201,6 +238,7 @@ pub const Request = struct {
|
||||
.max_header_bytes = max,
|
||||
.header_bytes_owned = true,
|
||||
.next_chunk_length = undefined,
|
||||
.done = false,
|
||||
};
|
||||
}
|
||||
|
||||
@ -212,6 +250,7 @@ pub const Request = struct {
|
||||
.max_header_bytes = buf.len,
|
||||
.header_bytes_owned = false,
|
||||
.next_chunk_length = undefined,
|
||||
.done = false,
|
||||
};
|
||||
}
|
||||
|
||||
@ -501,6 +540,7 @@ pub const Request = struct {
|
||||
pub const Headers = struct {
|
||||
version: http.Version = .@"HTTP/1.1",
|
||||
method: http.Method = .GET,
|
||||
connection_close: bool = false,
|
||||
};
|
||||
|
||||
pub const Options = struct {
|
||||
@ -545,6 +585,7 @@ pub const Request = struct {
|
||||
HttpHeadersExceededSizeLimit,
|
||||
HttpRedirectMissingLocation,
|
||||
HttpTransferEncodingUnsupported,
|
||||
HttpConnectionHeaderUnsupported,
|
||||
HttpContentLengthUnknown,
|
||||
TooManyHttpRedirects,
|
||||
ShortHttpStatusLine,
|
||||
@ -669,8 +710,9 @@ pub const Request = struct {
|
||||
assert(len <= buffer.len);
|
||||
var index: usize = 0;
|
||||
while (index < len) {
|
||||
const zero_means_end = req.response.state.zeroMeansEnd();
|
||||
const amt = try readAdvanced(req, buffer[index..]);
|
||||
const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect;
|
||||
|
||||
if (amt == 0 and zero_means_end) break;
|
||||
index += amt;
|
||||
}
|
||||
@ -680,7 +722,29 @@ pub const Request = struct {
|
||||
/// This one can return 0 without meaning EOF.
|
||||
/// TODO change to readvAdvanced
|
||||
pub fn readAdvanced(req: *Request, buffer: []u8) !usize {
|
||||
var in = buffer[0..try req.connection.read(buffer)];
|
||||
if (req.response.done) {
|
||||
if (req.response.headers.status.class() == .redirect) {
|
||||
if (req.redirects_left == 0) return error.TooManyHttpRedirects;
|
||||
|
||||
const location = req.response.headers.location orelse
|
||||
return error.HttpRedirectMissingLocation;
|
||||
const new_url = try std.Uri.parse(location);
|
||||
const new_req = try req.client.request(new_url, req.headers, .{
|
||||
.max_redirects = req.redirects_left - 1,
|
||||
.header_strategy = if (req.response.header_bytes_owned) .{
|
||||
.dynamic = req.response.max_header_bytes,
|
||||
} else .{
|
||||
.static = req.response.header_bytes.unusedCapacitySlice(),
|
||||
},
|
||||
});
|
||||
req.deinit();
|
||||
req.* = new_req;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
var in = buffer[0..try req.connection.data.read(buffer)];
|
||||
var out_index: usize = 0;
|
||||
while (true) {
|
||||
switch (req.response.state) {
|
||||
@ -698,24 +762,10 @@ pub const Request = struct {
|
||||
if (req.response.state == .finished) {
|
||||
req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
|
||||
|
||||
if (req.response.headers.status.class() == .redirect) {
|
||||
if (req.redirects_left == 0) return error.TooManyHttpRedirects;
|
||||
const location = req.response.headers.location orelse
|
||||
return error.HttpRedirectMissingLocation;
|
||||
const new_url = try std.Uri.parse(location);
|
||||
const new_req = try req.client.request(new_url, req.headers, .{
|
||||
.max_redirects = req.redirects_left - 1,
|
||||
.header_strategy = if (req.response.header_bytes_owned) .{
|
||||
.dynamic = req.response.max_header_bytes,
|
||||
} else .{
|
||||
.static = req.response.header_bytes.unusedCapacitySlice(),
|
||||
},
|
||||
});
|
||||
req.deinit();
|
||||
req.* = new_req;
|
||||
assert(out_index == 0);
|
||||
in = buffer[0..try req.connection.read(buffer)];
|
||||
continue;
|
||||
if (req.response.headers.connection_close == true) {
|
||||
req.connection.data.unusable = true;
|
||||
} else {
|
||||
req.connection.data.unusable = false;
|
||||
}
|
||||
|
||||
if (req.response.headers.transfer_encoding) |transfer_encoding| {
|
||||
@ -742,11 +792,29 @@ pub const Request = struct {
|
||||
return 0;
|
||||
},
|
||||
.finished => {
|
||||
const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len));
|
||||
req.response.next_chunk_length -= sub_amt;
|
||||
|
||||
if (req.response.next_chunk_length == 0) {
|
||||
req.client.release(req.connection);
|
||||
req.connection = undefined;
|
||||
|
||||
req.response.done = true;
|
||||
assert(in.len == sub_amt); // TODO: figure out how to not read more than necessary.
|
||||
|
||||
if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0;
|
||||
|
||||
mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
|
||||
return out_index + sub_amt;
|
||||
}
|
||||
|
||||
if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0;
|
||||
|
||||
if (in.ptr == buffer.ptr) {
|
||||
return in.len;
|
||||
return sub_amt;
|
||||
} else {
|
||||
mem.copy(u8, buffer[out_index..], in);
|
||||
return out_index + in.len;
|
||||
mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
|
||||
return out_index + sub_amt;
|
||||
}
|
||||
},
|
||||
.chunk_size_prefix_r => switch (in.len) {
|
||||
@ -793,7 +861,10 @@ pub const Request = struct {
|
||||
.invalid => return error.HttpHeadersInvalid,
|
||||
.chunk_data => {
|
||||
if (req.response.next_chunk_length == 0) {
|
||||
req.response.state = .start;
|
||||
req.response.done = true;
|
||||
req.client.release(req.connection);
|
||||
req.connection = undefined;
|
||||
|
||||
return out_index;
|
||||
}
|
||||
in = in[i..];
|
||||
@ -807,20 +878,27 @@ pub const Request = struct {
|
||||
// TODO https://github.com/ziglang/zig/issues/14039
|
||||
const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len));
|
||||
req.response.next_chunk_length -= sub_amt;
|
||||
if (req.response.next_chunk_length > 0) {
|
||||
if (in.ptr == buffer.ptr) {
|
||||
return sub_amt;
|
||||
} else {
|
||||
mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
|
||||
out_index += sub_amt;
|
||||
return out_index;
|
||||
}
|
||||
|
||||
if (req.response.next_chunk_length == 0) {
|
||||
req.response.state = .chunk_size_prefix_r;
|
||||
in = in[sub_amt..];
|
||||
|
||||
if (req.response.headers.status.class() == .redirect) continue;
|
||||
|
||||
mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
|
||||
out_index += sub_amt;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (req.response.headers.status.class() == .redirect) return 0;
|
||||
|
||||
if (in.ptr == buffer.ptr) {
|
||||
return sub_amt;
|
||||
} else {
|
||||
mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
|
||||
out_index += sub_amt;
|
||||
return out_index;
|
||||
}
|
||||
mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
|
||||
out_index += sub_amt;
|
||||
req.response.state = .chunk_size_prefix_r;
|
||||
in = in[sub_amt..];
|
||||
continue;
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -844,24 +922,52 @@ pub const Request = struct {
|
||||
};
|
||||
|
||||
pub fn deinit(client: *Client) void {
|
||||
var next = client.connection_pool.first;
|
||||
while (next) |node| {
|
||||
next = node.next;
|
||||
|
||||
node.data.close(client);
|
||||
|
||||
client.allocator.destroy(node);
|
||||
}
|
||||
|
||||
client.ca_bundle.deinit(client.allocator);
|
||||
client.* = undefined;
|
||||
}
|
||||
|
||||
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !Connection {
|
||||
var conn: Connection = .{
|
||||
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !*ConnectionNode {
|
||||
var potential = client.connection_pool.last;
|
||||
while (potential) |node| {
|
||||
const same_host = mem.eql(u8, node.data.host, host);
|
||||
const same_port = node.data.port == port;
|
||||
const same_protocol = node.data.protocol == protocol;
|
||||
|
||||
if (same_host and same_port and same_protocol) {
|
||||
client.connection_pool.remove(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
potential = node.prev;
|
||||
}
|
||||
|
||||
const conn = try client.allocator.create(ConnectionNode);
|
||||
errdefer client.allocator.destroy(conn);
|
||||
|
||||
conn.* = .{ .data = .{
|
||||
.stream = try net.tcpConnectToHost(client.allocator, host, port),
|
||||
.tls_client = undefined,
|
||||
.protocol = protocol,
|
||||
};
|
||||
.host = try client.allocator.dupe(u8, host),
|
||||
.port = port,
|
||||
} };
|
||||
|
||||
switch (protocol) {
|
||||
.plain => {},
|
||||
.tls => {
|
||||
conn.tls_client = try std.crypto.tls.Client.init(conn.stream, client.ca_bundle, host);
|
||||
conn.data.tls_client = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
conn.tls_client.allow_truncation_attacks = true;
|
||||
conn.data.tls_client.allow_truncation_attacks = true;
|
||||
},
|
||||
}
|
||||
|
||||
@ -908,10 +1014,15 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
|
||||
try h.appendSlice(@tagName(headers.version));
|
||||
try h.appendSlice("\r\nHost: ");
|
||||
try h.appendSlice(host);
|
||||
try h.appendSlice("\r\nConnection: close\r\n\r\n");
|
||||
if (headers.connection_close) {
|
||||
try h.appendSlice("\r\nConnection: close");
|
||||
} else {
|
||||
try h.appendSlice("\r\nConnection: keep-alive");
|
||||
}
|
||||
try h.appendSlice("\r\n\r\n");
|
||||
|
||||
const header_bytes = h.slice();
|
||||
try req.connection.writeAll(header_bytes);
|
||||
try req.connection.data.writeAll(header_bytes);
|
||||
}
|
||||
|
||||
return req;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user