From 1afeada2d95e50efe651bd6227719ca4003dad96 Mon Sep 17 00:00:00 2001 From: Nameless Date: Mon, 2 Oct 2023 19:57:43 -0500 Subject: [PATCH 01/12] std.http.Client: enhance proxy support adds connectTunnel to form a HTTP CONNECT tunnel to the desired host. Primarily implemented for proxies, but like connectUnix may be called by any user. adds loadDefaultProxies to load proxy information from common environment variables (http_proxy, HTTP_PROXY, https_proxy, HTTPS_PROXY, all_proxy, ALL_PROXY). - no_proxy and NO_PROXY are currently unsupported. splits proxy into http_proxy and https_proxy, adds headers field for arbitrary headers to each proxy. --- lib/std/Uri.zig | 112 +++++++++---- lib/std/http/Client.zig | 328 +++++++++++++++++++++++++++++++-------- test/standalone/http.zig | 34 ++-- 3 files changed, 358 insertions(+), 116 deletions(-) diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 6952839e71..0a98c5b641 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -208,24 +208,45 @@ pub fn parseWithoutScheme(text: []const u8) ParseError!Uri { return uri; } -pub fn format( +pub const WriteToStreamOptions = struct { + /// When true, include the scheme part of the URI. + scheme: bool = false, + + /// When true, include the user and password part of the URI. Ignored if `authority` is false. + authentication: bool = false, + + /// When true, include the authority part of the URI. + authority: bool = false, + + /// When true, include the path part of the URI. + path: bool = false, + + /// When true, include the query part of the URI. Ignored when `path` is false. + query: bool = false, + + /// When true, include the fragment part of the URI. Ignored when `path` is false. + fragment: bool = false, + + /// When true, do not escape any part of the URI. + raw: bool = false, +}; + +pub fn writeToStream( uri: Uri, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, + options: WriteToStreamOptions, writer: anytype, ) @TypeOf(writer).Error!void { - _ = options; - - const needs_absolute = comptime std.mem.indexOf(u8, fmt, "+") != null; - const needs_path = comptime std.mem.indexOf(u8, fmt, "/") != null or fmt.len == 0; - const raw_uri = comptime std.mem.indexOf(u8, fmt, "r") != null; - const needs_fragment = comptime std.mem.indexOf(u8, fmt, "#") != null; - - if (needs_absolute) { + if (options.scheme) { try writer.writeAll(uri.scheme); try writer.writeAll(":"); - if (uri.host) |host| { + + if (options.authority and uri.host != null) { try writer.writeAll("//"); + } + } + + if (options.authority) { + if (options.authentication and uri.host != null) { if (uri.user) |user| { try writer.writeAll(user); if (uri.password) |password| { @@ -234,7 +255,9 @@ pub fn format( } try writer.writeAll("@"); } + } + if (uri.host) |host| { try writer.writeAll(host); if (uri.port) |port| { @@ -244,39 +267,62 @@ pub fn format( } } - if (needs_path) { + if (options.path) { if (uri.path.len == 0) { try writer.writeAll("/"); + } else if (options.raw) { + try writer.writeAll(uri.path); } else { - if (raw_uri) { - try writer.writeAll(uri.path); - } else { - try Uri.writeEscapedPath(writer, uri.path); - } + try writeEscapedPath(writer, uri.path); } - if (uri.query) |q| { + if (options.query) if (uri.query) |q| { try writer.writeAll("?"); - if (raw_uri) { + if (options.raw) { try writer.writeAll(q); } else { - try Uri.writeEscapedQuery(writer, q); + try writeEscapedQuery(writer, q); } - } + }; - if (needs_fragment) { - if (uri.fragment) |f| { - try writer.writeAll("#"); - if (raw_uri) { - try writer.writeAll(f); - } else { - try Uri.writeEscapedQuery(writer, f); - } + if (options.fragment) if (uri.fragment) |f| { + try writer.writeAll("#"); + if (options.raw) { + try writer.writeAll(f); + } else { + try writeEscapedQuery(writer, f); } - } + }; } } +pub fn format( + uri: Uri, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + writer: anytype, +) @TypeOf(writer).Error!void { + _ = options; + + const scheme = comptime std.mem.indexOf(u8, fmt, ":") != null or fmt.len == 0; + const authentication = comptime std.mem.indexOf(u8, fmt, "@") != null or fmt.len == 0; + const authority = comptime std.mem.indexOf(u8, fmt, "+") != null or fmt.len == 0; + const path = comptime std.mem.indexOf(u8, fmt, "/") != null or fmt.len == 0; + const query = comptime std.mem.indexOf(u8, fmt, "?") != null or fmt.len == 0; + const fragment = comptime std.mem.indexOf(u8, fmt, "#") != null or fmt.len == 0; + const raw = comptime std.mem.indexOf(u8, fmt, "r") != null or fmt.len == 0; + + return writeToStream(uri, .{ + .scheme = scheme, + .authentication = authentication, + .authority = authority, + .path = path, + .query = query, + .fragment = fragment, + .raw = raw, + }, writer); +} + /// Parses the URI or returns an error. /// The return value will contain unescaped strings pointing into the /// original `text`. Each component that is provided, will be non-`null`. @@ -709,7 +755,7 @@ test "URI query escaping" { const parsed = try Uri.parse(address); // format the URI to escape it - const formatted_uri = try std.fmt.allocPrint(std.testing.allocator, "{}", .{parsed}); + const formatted_uri = try std.fmt.allocPrint(std.testing.allocator, "{/?}", .{parsed}); defer std.testing.allocator.free(formatted_uri); try std.testing.expectEqualStrings("/?response-content-type=application%2Foctet-stream", formatted_uri); } @@ -727,6 +773,6 @@ test "format" { }; var buf = std.ArrayList(u8).init(std.testing.allocator); defer buf.deinit(); - try uri.format("+/", .{}, buf.writer()); + try uri.format(":/?#", .{}, buf.writer()); try std.testing.expectEqualSlices(u8, "file:/foo/bar/baz", buf.items); } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 9e475df51b..d97334fe4c 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -18,6 +18,7 @@ pub const connection_pool_size = std.options.http_connection_pool_size; allocator: Allocator, ca_bundle: std.crypto.Certificate.Bundle = .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, + /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, @@ -25,7 +26,11 @@ next_https_rescan_certs: bool = true, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, -proxy: ?HttpProxy = null, +/// This is the proxy that will handle http:// connections. It *must not* be modified when the client has any active connections. +http_proxy: ?ProxyInformation = null, + +/// This is the proxy that will handle https:// connections. It *must not* be modified when the client has any active connections. +https_proxy: ?ProxyInformation = null, /// A set of linked lists of connections that can be reused. pub const ConnectionPool = struct { @@ -33,7 +38,7 @@ pub const ConnectionPool = struct { pub const Criteria = struct { host: []const u8, port: u16, - is_tls: bool, + protocol: Connection.Protocol, }; const Queue = std.DoublyLinkedList(Connection); @@ -55,9 +60,9 @@ pub const ConnectionPool = struct { var next = pool.free.last; while (next) |node| : (next = node.prev) { - if ((node.data.protocol == .tls) != criteria.is_tls) continue; + if (node.data.protocol != criteria.protocol) continue; if (node.data.port != criteria.port) continue; - if (!mem.eql(u8, node.data.host, criteria.host)) continue; + if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; pool.acquireUnsafe(node); return node; @@ -84,23 +89,23 @@ pub const ConnectionPool = struct { /// Tries to release a connection back to the connection pool. This function is threadsafe. /// If the connection is marked as closing, it will be closed instead. - pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void { + pub fn release(pool: *ConnectionPool, allocator: Allocator, node: *Node) void { pool.mutex.lock(); defer pool.mutex.unlock(); pool.used.remove(node); - if (node.data.closing) { - node.data.deinit(client); - return client.allocator.destroy(node); + if (node.data.closing or pool.free_size == 0) { + node.data.close(allocator); + return allocator.destroy(node); } if (pool.free_len >= pool.free_size) { const popped = pool.free.popFirst() orelse unreachable; pool.free_len -= 1; - popped.data.deinit(client); - client.allocator.destroy(popped); + popped.data.close(allocator); + allocator.destroy(popped); } if (node.data.proxied) { @@ -128,7 +133,7 @@ pub const ConnectionPool = struct { defer client.allocator.destroy(node); next = node.next; - node.data.deinit(client); + node.data.close(client.allocator); } next = pool.used.first; @@ -136,7 +141,7 @@ pub const ConnectionPool = struct { defer client.allocator.destroy(node); next = node.next; - node.data.deinit(client); + node.data.close(client.allocator); } pool.* = undefined; @@ -283,19 +288,15 @@ pub const Connection = struct { return Writer{ .context = conn }; } - pub fn close(conn: *Connection, client: *const Client) void { + pub fn close(conn: *Connection, allocator: Allocator) void { if (conn.protocol == .tls) { // try to cleanly close the TLS connection, for any server that cares. _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; - client.allocator.destroy(conn.tls_client); + allocator.destroy(conn.tls_client); } conn.stream.close(); - } - - pub fn deinit(conn: *Connection, client: *const Client) void { - conn.close(client); - client.allocator.free(conn.host); + allocator.free(conn.host); } }; @@ -490,7 +491,7 @@ pub const Request = struct { // If the response wasn't fully read, then we need to close the connection. connection.data.closing = true; } - req.client.connection_pool.release(req.client, connection); + req.client.connection_pool.release(req.client.allocator, connection); } req.arena.deinit(); @@ -509,7 +510,7 @@ pub const Request = struct { .zstd => |*zstd| zstd.deinit(), } - req.client.connection_pool.release(req.client, req.connection.?); + req.client.connection_pool.release(req.client.allocator, req.connection.?); req.connection = null; const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; @@ -554,24 +555,16 @@ pub const Request = struct { try w.writeByte(' '); if (req.method == .CONNECT) { - try w.writeAll(req.uri.host.?); - try w.writeByte(':'); - try w.print("{}", .{req.uri.port.?}); + try req.uri.writeToStream(.{ .authority = true }, w); } else { - if (req.connection.?.data.proxied) { - // proxied connections require the full uri - if (options.raw_uri) { - try w.print("{+/r}", .{req.uri}); - } else { - try w.print("{+/}", .{req.uri}); - } - } else { - if (options.raw_uri) { - try w.print("{/r}", .{req.uri}); - } else { - try w.print("{/}", .{req.uri}); - } - } + try req.uri.writeToStream(.{ + .scheme = req.connection.?.data.proxied, + .authentication = req.connection.?.data.proxied, + .authority = req.connection.?.data.proxied, + .path = true, + .query = true, + .raw = options.raw_uri, + }, w); } try w.writeByte(' '); try w.writeAll(@tagName(req.version)); @@ -579,7 +572,7 @@ pub const Request = struct { if (!req.headers.contains("host")) { try w.writeAll("Host: "); - try w.writeAll(req.uri.host.?); + try req.uri.writeToStream(.{ .authority = true }, w); try w.writeAll("\r\n"); } @@ -636,6 +629,24 @@ pub const Request = struct { try w.writeAll("\r\n"); } + if (req.connection.?.data.proxied) { + const proxy_headers: ?http.Headers = switch (req.connection.?.data.protocol) { + .plain => if (req.client.http_proxy) |proxy| proxy.headers else null, + .tls => if (req.client.https_proxy) |proxy| proxy.headers else null, + }; + + if (proxy_headers) |headers| { + for (headers.list.items) |entry| { + if (entry.value.len == 0) continue; + + try w.writeAll(entry.name); + try w.writeAll(": "); + try w.writeAll(entry.value); + try w.writeAll("\r\n"); + } + } + } + try w.writeAll("\r\n"); try buffered.flush(); @@ -893,18 +904,15 @@ pub const Request = struct { } }; -pub const HttpProxy = struct { - pub const ProxyAuthentication = union(enum) { - basic: []const u8, - custom: []const u8, - }; +pub const ProxyInformation = struct { + allocator: Allocator, + headers: http.Headers, protocol: Connection.Protocol, host: []const u8, - port: ?u16 = null, + port: u16, - /// The value for the Proxy-Authorization header. - auth: ?ProxyAuthentication = null, + supports_connect: bool = true, }; /// Release all associated resources with the client. @@ -912,19 +920,115 @@ pub const HttpProxy = struct { pub fn deinit(client: *Client) void { client.connection_pool.deinit(client); + if (client.http_proxy) |*proxy| { + proxy.allocator.free(proxy.host); + proxy.headers.deinit(); + } + + if (client.https_proxy) |*proxy| { + proxy.allocator.free(proxy.host); + proxy.headers.deinit(); + } + client.ca_bundle.deinit(client.allocator); client.* = undefined; } -pub const ConnectUnproxiedError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; +/// Uses the *_proxy environment variable to set any unset proxies for the client. +/// This function *must not* be called when the client has any active connections. +pub fn loadDefaultProxies(client: *Client) !void { + if (client.http_proxy == null) http: { + const content: []const u8 = if (std.process.hasEnvVarConstant("http_proxy")) + try std.process.getEnvVarOwned(client.allocator, "http_proxy") + else if (std.process.hasEnvVarConstant("HTTP_PROXY")) + try std.process.getEnvVarOwned(client.allocator, "HTTP_PROXY") + else if (std.process.hasEnvVarConstant("all_proxy")) + try std.process.getEnvVarOwned(client.allocator, "all_proxy") + else if (std.process.hasEnvVarConstant("ALL_PROXY")) + try std.process.getEnvVarOwned(client.allocator, "ALL_PROXY") + else + break :http; + defer client.allocator.free(content); + + const uri = try Uri.parse(content); + + const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; + client.http_proxy = .{ + .allocator = client.allocator, + .headers = .{ .allocator = client.allocator }, + + .protocol = protocol, + .host = if (uri.host) |host| try client.allocator.dupe(u8, host) else return error.UriMissingHost, + .port = uri.port orelse switch (protocol) { + .plain => 80, + .tls => 443, + }, + }; + + if (uri.user != null and uri.password != null) { + const unencoded = try std.fmt.allocPrint(client.allocator, "{s}:{s}", .{ uri.user.?, uri.password.? }); + defer client.allocator.free(unencoded); + + const buffer = try client.allocator.alloc(u8, std.base64.standard.Encoder.calcSize(unencoded.len)); + defer client.allocator.free(buffer); + + const result = std.base64.standard.Encoder.encode(buffer, unencoded); + + try client.http_proxy.?.headers.append("proxy-authorization", result); + } + } + + if (client.https_proxy == null) https: { + const content: []const u8 = if (std.process.hasEnvVarConstant("https_proxy")) + try std.process.getEnvVarOwned(client.allocator, "https_proxy") + else if (std.process.hasEnvVarConstant("HTTPS_PROXY")) + try std.process.getEnvVarOwned(client.allocator, "HTTPS_PROXY") + else if (std.process.hasEnvVarConstant("all_proxy")) + try std.process.getEnvVarOwned(client.allocator, "all_proxy") + else if (std.process.hasEnvVarConstant("ALL_PROXY")) + try std.process.getEnvVarOwned(client.allocator, "ALL_PROXY") + else + break :https; + defer client.allocator.free(content); + + const uri = try Uri.parse(content); + + const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; + client.http_proxy = .{ + .allocator = client.allocator, + .headers = .{ .allocator = client.allocator }, + + .protocol = protocol, + .host = if (uri.host) |host| try client.allocator.dupe(u8, host) else return error.UriMissingHost, + .port = uri.port orelse switch (protocol) { + .plain => 80, + .tls => 443, + }, + }; + + if (uri.user != null and uri.password != null) { + const unencoded = try std.fmt.allocPrint(client.allocator, "{s}:{s}", .{ uri.user.?, uri.password.? }); + defer client.allocator.free(unencoded); + + const buffer = try client.allocator.alloc(u8, std.base64.standard.Encoder.calcSize(unencoded.len)); + defer client.allocator.free(buffer); + + const result = std.base64.standard.Encoder.encode(buffer, unencoded); + + try client.https_proxy.?.headers.append("proxy-authorization", result); + } + } +} + +pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. /// This function is threadsafe. -pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectUnproxiedError!*ConnectionPool.Node { +pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*ConnectionPool.Node { if (client.connection_pool.findConnection(.{ .host = host, .port = port, - .is_tls = protocol == .tls, + .protocol = protocol, })) |node| return node; @@ -948,8 +1052,8 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: conn.data = .{ .stream = stream, .tls_client = undefined, - .protocol = protocol, + .protocol = protocol, .host = try client.allocator.dupe(u8, host), .port = port, }; @@ -981,7 +1085,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti if (client.connection_pool.findConnection(.{ .host = path, .port = 0, - .is_tls = false, + .protocol = .plain, })) |node| return node; @@ -1007,34 +1111,120 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti return conn; } -// Prevents a dependency loop in request() -const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused }; -pub const ConnectError = ConnectErrorPartial || RequestError; +pub fn connectTunnel( + client: *Client, + proxy: *ProxyInformation, + tunnel_host: []const u8, + tunnel_port: u16, +) !*ConnectionPool.Node { + if (!proxy.supports_connect) return error.TunnelNotSupported; -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node { if (client.connection_pool.findConnection(.{ - .host = host, - .port = port, - .is_tls = protocol == .tls, + .host = tunnel_host, + .port = tunnel_port, + .protocol = proxy.protocol, })) |node| return node; - if (client.proxy) |proxy| { - const proxy_port: u16 = proxy.port orelse switch (proxy.protocol) { - .plain => 80, - .tls => 443, + var maybe_valid = false; + _ = tunnel: { + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.data.closing = true; + client.connection_pool.release(client.allocator, conn); + } + + const uri = Uri{ + .scheme = "http", + .user = null, + .password = null, + .host = tunnel_host, + .port = tunnel_port, + .path = "", + .query = null, + .fragment = null, }; - const conn = try client.connectUnproxied(proxy.host, proxy_port, proxy.protocol); - conn.data.proxied = true; + // we can use a small buffer here because a CONNECT response should be very small + var buffer: [8096]u8 = undefined; + + var req = client.request(.CONNECT, uri, proxy.headers, .{ + .handle_redirects = false, + .connection = conn, + .header_strategy = .{ .static = buffer[0..] }, + }) catch |err| { + std.log.debug("err {}", .{err}); + break :tunnel err; + }; + defer req.deinit(); + + req.start(.{ .raw_uri = true }) catch |err| break :tunnel err; + req.wait() catch |err| break :tunnel err; + + if (req.response.status.class() == .server_error) { + maybe_valid = true; + break :tunnel error.ServerError; + } + + if (req.response.status != .ok) break :tunnel error.ConnectionRefused; + + // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. + req.connection = null; + + client.allocator.free(conn.data.host); + conn.data.host = try client.allocator.dupe(u8, tunnel_host); + errdefer client.allocator.free(conn.data.host); + + conn.data.port = tunnel_port; + conn.data.closing = false; return conn; - } else { - return client.connectUnproxied(host, port, protocol); - } + } catch { + // something went wrong with the tunnel + proxy.supports_connect = maybe_valid; + return error.TunnelNotSupported; + }; } -pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || Connection.WriteError || error{ +// Prevents a dependency loop in request() +const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, ConnectionRefused }; +pub const ConnectError = ConnectErrorPartial || RequestError; + +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node { + // pointer required so that `supports_connect` can be updated if a CONNECT fails + const potential_proxy: ?*ProxyInformation = switch (protocol) { + .plain => if (client.http_proxy) |*proxy_info| proxy_info else null, + .tls => if (client.https_proxy) |*proxy_info| proxy_info else null, + }; + + if (potential_proxy) |proxy| { + // don't attempt to proxy the proxy thru itself. + if (std.mem.eql(u8, proxy.host, host) and proxy.port == port and proxy.protocol == protocol) { + return client.connectTcp(host, port, protocol); + } + + _ = if (proxy.supports_connect) tunnel: { + return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + error.TunnelNotSupported => break :tunnel, + else => |e| return e, + }; + }; + + // fall back to using the proxy as a normal http proxy + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.data.closing = true; + client.connection_pool.release(conn); + } + + conn.data.proxied = true; + return conn; + } + + return client.connectTcp(host, port, protocol); +} + +pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || Connection.WriteError || error{ UnsupportedUrlScheme, UriMissingHost, diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 00fd4397b0..8b538a092f 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -226,8 +226,11 @@ pub fn main() !void { const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server}); var client = Client{ .allocator = calloc }; + errdefer client.deinit(); // defer client.deinit(); handled below + try client.loadDefaultProxies(); + { // read content-length response var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -251,7 +254,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // read large content-length response var h = http.Headers{ .allocator = calloc }; @@ -275,7 +278,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // send head request and not read chunked var h = http.Headers{ .allocator = calloc }; @@ -301,7 +304,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // read chunked response var h = http.Headers{ .allocator = calloc }; @@ -326,7 +329,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // send head request and not read chunked var h = http.Headers{ .allocator = calloc }; @@ -352,7 +355,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // check trailing headers var h = http.Headers{ .allocator = calloc }; @@ -377,7 +380,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // send content-length request var h = http.Headers{ .allocator = calloc }; @@ -409,7 +412,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // read content-length response with connection close var h = http.Headers{ .allocator = calloc }; @@ -468,7 +471,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // relative redirect var h = http.Headers{ .allocator = calloc }; @@ -492,7 +495,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // redirect from root var h = http.Headers{ .allocator = calloc }; @@ -516,7 +519,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // absolute redirect var h = http.Headers{ .allocator = calloc }; @@ -540,7 +543,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // too many redirects var h = http.Headers{ .allocator = calloc }; @@ -562,7 +565,7 @@ pub fn main() !void { } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // check client without segfault by connection error after redirection var h = http.Headers{ .allocator = calloc }; @@ -579,11 +582,14 @@ pub fn main() !void { try req.start(.{}); const result = req.wait(); - try testing.expectError(error.ConnectionRefused, result); // expects not segfault but the regular error + // a proxy without an upstream is likely to return a 5xx status. + if (client.http_proxy == null) { + try testing.expectError(error.ConnectionRefused, result); // expects not segfault but the regular error + } } // connection has been kept alive - try testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); { // Client.fetch() var h = http.Headers{ .allocator = calloc }; From e1c37f70d4ae9a7bfa6de92dcb26e7cfdffc17c2 Mon Sep 17 00:00:00 2001 From: Nameless Date: Tue, 3 Oct 2023 14:26:06 -0500 Subject: [PATCH 02/12] std.http.Client: store *Connection instead of a pool node, buffer writes --- lib/std/crypto/tls/Client.zig | 2 +- lib/std/http/Client.zig | 198 ++++++++++++++++++---------------- lib/std/http/protocol.zig | 6 +- test/standalone/http.zig | 2 +- 4 files changed, 111 insertions(+), 97 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 37306dd37f..7671d06469 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -881,7 +881,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { /// The `iovecs` parameter is mutable because this function needs to mutate the fields in /// order to handle partial reads from the underlying stream layer. pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { - return readvAtLeast(c, stream, iovecs); + return readvAtLeast(c, stream, iovecs, 1); } /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index d97334fe4c..55ae62a183 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -54,7 +54,7 @@ pub const ConnectionPool = struct { /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. /// If no connection is found, null is returned. - pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node { + pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -65,7 +65,7 @@ pub const ConnectionPool = struct { if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; pool.acquireUnsafe(node); - return node; + return &node.data; } return null; @@ -89,10 +89,12 @@ pub const ConnectionPool = struct { /// Tries to release a connection back to the connection pool. This function is threadsafe. /// If the connection is marked as closing, it will be closed instead. - pub fn release(pool: *ConnectionPool, allocator: Allocator, node: *Node) void { + pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { pool.mutex.lock(); defer pool.mutex.unlock(); + const node = @fieldParentPtr(Node, "data", connection); + pool.used.remove(node); if (node.data.closing or pool.free_size == 0) { @@ -151,6 +153,8 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + const BufferSize = std.math.IntFittingRange(0, buffer_size); + pub const Protocol = enum { plain, tls }; stream: net.Stream, @@ -164,14 +168,16 @@ pub const Connection = struct { proxied: bool = false, closing: bool = false, - read_start: u16 = 0, - read_end: u16 = 0, + read_start: BufferSize = 0, + read_end: BufferSize = 0, + write_end: BufferSize = 0, read_buf: [buffer_size]u8 = undefined, + write_buf: [buffer_size]u8 = undefined, - pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { + pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { return switch (conn.protocol) { - .plain => conn.stream.readAtLeast(buffer, len), - .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), + .plain => conn.stream.readv(buffers), + .tls => conn.tls_client.readv(conn.stream, buffers), } catch |err| { // TODO: https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; @@ -188,58 +194,52 @@ pub const Connection = struct { pub fn fill(conn: *Connection) ReadError!void { if (conn.read_end != conn.read_start) return; - const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); + var iovecs = [1]std.os.iovec{ + .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); if (nread == 0) return error.EndOfStream; conn.read_start = 0; - conn.read_end = @as(u16, @intCast(nread)); + conn.read_end = @intCast(nread); } pub fn peek(conn: *Connection) []const u8 { return conn.read_buf[conn.read_start..conn.read_end]; } - pub fn drop(conn: *Connection, num: u16) void { + pub fn drop(conn: *Connection, num: BufferSize) void { conn.read_start += num; } - pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len; - var out_index: u16 = 0; - while (out_index < len) { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len - out_index; + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + conn.read_start += @intCast(available_buffer); - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - out_index += @as(u16, @intCast(available_buffer)); - conn.read_start += @as(u16, @intCast(available_buffer)); + return available_buffer; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + conn.read_start += available_read; - break; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - out_index += available_read; - conn.read_start += available_read; - - if (out_index >= len) break; - } - - const leftover_buffer = available_buffer - available_read; - const leftover_len = len - out_index; - - if (leftover_buffer > conn.read_buf.len) { - // skip the buffer if the output is large enough - return conn.rawReadAtLeast(buffer[out_index..], leftover_len); - } - - try conn.fill(); + return available_read; } - return out_index; - } + var iovecs = [2]std.os.iovec{ + .{ .iov_base = buffer.ptr, .iov_len = buffer.len }, + .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); + if (nread > buffer.len) { + conn.read_start = 0; + conn.read_end = @intCast(nread - buffer.len); + return buffer.len; + } + + return nread; } pub const ReadError = error{ @@ -257,7 +257,7 @@ pub const Connection = struct { return Reader{ .context = conn }; } - pub fn writeAll(conn: *Connection, buffer: []const u8) !void { + pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { return switch (conn.protocol) { .plain => conn.stream.writeAll(buffer), .tls => conn.tls_client.writeAll(conn.stream, buffer), @@ -267,14 +267,27 @@ pub const Connection = struct { }; } - pub fn write(conn: *Connection, buffer: []const u8) !usize { - return switch (conn.protocol) { - .plain => conn.stream.write(buffer), - .tls => conn.tls_client.write(conn.stream, buffer), - } catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; + pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { + if (conn.write_end + buffer.len > conn.write_buf.len) { + try conn.flush(); + + if (buffer.len > conn.write_buf.len) { + try conn.writeAllDirect(buffer); + return buffer.len; + } + } + + @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); + conn.write_end += @intCast(buffer.len); + + return buffer.len; + } + + pub fn flush(conn: *Connection) WriteError!void { + if (conn.write_end == 0) return; + + try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); + conn.write_end = 0; } pub const WriteError = error{ @@ -455,7 +468,7 @@ pub const Request = struct { uri: Uri, client: *Client, /// is null when this connection is released - connection: ?*ConnectionPool.Node, + connection: ?*Connection, method: http.Method, version: http.Version = .@"HTTP/1.1", @@ -489,7 +502,7 @@ pub const Request = struct { if (req.connection) |connection| { if (!req.response.parser.done) { // If the response wasn't fully read, then we need to close the connection. - connection.data.closing = true; + connection.closing = true; } req.client.connection_pool.release(req.client.allocator, connection); } @@ -548,8 +561,7 @@ pub const Request = struct { pub fn start(req: *Request, options: StartOptions) StartError!void { if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding; - var buffered = std.io.bufferedWriter(req.connection.?.data.writer()); - const w = buffered.writer(); + const w = req.connection.?.writer(); try req.method.write(w); try w.writeByte(' '); @@ -558,9 +570,9 @@ pub const Request = struct { try req.uri.writeToStream(.{ .authority = true }, w); } else { try req.uri.writeToStream(.{ - .scheme = req.connection.?.data.proxied, - .authentication = req.connection.?.data.proxied, - .authority = req.connection.?.data.proxied, + .scheme = req.connection.?.proxied, + .authentication = req.connection.?.proxied, + .authority = req.connection.?.proxied, .path = true, .query = true, .raw = options.raw_uri, @@ -629,8 +641,8 @@ pub const Request = struct { try w.writeAll("\r\n"); } - if (req.connection.?.data.proxied) { - const proxy_headers: ?http.Headers = switch (req.connection.?.data.protocol) { + if (req.connection.?.proxied) { + const proxy_headers: ?http.Headers = switch (req.connection.?.protocol) { .plain => if (req.client.http_proxy) |proxy| proxy.headers else null, .tls => if (req.client.https_proxy) |proxy| proxy.headers else null, }; @@ -649,7 +661,7 @@ pub const Request = struct { try w.writeAll("\r\n"); - try buffered.flush(); + try req.connection.?.flush(); } const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; @@ -665,7 +677,7 @@ pub const Request = struct { var index: usize = 0; while (index == 0) { - const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip); + const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); if (amt == 0 and req.response.parser.done) break; index += amt; } @@ -683,10 +695,10 @@ pub const Request = struct { pub fn wait(req: *Request) WaitError!void { while (true) { // handle redirects while (true) { // read headers - try req.connection.?.data.fill(); + try req.connection.?.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek()); - req.connection.?.data.drop(@as(u16, @intCast(nchecked))); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); if (req.response.parser.state.isContent()) break; } @@ -701,7 +713,7 @@ pub const Request = struct { // we're switching protocols, so this connection is no longer doing http if (req.response.status == .switching_protocols or (req.method == .CONNECT and req.response.status == .ok)) { - req.connection.?.data.closing = false; + req.connection.?.closing = false; req.response.parser.done = true; } @@ -712,9 +724,9 @@ pub const Request = struct { const res_connection = req.response.headers.getFirstValue("connection"); const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); if (res_keepalive and (req_keepalive or req_connection == null)) { - req.connection.?.data.closing = false; + req.connection.?.closing = false; } else { - req.connection.?.data.closing = true; + req.connection.?.closing = true; } if (req.response.transfer_encoding) |te| { @@ -827,10 +839,10 @@ pub const Request = struct { const has_trail = !req.response.parser.state.isContent(); while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.data.fill(); + try req.connection.?.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek()); - req.connection.?.data.drop(@as(u16, @intCast(nchecked))); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); } if (has_trail) { @@ -868,16 +880,16 @@ pub const Request = struct { pub fn write(req: *Request, bytes: []const u8) WriteError!usize { switch (req.transfer_encoding) { .chunked => { - try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.data.writeAll(bytes); - try req.connection.?.data.writeAll("\r\n"); + try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.?.writer().writeAll(bytes); + try req.connection.?.writer().writeAll("\r\n"); return bytes.len; }, .content_length => |*len| { if (len.* < bytes.len) return error.MessageTooLong; - const amt = try req.connection.?.data.write(bytes); + const amt = try req.connection.?.write(bytes); len.* -= amt; return amt; }, @@ -897,10 +909,12 @@ pub const Request = struct { /// Finish the body of a request. This notifies the server that you have no more data to send. pub fn finish(req: *Request) FinishError!void { switch (req.transfer_encoding) { - .chunked => try req.connection.?.data.writeAll("0\r\n\r\n"), + .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } + + try req.connection.?.flush(); } }; @@ -1024,7 +1038,7 @@ pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, Network /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. /// This function is threadsafe. -pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*ConnectionPool.Node { +pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { if (client.connection_pool.findConnection(.{ .host = host, .port = port, @@ -1074,12 +1088,12 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec client.connection_pool.addUsed(conn); - return conn; + return &conn.data; } pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{ NameTooLong, Unsupported } || std.os.ConnectError; -pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*ConnectionPool.Node { +pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { if (!net.has_unix_sockets) return error.Unsupported; if (client.connection_pool.findConnection(.{ @@ -1108,7 +1122,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti client.connection_pool.addUsed(conn); - return conn; + return &conn.data; } pub fn connectTunnel( @@ -1116,7 +1130,7 @@ pub fn connectTunnel( proxy: *ProxyInformation, tunnel_host: []const u8, tunnel_port: u16, -) !*ConnectionPool.Node { +) !*Connection { if (!proxy.supports_connect) return error.TunnelNotSupported; if (client.connection_pool.findConnection(.{ @@ -1130,7 +1144,7 @@ pub fn connectTunnel( _ = tunnel: { const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); errdefer { - conn.data.closing = true; + conn.closing = true; client.connection_pool.release(client.allocator, conn); } @@ -1171,12 +1185,12 @@ pub fn connectTunnel( // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. req.connection = null; - client.allocator.free(conn.data.host); - conn.data.host = try client.allocator.dupe(u8, tunnel_host); - errdefer client.allocator.free(conn.data.host); + client.allocator.free(conn.host); + conn.host = try client.allocator.dupe(u8, tunnel_host); + errdefer client.allocator.free(conn.host); - conn.data.port = tunnel_port; - conn.data.closing = false; + conn.port = tunnel_port; + conn.closing = false; return conn; } catch { @@ -1190,7 +1204,7 @@ pub fn connectTunnel( const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, ConnectionRefused }; pub const ConnectError = ConnectErrorPartial || RequestError; -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node { +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*Connection { // pointer required so that `supports_connect` can be updated if a CONNECT fails const potential_proxy: ?*ProxyInformation = switch (protocol) { .plain => if (client.http_proxy) |*proxy_info| proxy_info else null, @@ -1213,11 +1227,11 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio // fall back to using the proxy as a normal http proxy const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); errdefer { - conn.data.closing = true; + conn.closing = true; client.connection_pool.release(conn); } - conn.data.proxied = true; + conn.proxied = true; return conn; } @@ -1240,7 +1254,7 @@ pub const RequestOptions = struct { header_strategy: StorageStrategy = .{ .dynamic = 16 * 1024 }, /// Must be an already acquired connection. - connection: ?*ConnectionPool.Node = null, + connection: ?*Connection = null, pub const StorageStrategy = union(enum) { /// In this case, the client's Allocator will be used to store the diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index a369c38581..74e0207f34 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -529,7 +529,7 @@ pub const HeadersParser = struct { try conn.fill(); const nread = @min(conn.peek().len, data_avail); - conn.drop(@as(u16, @intCast(nread))); + conn.drop(@intCast(nread)); r.next_chunk_length -= nread; if (r.next_chunk_length == 0) r.done = true; @@ -553,7 +553,7 @@ pub const HeadersParser = struct { try conn.fill(); const i = r.findChunkedLen(conn.peek()); - conn.drop(@as(u16, @intCast(i))); + conn.drop(@intCast(i)); switch (r.state) { .invalid => return error.HttpChunkInvalid, @@ -582,7 +582,7 @@ pub const HeadersParser = struct { try conn.fill(); const nread = @min(conn.peek().len, data_avail); - conn.drop(@as(u16, @intCast(nread))); + conn.drop(@intCast(nread)); r.next_chunk_length -= nread; } else if (out_avail > 0) { const can_read: usize = @intCast(@min(data_avail, out_avail)); diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 8b538a092f..71f3481767 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -680,7 +680,7 @@ pub fn main() !void { for (0..total_connections) |i| { var req = try client.request(.GET, uri, .{ .allocator = calloc }, .{}); req.response.parser.done = true; - req.connection.?.data.closing = false; + req.connection.?.closing = false; requests[i] = req; } From 0eef21d8ec290564ab503e5ad25f4c0c86f04d45 Mon Sep 17 00:00:00 2001 From: Nameless Date: Thu, 5 Oct 2023 12:19:06 -0500 Subject: [PATCH 03/12] std.http.Client: add option to disable https std_options.http_connection_pool_size removed in favor of ``` client.connection_pool.resize(client.allocator, size); ``` std_options.http_disable_tls will remove all https capability from std.http when true. Any https request will error with `error.TlsInitializationFailed`. Solves #17051. --- lib/std/http/Client.zig | 102 ++++++++++++++++++++++++++++----------- lib/std/std.zig | 11 +++-- test/standalone/http.zig | 4 ++ 3 files changed, 85 insertions(+), 32 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 55ae62a183..4107cfdcc8 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -12,8 +12,7 @@ const assert = std.debug.assert; const Client = @This(); const proto = @import("protocol.zig"); -pub const default_connection_pool_size = 32; -pub const connection_pool_size = std.options.http_connection_pool_size; +pub const disable_tls = std.options.http_disable_tls; allocator: Allocator, ca_bundle: std.crypto.Certificate.Bundle = .{}, @@ -50,7 +49,7 @@ pub const ConnectionPool = struct { /// Open connections that are not currently in use. free: Queue = .{}, free_len: usize = 0, - free_size: usize = connection_pool_size, + free_size: usize = 32, /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. /// If no connection is found, null is returned. @@ -127,23 +126,43 @@ pub const ConnectionPool = struct { pool.used.append(node); } - pub fn deinit(pool: *ConnectionPool, client: *Client) void { + /// Resizes the connection pool. This function is threadsafe. + /// + /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + var next = pool.free.first; + _ = next; + while (pool.free_len > new_size) { + const popped = pool.free.popFirst() orelse unreachable; + pool.free_len -= 1; + + popped.data.close(allocator); + allocator.destroy(popped); + } + + pool.free_size = new_size; + } + + pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { pool.mutex.lock(); var next = pool.free.first; while (next) |node| { - defer client.allocator.destroy(node); + defer allocator.destroy(node); next = node.next; - node.data.close(client.allocator); + node.data.close(allocator); } next = pool.used.first; while (next) |node| { - defer client.allocator.destroy(node); + defer allocator.destroy(node); next = node.next; - node.data.close(client.allocator); + node.data.close(allocator); } pool.* = undefined; @@ -159,7 +178,7 @@ pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: *std.crypto.tls.Client, + tls_client: if (!disable_tls) *std.crypto.tls.Client else void, protocol: Protocol, host: []u8, @@ -174,11 +193,8 @@ pub const Connection = struct { read_buf: [buffer_size]u8 = undefined, write_buf: [buffer_size]u8 = undefined, - pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { - return switch (conn.protocol) { - .plain => conn.stream.readv(buffers), - .tls => conn.tls_client.readv(conn.stream, buffers), - } catch |err| { + pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { + return conn.tls_client.readv(conn.stream, buffers) catch |err| { // TODO: https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; @@ -191,6 +207,20 @@ pub const Connection = struct { }; } + pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.readvDirectTls(buffers); + } + + return conn.stream.readv(buffers) catch |err| switch (err) { + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + }; + } + pub fn fill(conn: *Connection) ReadError!void { if (conn.read_end != conn.read_start) return; @@ -257,11 +287,21 @@ pub const Connection = struct { return Reader{ .context = conn }; } + pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { + return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; + } + pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - return switch (conn.protocol) { - .plain => conn.stream.writeAll(buffer), - .tls => conn.tls_client.writeAll(conn.stream, buffer), - } catch |err| switch (err) { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.writeAllDirectTls(buffer); + } + + return conn.stream.writeAll(buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; @@ -303,6 +343,8 @@ pub const Connection = struct { pub fn close(conn: *Connection, allocator: Allocator) void { if (conn.protocol == .tls) { + if (disable_tls) unreachable; + // try to cleanly close the TLS connection, for any server that cares. _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; allocator.destroy(conn.tls_client); @@ -932,7 +974,7 @@ pub const ProxyInformation = struct { /// Release all associated resources with the client. /// TODO: currently leaks all request allocated data pub fn deinit(client: *Client) void { - client.connection_pool.deinit(client); + client.connection_pool.deinit(client.allocator); if (client.http_proxy) |*proxy| { proxy.allocator.free(proxy.host); @@ -1046,6 +1088,9 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec })) |node| return node; + if (disable_tls and protocol == .tls) + return error.TlsInitializationFailed; + const conn = try client.allocator.create(ConnectionPool.Node); errdefer client.allocator.destroy(conn); conn.* = .{ .data = undefined }; @@ -1073,17 +1118,16 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec }; errdefer client.allocator.free(conn.data.host); - switch (protocol) { - .plain => {}, - .tls => { - conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(conn.data.tls_client); + if (protocol == .tls) { + if (disable_tls) unreachable; - conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.data.tls_client.allow_truncation_attacks = true; - }, + conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + errdefer client.allocator.destroy(conn.data.tls_client); + + conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + conn.data.tls_client.allow_truncation_attacks = true; } client.connection_pool.addUsed(conn); diff --git a/lib/std/std.zig b/lib/std/std.zig index 16222e52da..5829a241c4 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -283,10 +283,15 @@ pub const options = struct { else false; - pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size")) - options_override.http_connection_pool_size + /// By default, std.http.Client will support HTTPS connections. Set this option to `true` to + /// disable TLS support. + /// + /// This will likely reduce the size of the binary, but it will also make it impossible to + /// make a HTTPS connection. + pub const http_disable_tls = if (@hasDecl(options_override, "http_disable_tls")) + options_override.http_disable_tls else - http.Client.default_connection_pool_size; + false; pub const side_channels_mitigations: crypto.SideChannelsMitigations = if (@hasDecl(options_override, "side_channels_mitigations")) options_override.side_channels_mitigations diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 71f3481767..55a8456fde 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -7,6 +7,10 @@ const Client = http.Client; const mem = std.mem; const testing = std.testing; +pub const std_options = struct { + pub const http_disable_tls = true; +}; + const max_header_size = 8192; var gpa_server = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){}; From e11a8397602b23997e9ebfbd1b865dab4f6a18d1 Mon Sep 17 00:00:00 2001 From: Nameless Date: Thu, 5 Oct 2023 12:29:49 -0500 Subject: [PATCH 04/12] std.http: use loadDefaultProxies in compiler --- lib/std/std.zig | 2 +- src/main.zig | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/std/std.zig b/lib/std/std.zig index 5829a241c4..7342cadbef 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -285,7 +285,7 @@ pub const options = struct { /// By default, std.http.Client will support HTTPS connections. Set this option to `true` to /// disable TLS support. - /// + /// /// This will likely reduce the size of the binary, but it will also make it impossible to /// make a HTTPS connection. pub const http_disable_tls = if (@hasDecl(options_override, "http_disable_tls")) diff --git a/src/main.zig b/src/main.zig index 07e7d088ea..8bf2cc1dd0 100644 --- a/src/main.zig +++ b/src/main.zig @@ -5128,6 +5128,8 @@ pub fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !voi var http_client: std.http.Client = .{ .allocator = gpa }; defer http_client.deinit(); + try http_client.loadDefaultProxies(); + var progress: std.Progress = .{ .dont_print_on_dumb = true }; const root_prog_node = progress.start("Fetch Packages", 0); defer root_prog_node.end(); @@ -7039,6 +7041,8 @@ fn cmdFetch( var http_client: std.http.Client = .{ .allocator = gpa }; defer http_client.deinit(); + try http_client.loadDefaultProxies(); + var progress: std.Progress = .{ .dont_print_on_dumb = true }; const root_prog_node = progress.start("Fetch", 0); defer root_prog_node.end(); From 16f89eab45c4356f8f1d824c342e76aee8ddff90 Mon Sep 17 00:00:00 2001 From: Nameless Date: Thu, 5 Oct 2023 13:59:16 -0500 Subject: [PATCH 05/12] std.http.Client: make transfer-encoding priority over content-length as per spec --- lib/std/http/Client.zig | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 4107cfdcc8..fe085cfeca 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -420,13 +420,7 @@ pub const Response = struct { if (trailing) continue; - if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - - if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; - - res.content_length = content_length; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { // Transfer-Encoding: second, first // Transfer-Encoding: deflate, chunked var iter = mem.splitBackwardsScalar(u8, header_value, ','); @@ -458,6 +452,12 @@ pub const Response = struct { } if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; + + if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; + + res.content_length = content_length; } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { if (res.transfer_compression != null) return error.HttpHeadersInvalid; @@ -658,17 +658,17 @@ pub const Request = struct { .none => {}, } } else { - if (has_content_length) { - const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; - - req.transfer_encoding = .{ .content_length = content_length }; - } else if (has_transfer_encoding) { + if (has_transfer_encoding) { const transfer_encoding = req.headers.getFirstValue("transfer-encoding").?; if (std.mem.eql(u8, transfer_encoding, "chunked")) { req.transfer_encoding = .chunked; } else { return error.UnsupportedTransferEncoding; } + } else if (has_content_length) { + const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; + + req.transfer_encoding = .{ .content_length = content_length }; } else { req.transfer_encoding = .none; } From c523b5421be86cdd0591a5672eeaea30fb142fe4 Mon Sep 17 00:00:00 2001 From: Nameless Date: Fri, 6 Oct 2023 21:38:05 -0500 Subject: [PATCH 06/12] std.http: make encoding fields non-null, store as enum variant --- lib/std/http.zig | 3 +++ lib/std/http/Client.zig | 49 +++++++++++++++++++-------------------- lib/std/http/Server.zig | 51 +++++++++++++++++++---------------------- 3 files changed, 50 insertions(+), 53 deletions(-) diff --git a/lib/std/http.zig b/lib/std/http.zig index 532e5a6de8..3f7af6b6e3 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -289,14 +289,17 @@ pub const Status = enum(u10) { pub const TransferEncoding = enum { chunked, + none, // compression is intentionally omitted here, as std.http.Client stores it as content-encoding }; pub const ContentEncoding = enum { identity, compress, + @"x-compress", deflate, gzip, + @"x-gzip", zstd, }; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index fe085cfeca..558c1e18dd 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -425,27 +425,23 @@ pub const Response = struct { // Transfer-Encoding: deflate, chunked var iter = mem.splitBackwardsScalar(u8, header_value, ','); - if (iter.next()) |first| { - const trimmed = mem.trim(u8, first, " "); + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); - if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { - if (res.transfer_encoding != null) return error.HttpHeadersInvalid; - res.transfer_encoding = te; - } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - if (res.transfer_compression != null) return error.HttpHeadersInvalid; - res.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + res.transfer_encoding = transfer; + + next = iter.next(); } - if (iter.next()) |second| { - if (res.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); - const trimmed = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - res.transfer_compression = ce; + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported + res.transfer_compression = transfer; } else { return error.HttpTransferEncodingUnsupported; } @@ -459,7 +455,7 @@ pub const Response = struct { res.content_length = content_length; } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (res.transfer_compression != null) return error.HttpHeadersInvalid; + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; const trimmed = mem.trim(u8, header_value, " "); @@ -494,8 +490,8 @@ pub const Response = struct { reason: []const u8, content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - transfer_compression: ?http.ContentEncoding = null, + transfer_encoding: http.TransferEncoding = .none, + transfer_compression: http.ContentEncoding = .identity, headers: http.Headers, parser: proto.HeadersParser, @@ -771,8 +767,9 @@ pub const Request = struct { req.connection.?.closing = true; } - if (req.response.transfer_encoding) |te| { - switch (te) { + if (req.response.transfer_encoding != .none) { + switch (req.response.transfer_encoding) { + .none => unreachable, .chunked => { req.response.parser.next_chunk_length = 0; req.response.parser.state = .chunk_head_size; @@ -840,19 +837,19 @@ pub const Request = struct { } else { req.response.skip = false; if (!req.response.parser.done) { - if (req.response.transfer_compression) |tc| switch (tc) { + switch (req.response.transfer_compression) { .identity => req.response.compression = .none, - .compress => return error.CompressionNotSupported, + .compress, .@"x-compress" => return error.CompressionNotSupported, .deflate => req.response.compression = .{ .deflate = std.compress.zlib.decompressStream(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed, }, - .gzip => req.response.compression = .{ + .gzip, .@"x-gzip" => req.response.compression = .{ .gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed, }, .zstd => req.response.compression = .{ .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), }, - }; + } } break; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index a0c4774e06..da53b2b05d 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -228,27 +228,23 @@ pub const Request = struct { // Transfer-Encoding: deflate, chunked var iter = mem.splitBackwardsScalar(u8, header_value, ','); - if (iter.next()) |first| { - const trimmed = mem.trim(u8, first, " "); + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); - if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { - if (req.transfer_encoding != null) return error.HttpHeadersInvalid; - req.transfer_encoding = te; - } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - if (req.transfer_compression != null) return error.HttpHeadersInvalid; - req.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (req.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + req.transfer_encoding = transfer; + + next = iter.next(); } - if (iter.next()) |second| { - if (req.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); - const trimmed = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - req.transfer_compression = ce; + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported + req.transfer_compression = transfer; } else { return error.HttpTransferEncodingUnsupported; } @@ -256,7 +252,7 @@ pub const Request = struct { if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (req.transfer_compression != null) return error.HttpHeadersInvalid; + if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; const trimmed = mem.trim(u8, header_value, " "); @@ -278,8 +274,8 @@ pub const Request = struct { version: http.Version, content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - transfer_compression: ?http.ContentEncoding = null, + transfer_encoding: http.TransferEncoding = .none, + transfer_compression: http.ContentEncoding = .identity, headers: http.Headers, parser: proto.HeadersParser, @@ -511,8 +507,9 @@ pub const Response = struct { res.request.headers = .{ .allocator = res.allocator, .owned = true }; try res.request.parse(res.request.parser.header_bytes.items); - if (res.request.transfer_encoding) |te| { - switch (te) { + if (res.request.transfer_encoding != .none) { + switch (res.request.transfer_encoding) { + .none => unreachable, .chunked => { res.request.parser.next_chunk_length = 0; res.request.parser.state = .chunk_head_size; @@ -527,19 +524,19 @@ pub const Response = struct { } if (!res.request.parser.done) { - if (res.request.transfer_compression) |tc| switch (tc) { + switch (res.request.transfer_compression) { .identity => res.request.compression = .none, - .compress => return error.CompressionNotSupported, + .compress, .@"x-compress" => return error.CompressionNotSupported, .deflate => res.request.compression = .{ .deflate = std.compress.zlib.decompressStream(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, - .gzip => res.request.compression = .{ + .gzip, .@"x-gzip" => res.request.compression = .{ .gzip = std.compress.gzip.decompress(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .zstd => res.request.compression = .{ .zstd = std.compress.zstd.decompressStream(res.allocator, res.transferReader()), }, - }; + } } } @@ -754,7 +751,7 @@ test "HTTP server handles a chunked transfer coding request" { defer _ = res.reset(); try res.wait(); - try expect(res.request.transfer_encoding.? == .chunked); + try expect(res.request.transfer_encoding == .chunked); const server_body: []const u8 = "message from server!\n"; res.transfer_encoding = .{ .content_length = server_body.len }; From d4cf8ea0b7621f7757203ecdbaa760e99cbc455c Mon Sep 17 00:00:00 2001 From: Nameless Date: Sat, 7 Oct 2023 19:58:15 -0500 Subject: [PATCH 07/12] std.http.Client: improve documentation --- lib/std/http/Client.zig | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 558c1e18dd..b3c0f3e97c 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -489,13 +489,21 @@ pub const Response = struct { status: http.Status, reason: []const u8, + /// If present, the number of bytes in the response body. content_length: ?u64 = null, + + /// If present, the transfer encoding of the response body, otherwise none. transfer_encoding: http.TransferEncoding = .none, + + /// If present, the compression of the response body, otherwise identity (no compression). transfer_compression: http.ContentEncoding = .identity, + /// The headers received from the server. headers: http.Headers, parser: proto.HeadersParser, compression: Compression = .none, + + /// Whether the response body should be skipped. Any data read from the response body will be discarded. skip: bool = false, }; @@ -511,6 +519,8 @@ pub const Request = struct { method: http.Method, version: http.Version = .@"HTTP/1.1", headers: http.Headers, + + /// The transfer encoding of the request body. transfer_encoding: RequestTransfer = .none, redirects_left: u32, @@ -595,7 +605,7 @@ pub const Request = struct { raw_uri: bool = false, }; - /// Send the request to the server. + /// Send the HTTP request to the server. pub fn start(req: *Request, options: StartOptions) StartError!void { if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding; @@ -730,6 +740,8 @@ pub const Request = struct { /// /// If `handle_redirects` is true and the request has no payload, then this function will automatically follow /// redirects. If a request payload is present, then this function will error with error.RedirectRequiresResend. + /// + /// Must be called after `start` and, if any data was written to the request body, then also after `finish`. pub fn wait(req: *Request) WaitError!void { while (true) { // handle redirects while (true) { // read headers @@ -865,7 +877,7 @@ pub const Request = struct { return .{ .context = req }; } - /// Reads data from the response body. Must be called after `do`. + /// Reads data from the response body. Must be called after `wait`. pub fn read(req: *Request, buffer: []u8) ReadError!usize { const out_index = switch (req.response.compression) { .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, @@ -896,7 +908,7 @@ pub const Request = struct { return out_index; } - /// Reads data from the response body. Must be called after `do`. + /// Reads data from the response body. Must be called after `wait`. pub fn readAll(req: *Request, buffer: []u8) !usize { var index: usize = 0; while (index < buffer.len) { @@ -915,7 +927,8 @@ pub const Request = struct { return .{ .context = req }; } - /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `start` and before `finish`. pub fn write(req: *Request, bytes: []const u8) WriteError!usize { switch (req.transfer_encoding) { .chunked => { @@ -936,6 +949,8 @@ pub const Request = struct { } } + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `start` and before `finish`. pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { var index: usize = 0; while (index < bytes.len) { @@ -946,6 +961,7 @@ pub const Request = struct { pub const FinishError = WriteError || error{MessageNotCompleted}; /// Finish the body of a request. This notifies the server that you have no more data to send. + /// Must be called after `start`. pub fn finish(req: *Request) FinishError!void { switch (req.transfer_encoding) { .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), @@ -1134,6 +1150,8 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{ NameTooLong, Unsupported } || std.os.ConnectError; +/// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open. +/// This function is threadsafe. pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { if (!net.has_unix_sockets) return error.Unsupported; @@ -1166,6 +1184,8 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti return &conn.data; } +/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP CONNECT. This will reuse a connection if one is already open. +/// This function is threadsafe. pub fn connectTunnel( client: *Client, proxy: *ProxyInformation, @@ -1245,6 +1265,11 @@ pub fn connectTunnel( const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, ConnectionRefused }; pub const ConnectError = ConnectErrorPartial || RequestError; +/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// +/// If a proxy is configured for the client, then the proxy will be used to connect to the host. +/// +/// This function is threadsafe. pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*Connection { // pointer required so that `supports_connect` can be updated if a CONNECT fails const potential_proxy: ?*ProxyInformation = switch (protocol) { @@ -1318,7 +1343,7 @@ pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{ .{ "wss", .tls }, }); -/// Form and send a http request to a server. +/// Open a connection to the host specified by `uri` and prepare to send a HTTP request. /// /// `uri` must remain alive during the entire request. /// `headers` is cloned and may be freed after this function returns. @@ -1420,6 +1445,9 @@ pub const FetchResult = struct { } }; +/// Perform a one-shot HTTP request with the provided options. +/// +/// This function is threadsafe. pub fn fetch(client: *Client, allocator: Allocator, options: FetchOptions) !FetchResult { const has_transfer_encoding = options.headers.contains("transfer-encoding"); const has_content_length = options.headers.contains("content-length"); From 544ed34d99f59a1b487341eaaa9610be44629924 Mon Sep 17 00:00:00 2001 From: Nameless Date: Sat, 7 Oct 2023 20:05:04 -0500 Subject: [PATCH 08/12] std.http.Server: improve documentation, do -> start Response.do was renamed to Response.start to mimic the naming scheme in http.Client --- lib/std/http/Client.zig | 2 +- lib/std/http/Server.zig | 22 ++++++++++++++++++---- test/standalone/http.zig | 24 ++++++++++++------------ 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index b3c0f3e97c..9eb363d752 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -605,7 +605,7 @@ pub const Request = struct { raw_uri: bool = false, }; - /// Send the HTTP request to the server. + /// Send the HTTP request headers to the server. pub fn start(req: *Request, options: StartOptions) StartError!void { if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index da53b2b05d..6620433bd4 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -14,7 +14,7 @@ allocator: Allocator, socket: net.StreamServer, -/// An interface to either a plain or TLS connection. +/// An interface to a plain connection. pub const Connection = struct { pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; pub const Protocol = enum { plain }; @@ -273,8 +273,13 @@ pub const Request = struct { target: []const u8, version: http.Version, + /// The length of the request body, if known. content_length: ?u64 = null, + + /// The transfer encoding of the request body, or .none if not present. transfer_encoding: http.TransferEncoding = .none, + + /// The compression of the request body, or .identity (no compression) if not present. transfer_compression: http.ContentEncoding = .identity, headers: http.Headers, @@ -311,6 +316,7 @@ pub const Response = struct { finished, }; + /// Free all resources associated with this response. pub fn deinit(res: *Response) void { res.connection.close(); @@ -386,10 +392,10 @@ pub const Response = struct { } } - pub const DoError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; + pub const StartError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; - /// Send the response headers. - pub fn do(res: *Response) DoError!void { + /// Send the HTTP response headers to the client. + pub fn start(res: *Response) StartError!void { switch (res.state) { .waited => res.state = .responded, .first, .start, .responded, .finished => unreachable, @@ -548,6 +554,7 @@ pub const Response = struct { return .{ .context = res }; } + /// Reads data from the response body. Must be called after `wait`. pub fn read(res: *Response, buffer: []u8) ReadError!usize { switch (res.state) { .waited, .responded, .finished => {}, @@ -583,6 +590,7 @@ pub const Response = struct { return out_index; } + /// Reads data from the response body. Must be called after `wait`. pub fn readAll(res: *Response, buffer: []u8) !usize { var index: usize = 0; while (index < buffer.len) { @@ -602,6 +610,7 @@ pub const Response = struct { } /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + /// Must be called after `start` and before `finish`. pub fn write(res: *Response, bytes: []const u8) WriteError!usize { switch (res.state) { .responded => {}, @@ -627,6 +636,8 @@ pub const Response = struct { } } + /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + /// Must be called after `start` and before `finish`. pub fn writeAll(req: *Response, bytes: []const u8) WriteError!void { var index: usize = 0; while (index < bytes.len) { @@ -637,6 +648,7 @@ pub const Response = struct { pub const FinishError = WriteError || error{MessageNotCompleted}; /// Finish the body of a request. This notifies the server that you have no more data to send. + /// Must be called after `start`. pub fn finish(res: *Response) FinishError!void { switch (res.state) { .responded => res.state = .finished, @@ -651,6 +663,7 @@ pub const Response = struct { } }; +/// Create a new HTTP server. pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server { return .{ .allocator = allocator, @@ -658,6 +671,7 @@ pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server { }; } +/// Free all resources associated with this server. pub fn deinit(server: *Server) void { server.socket.deinit(); } diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 55a8456fde..a242bb5778 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -29,11 +29,11 @@ fn handleRequest(res: *Server.Response) !void { if (res.request.headers.contains("expect")) { if (mem.eql(u8, res.request.headers.getFirstValue("expect").?, "100-continue")) { res.status = .@"continue"; - try res.do(); + try res.start(); res.status = .ok; } else { res.status = .expectation_failed; - try res.do(); + try res.start(); return; } } @@ -54,7 +54,7 @@ fn handleRequest(res: *Server.Response) !void { try res.headers.append("content-type", "text/plain"); - try res.do(); + try res.start(); if (res.request.method != .HEAD) { try res.writeAll("Hello, "); try res.writeAll("World!\n"); @@ -65,7 +65,7 @@ fn handleRequest(res: *Server.Response) !void { } else if (mem.startsWith(u8, res.request.target, "/large")) { res.transfer_encoding = .{ .content_length = 14 * 1024 + 14 * 10 }; - try res.do(); + try res.start(); var i: u32 = 0; while (i < 5) : (i += 1) { @@ -92,14 +92,14 @@ fn handleRequest(res: *Server.Response) !void { try testing.expectEqualStrings("14", res.request.headers.getFirstValue("content-length").?); } - try res.do(); + try res.start(); try res.writeAll("Hello, "); try res.writeAll("World!\n"); try res.finish(); } else if (mem.eql(u8, res.request.target, "/trailer")) { res.transfer_encoding = .chunked; - try res.do(); + try res.start(); try res.writeAll("Hello, "); try res.writeAll("World!\n"); // try res.finish(); @@ -110,7 +110,7 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", "../../get"); - try res.do(); + try res.start(); try res.writeAll("Hello, "); try res.writeAll("Redirected!\n"); try res.finish(); @@ -120,7 +120,7 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", "/redirect/1"); - try res.do(); + try res.start(); try res.writeAll("Hello, "); try res.writeAll("Redirected!\n"); try res.finish(); @@ -133,7 +133,7 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", location); - try res.do(); + try res.start(); try res.writeAll("Hello, "); try res.writeAll("Redirected!\n"); try res.finish(); @@ -143,7 +143,7 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", "/redirect/3"); - try res.do(); + try res.start(); try res.writeAll("Hello, "); try res.writeAll("Redirected!\n"); try res.finish(); @@ -154,11 +154,11 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", location); - try res.do(); + try res.start(); try res.finish(); } else { res.status = .not_found; - try res.do(); + try res.start(); } } From 363d0ee5e13f4ac3a93d246121edfd00ef9fd97b Mon Sep 17 00:00:00 2001 From: Nameless Date: Sat, 7 Oct 2023 20:27:39 -0500 Subject: [PATCH 09/12] std.http: rename start->send and request->open to be more inline with operation --- lib/std/http/Client.zig | 32 +++++++------- lib/std/http/Server.zig | 6 +-- src/Package/Fetch.zig | 4 +- src/Package/Fetch/git.zig | 12 +++--- test/standalone/http.zig | 90 +++++++++++++++++++-------------------- 5 files changed, 72 insertions(+), 72 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 9eb363d752..e0a8328bab 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -598,15 +598,15 @@ pub const Request = struct { }; } - pub const StartError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; - pub const StartOptions = struct { - /// Specifies that the uri should be used as is + pub const SendOptions = struct { + /// Specifies that the uri should be used as is. You guarantee that the uri is already escaped. raw_uri: bool = false, }; /// Send the HTTP request headers to the server. - pub fn start(req: *Request, options: StartOptions) StartError!void { + pub fn send(req: *Request, options: SendOptions) SendError!void { if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding; const w = req.connection.?.writer(); @@ -733,14 +733,14 @@ pub const Request = struct { return index; } - pub const WaitError = RequestError || StartError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, RedirectRequiresResend, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported }; + pub const WaitError = RequestError || SendError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, RedirectRequiresResend, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported }; /// Waits for a response from the server and parses any headers that are sent. /// This function will block until the final response is received. /// /// If `handle_redirects` is true and the request has no payload, then this function will automatically follow /// redirects. If a request payload is present, then this function will error with error.RedirectRequiresResend. - /// + /// /// Must be called after `start` and, if any data was written to the request body, then also after `finish`. pub fn wait(req: *Request) WaitError!void { while (true) { // handle redirects @@ -845,7 +845,7 @@ pub const Request = struct { try req.redirect(resolved_url); - try req.start(.{}); + try req.send(.{}); } else { req.response.skip = false; if (!req.response.parser.done) { @@ -1223,7 +1223,7 @@ pub fn connectTunnel( // we can use a small buffer here because a CONNECT response should be very small var buffer: [8096]u8 = undefined; - var req = client.request(.CONNECT, uri, proxy.headers, .{ + var req = client.open(.CONNECT, uri, proxy.headers, .{ .handle_redirects = false, .connection = conn, .header_strategy = .{ .static = buffer[0..] }, @@ -1233,7 +1233,7 @@ pub fn connectTunnel( }; defer req.deinit(); - req.start(.{ .raw_uri = true }) catch |err| break :tunnel err; + req.send(.{ .raw_uri = true }) catch |err| break :tunnel err; req.wait() catch |err| break :tunnel err; if (req.response.status.class() == .server_error) { @@ -1266,9 +1266,9 @@ const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, Conn pub const ConnectError = ConnectErrorPartial || RequestError; /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. -/// +/// /// If a proxy is configured for the client, then the proxy will be used to connect to the host. -/// +/// /// This function is threadsafe. pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*Connection { // pointer required so that `supports_connect` can be updated if a CONNECT fails @@ -1304,7 +1304,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio return client.connectTcp(host, port, protocol); } -pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || Connection.WriteError || error{ +pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || std.fmt.ParseIntError || Connection.WriteError || error{ UnsupportedUrlScheme, UriMissingHost, @@ -1350,7 +1350,7 @@ pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{ /// /// The caller is responsible for calling `deinit()` on the `Request`. /// This function is threadsafe. -pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: RequestOptions) RequestError!Request { +pub fn open(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: RequestOptions) RequestError!Request { const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; const port: u16 = uri.port orelse switch (protocol) { @@ -1446,7 +1446,7 @@ pub const FetchResult = struct { }; /// Perform a one-shot HTTP request with the provided options. -/// +/// /// This function is threadsafe. pub fn fetch(client: *Client, allocator: Allocator, options: FetchOptions) !FetchResult { const has_transfer_encoding = options.headers.contains("transfer-encoding"); @@ -1459,7 +1459,7 @@ pub fn fetch(client: *Client, allocator: Allocator, options: FetchOptions) !Fetc .uri => |u| u, }; - var req = try request(client, options.method, uri, options.headers, .{ + var req = try open(client, options.method, uri, options.headers, .{ .header_strategy = options.header_strategy, .handle_redirects = options.payload == .none, }); @@ -1476,7 +1476,7 @@ pub fn fetch(client: *Client, allocator: Allocator, options: FetchOptions) !Fetc .none => {}, } - try req.start(.{ .raw_uri = options.raw_uri }); + try req.send(.{ .raw_uri = options.raw_uri }); switch (options.payload) { .string => |str| try req.writeAll(str), diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 6620433bd4..ff568f21fd 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -392,10 +392,10 @@ pub const Response = struct { } } - pub const StartError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; + pub const SendError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; /// Send the HTTP response headers to the client. - pub fn start(res: *Response) StartError!void { + pub fn send(res: *Response) SendError!void { switch (res.state) { .waited => res.state = .responded, .first, .start, .responded, .finished => unreachable, @@ -771,7 +771,7 @@ test "HTTP server handles a chunked transfer coding request" { res.transfer_encoding = .{ .content_length = server_body.len }; try res.headers.append("content-type", "text/plain"); try res.headers.append("connection", "close"); - try res.do(); + try res.send(); var buf: [128]u8 = undefined; const n = try res.readAll(&buf); diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index 50468f4c2c..c10008aa16 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -826,7 +826,7 @@ fn initResource(f: *Fetch, uri: std.Uri) RunError!Resource { var h = std.http.Headers{ .allocator = gpa }; defer h.deinit(); - var req = http_client.request(.GET, uri, h, .{}) catch |err| { + var req = http_client.open(.GET, uri, h, .{}) catch |err| { return f.fail(f.location_tok, try eb.printString( "unable to connect to server: {s}", .{@errorName(err)}, @@ -834,7 +834,7 @@ fn initResource(f: *Fetch, uri: std.Uri) RunError!Resource { }; errdefer req.deinit(); // releases more than memory - req.start(.{}) catch |err| { + req.send(.{}) catch |err| { return f.fail(f.location_tok, try eb.printString( "HTTP request failed: {s}", .{@errorName(err)}, diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig index af4317702d..9a2e682300 100644 --- a/src/Package/Fetch/git.zig +++ b/src/Package/Fetch/git.zig @@ -518,11 +518,11 @@ pub const Session = struct { defer headers.deinit(); try headers.append("Git-Protocol", "version=2"); - var request = try session.transport.request(.GET, info_refs_uri, headers, .{ + var request = try session.transport.open(.GET, info_refs_uri, headers, .{ .max_redirects = 3, }); errdefer request.deinit(); - try request.start(.{}); + try request.send(.{}); try request.finish(); try request.wait(); @@ -641,12 +641,12 @@ pub const Session = struct { } try Packet.write(.flush, body_writer); - var request = try session.transport.request(.POST, upload_pack_uri, headers, .{ + var request = try session.transport.open(.POST, upload_pack_uri, headers, .{ .handle_redirects = false, }); errdefer request.deinit(); request.transfer_encoding = .{ .content_length = body.items.len }; - try request.start(.{}); + try request.send(.{}); try request.writeAll(body.items); try request.finish(); @@ -740,12 +740,12 @@ pub const Session = struct { try Packet.write(.{ .data = "done\n" }, body_writer); try Packet.write(.flush, body_writer); - var request = try session.transport.request(.POST, upload_pack_uri, headers, .{ + var request = try session.transport.open(.POST, upload_pack_uri, headers, .{ .handle_redirects = false, }); errdefer request.deinit(); request.transfer_encoding = .{ .content_length = body.items.len }; - try request.start(.{}); + try request.send(.{}); try request.writeAll(body.items); try request.finish(); diff --git a/test/standalone/http.zig b/test/standalone/http.zig index a242bb5778..b79aabd0fb 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -29,11 +29,11 @@ fn handleRequest(res: *Server.Response) !void { if (res.request.headers.contains("expect")) { if (mem.eql(u8, res.request.headers.getFirstValue("expect").?, "100-continue")) { res.status = .@"continue"; - try res.start(); + try res.send(); res.status = .ok; } else { res.status = .expectation_failed; - try res.start(); + try res.send(); return; } } @@ -54,7 +54,7 @@ fn handleRequest(res: *Server.Response) !void { try res.headers.append("content-type", "text/plain"); - try res.start(); + try res.send(); if (res.request.method != .HEAD) { try res.writeAll("Hello, "); try res.writeAll("World!\n"); @@ -65,7 +65,7 @@ fn handleRequest(res: *Server.Response) !void { } else if (mem.startsWith(u8, res.request.target, "/large")) { res.transfer_encoding = .{ .content_length = 14 * 1024 + 14 * 10 }; - try res.start(); + try res.send(); var i: u32 = 0; while (i < 5) : (i += 1) { @@ -92,14 +92,14 @@ fn handleRequest(res: *Server.Response) !void { try testing.expectEqualStrings("14", res.request.headers.getFirstValue("content-length").?); } - try res.start(); + try res.send(); try res.writeAll("Hello, "); try res.writeAll("World!\n"); try res.finish(); } else if (mem.eql(u8, res.request.target, "/trailer")) { res.transfer_encoding = .chunked; - try res.start(); + try res.send(); try res.writeAll("Hello, "); try res.writeAll("World!\n"); // try res.finish(); @@ -110,7 +110,7 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", "../../get"); - try res.start(); + try res.send(); try res.writeAll("Hello, "); try res.writeAll("Redirected!\n"); try res.finish(); @@ -120,7 +120,7 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", "/redirect/1"); - try res.start(); + try res.send(); try res.writeAll("Hello, "); try res.writeAll("Redirected!\n"); try res.finish(); @@ -133,7 +133,7 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", location); - try res.start(); + try res.send(); try res.writeAll("Hello, "); try res.writeAll("Redirected!\n"); try res.finish(); @@ -143,7 +143,7 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", "/redirect/3"); - try res.start(); + try res.send(); try res.writeAll("Hello, "); try res.writeAll("Redirected!\n"); try res.finish(); @@ -154,11 +154,11 @@ fn handleRequest(res: *Server.Response) !void { res.status = .found; try res.headers.append("location", location); - try res.start(); + try res.send(); try res.finish(); } else { res.status = .not_found; - try res.start(); + try res.send(); } } @@ -244,10 +244,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192); @@ -269,10 +269,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192 * 1024); @@ -293,10 +293,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.HEAD, uri, h, .{}); + var req = try client.open(.HEAD, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192); @@ -319,10 +319,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192); @@ -344,10 +344,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.HEAD, uri, h, .{}); + var req = try client.open(.HEAD, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192); @@ -370,10 +370,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192); @@ -397,12 +397,12 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.POST, uri, h, .{}); + var req = try client.open(.POST, uri, h, .{}); defer req.deinit(); req.transfer_encoding = .{ .content_length = 14 }; - try req.start(.{}); + try req.send(.{}); try req.writeAll("Hello, "); try req.writeAll("World!\n"); try req.finish(); @@ -429,10 +429,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192); @@ -456,12 +456,12 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.POST, uri, h, .{}); + var req = try client.open(.POST, uri, h, .{}); defer req.deinit(); req.transfer_encoding = .chunked; - try req.start(.{}); + try req.send(.{}); try req.writeAll("Hello, "); try req.writeAll("World!\n"); try req.finish(); @@ -486,10 +486,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192); @@ -510,10 +510,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192); @@ -534,10 +534,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); try req.wait(); const body = try req.reader().readAllAlloc(calloc, 8192); @@ -558,10 +558,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); req.wait() catch |err| switch (err) { error.TooManyHttpRedirects => {}, else => return err, @@ -580,10 +580,10 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.GET, uri, h, .{}); + var req = try client.open(.GET, uri, h, .{}); defer req.deinit(); - try req.start(.{}); + try req.send(.{}); const result = req.wait(); // a proxy without an upstream is likely to return a 5xx status. @@ -628,12 +628,12 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.POST, uri, h, .{}); + var req = try client.open(.POST, uri, h, .{}); defer req.deinit(); req.transfer_encoding = .chunked; - try req.start(.{}); + try req.send(.{}); try req.wait(); try testing.expectEqual(http.Status.@"continue", req.response.status); @@ -662,12 +662,12 @@ pub fn main() !void { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var req = try client.request(.POST, uri, h, .{}); + var req = try client.open(.POST, uri, h, .{}); defer req.deinit(); req.transfer_encoding = .chunked; - try req.start(.{}); + try req.send(.{}); try req.wait(); try testing.expectEqual(http.Status.expectation_failed, req.response.status); } @@ -682,7 +682,7 @@ pub fn main() !void { defer calloc.free(requests); for (0..total_connections) |i| { - var req = try client.request(.GET, uri, .{ .allocator = calloc }, .{}); + var req = try client.open(.GET, uri, .{ .allocator = calloc }, .{}); req.response.parser.done = true; req.connection.?.closing = false; requests[i] = req; From 7dd3099519fd0f64fcdf11791fc9ba95a68e0637 Mon Sep 17 00:00:00 2001 From: Nameless Date: Tue, 17 Oct 2023 19:08:22 -0500 Subject: [PATCH 10/12] std.http: fix crashes found via fuzzing --- lib/std/http.zig | 3 ++- lib/std/http/Client.zig | 28 +++++++++++++++++++++++----- lib/std/http/Headers.zig | 11 +++++++---- lib/std/http/Server.zig | 2 +- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/lib/std/http.zig b/lib/std/http.zig index 3f7af6b6e3..9487e82106 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -35,7 +35,8 @@ pub const Method = enum(u64) { // TODO: should be u192 or u256, but neither is s /// Asserts that `s` is 24 or fewer bytes. pub fn parse(s: []const u8) u64 { var x: u64 = 0; - @memcpy(std.mem.asBytes(&x)[0..s.len], s); + const len = @min(s.len, @sizeOf(@TypeOf(x))); + @memcpy(std.mem.asBytes(&x)[0..len], s[0..len]); return x; } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index e0a8328bab..6de2986166 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -15,7 +15,7 @@ const proto = @import("protocol.zig"); pub const disable_tls = std.options.http_disable_tls; allocator: Allocator, -ca_bundle: std.crypto.Certificate.Bundle = .{}, +ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, /// When this is `true`, the next time this client performs an HTTPS request, @@ -386,7 +386,7 @@ pub const Response = struct { }; pub fn parse(res: *Response, bytes: []const u8, trailing: bool) ParseError!void { - var it = mem.tokenizeAny(u8, bytes[0 .. bytes.len - 4], "\r\n"); + var it = mem.tokenizeAny(u8, bytes, "\r\n"); const first_line = it.next() orelse return error.HttpHeadersInvalid; if (first_line.len < 12) @@ -405,6 +405,8 @@ pub const Response = struct { res.status = status; res.reason = reason; + res.headers.clearRetainingCapacity(); + while (it.next()) |line| { if (line.len == 0) return error.HttpHeadersInvalid; switch (line[0]) { @@ -525,6 +527,7 @@ pub const Request = struct { redirects_left: u32, handle_redirects: bool, + handle_continue: bool, response: Response, @@ -758,6 +761,10 @@ pub const Request = struct { if (req.response.status == .@"continue") { req.response.parser.done = true; // we're done parsing the continue response, reset to prepare for the real response req.response.parser.reset(); + + if (req.handle_continue) + continue; + break; } @@ -897,8 +904,6 @@ pub const Request = struct { } if (has_trail) { - req.response.headers.clearRetainingCapacity(); - // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. // This will *only* fail for a malformed trailer. req.response.parse(req.response.parser.header_bytes.items, true) catch return error.InvalidTrailers; @@ -999,7 +1004,9 @@ pub fn deinit(client: *Client) void { proxy.headers.deinit(); } - client.ca_bundle.deinit(client.allocator); + if (!disable_tls) + client.ca_bundle.deinit(client.allocator); + client.* = undefined; } @@ -1315,6 +1322,14 @@ pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendE pub const RequestOptions = struct { version: http.Version = .@"HTTP/1.1", + /// Automatically ignore 100 Continue responses. This assumes you don't care, and will have sent the body before you + /// wait for the response. + /// + /// If this is not the case AND you know the server will send a 100 Continue, set this to false and wait for a + /// response before sending the body. If you wait AND the server does not send a 100 Continue before you finish the + /// request, then the request *will* deadlock. + handle_continue: bool = true, + handle_redirects: bool = true, max_redirects: u32 = 3, header_strategy: StorageStrategy = .{ .dynamic = 16 * 1024 }, @@ -1361,6 +1376,8 @@ pub fn open(client: *Client, method: http.Method, uri: Uri, headers: http.Header const host = uri.host orelse return error.UriMissingHost; if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .Acquire)) { + if (disable_tls) unreachable; + client.ca_bundle_mutex.lock(); defer client.ca_bundle_mutex.unlock(); @@ -1381,6 +1398,7 @@ pub fn open(client: *Client, method: http.Method, uri: Uri, headers: http.Header .version = options.version, .redirects_left = options.max_redirects, .handle_redirects = options.handle_redirects, + .handle_continue = options.handle_continue, .response = .{ .status = undefined, .reason = undefined, diff --git a/lib/std/http/Headers.zig b/lib/std/http/Headers.zig index f69f7fef5f..d2e578b5ee 100644 --- a/lib/std/http/Headers.zig +++ b/lib/std/http/Headers.zig @@ -14,15 +14,18 @@ pub const CaseInsensitiveStringContext = struct { pub fn hash(self: @This(), s: []const u8) u64 { _ = self; var buf: [64]u8 = undefined; - var i: u8 = 0; + var i: usize = 0; var h = std.hash.Wyhash.init(0); - while (i < s.len) : (i += 64) { - const left = @min(64, s.len - i); - const ret = ascii.lowerString(buf[0..], s[i..][0..left]); + while (i + 64 < s.len) : (i += 64) { + const ret = ascii.lowerString(buf[0..], s[i..][0..64]); h.update(ret); } + const left = @min(64, s.len - i); + const ret = ascii.lowerString(buf[0..], s[i..][0..left]); + h.update(ret); + return h.final(); } diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index ff568f21fd..057d1ef5ca 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -178,7 +178,7 @@ pub const Request = struct { }; pub fn parse(req: *Request, bytes: []const u8) ParseError!void { - var it = mem.tokenizeAny(u8, bytes[0 .. bytes.len - 4], "\r\n"); + var it = mem.tokenizeAny(u8, bytes, "\r\n"); const first_line = it.next() orelse return error.HttpHeadersInvalid; if (first_line.len < 10) From dd010e9e90c5ff6cbd5f390dbbb534ddf2fc87b6 Mon Sep 17 00:00:00 2001 From: Nameless Date: Wed, 18 Oct 2023 11:03:27 -0500 Subject: [PATCH 11/12] std.http.Client: ignore unknown proxies, fix basic proxy auth --- lib/std/http/Client.zig | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 6de2986166..832ce5a1dc 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1028,13 +1028,14 @@ pub fn loadDefaultProxies(client: *Client) !void { const uri = try Uri.parse(content); - const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; + const protocol = protocol_map.get(uri.scheme) orelse break :http; // Unknown scheme, ignore + const host = if (uri.host) |host| try client.allocator.dupe(u8, host) else break :http; // Missing host, ignore client.http_proxy = .{ .allocator = client.allocator, .headers = .{ .allocator = client.allocator }, .protocol = protocol, - .host = if (uri.host) |host| try client.allocator.dupe(u8, host) else return error.UriMissingHost, + .host = host, .port = uri.port orelse switch (protocol) { .plain => 80, .tls => 443, @@ -1042,13 +1043,16 @@ pub fn loadDefaultProxies(client: *Client) !void { }; if (uri.user != null and uri.password != null) { + const prefix_len = "Basic ".len; + const unencoded = try std.fmt.allocPrint(client.allocator, "{s}:{s}", .{ uri.user.?, uri.password.? }); defer client.allocator.free(unencoded); - const buffer = try client.allocator.alloc(u8, std.base64.standard.Encoder.calcSize(unencoded.len)); + const buffer = try client.allocator.alloc(u8, std.base64.standard.Encoder.calcSize(unencoded.len) + prefix_len); defer client.allocator.free(buffer); - const result = std.base64.standard.Encoder.encode(buffer, unencoded); + const result = std.base64.standard.Encoder.encode(buffer[prefix_len..], unencoded); + @memcpy(buffer[0..prefix_len], "Basic "); try client.http_proxy.?.headers.append("proxy-authorization", result); } @@ -1069,13 +1073,14 @@ pub fn loadDefaultProxies(client: *Client) !void { const uri = try Uri.parse(content); - const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; + const protocol = protocol_map.get(uri.scheme) orelse break :https; // Unknown scheme, ignore + const host = if (uri.host) |host| try client.allocator.dupe(u8, host) else break :https; // Missing host, ignore client.http_proxy = .{ .allocator = client.allocator, .headers = .{ .allocator = client.allocator }, .protocol = protocol, - .host = if (uri.host) |host| try client.allocator.dupe(u8, host) else return error.UriMissingHost, + .host = host, .port = uri.port orelse switch (protocol) { .plain => 80, .tls => 443, @@ -1083,13 +1088,16 @@ pub fn loadDefaultProxies(client: *Client) !void { }; if (uri.user != null and uri.password != null) { + const prefix_len = "Basic ".len; + const unencoded = try std.fmt.allocPrint(client.allocator, "{s}:{s}", .{ uri.user.?, uri.password.? }); defer client.allocator.free(unencoded); - const buffer = try client.allocator.alloc(u8, std.base64.standard.Encoder.calcSize(unencoded.len)); + const buffer = try client.allocator.alloc(u8, std.base64.standard.Encoder.calcSize(unencoded.len) + prefix_len); defer client.allocator.free(buffer); - const result = std.base64.standard.Encoder.encode(buffer, unencoded); + const result = std.base64.standard.Encoder.encode(buffer[prefix_len..], unencoded); + @memcpy(buffer[0..prefix_len], "Basic "); try client.https_proxy.?.headers.append("proxy-authorization", result); } From 93e1f8c8e583b3140bc1985e8b346fd7aca8cf6b Mon Sep 17 00:00:00 2001 From: Nameless Date: Fri, 20 Oct 2023 20:13:25 -0500 Subject: [PATCH 12/12] std.http.Client: documentaion fixes --- lib/std/http/Client.zig | 43 +++++++++++++++++++++++++++++----------- test/standalone/http.zig | 3 --- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 832ce5a1dc..23bbd994a9 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -14,7 +14,11 @@ const proto = @import("protocol.zig"); pub const disable_tls = std.options.http_disable_tls; +/// Allocator used for all allocations made by the client. +/// +/// This allocator must be thread-safe. allocator: Allocator, + ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, @@ -26,10 +30,10 @@ next_https_rescan_certs: bool = true, connection_pool: ConnectionPool = .{}, /// This is the proxy that will handle http:// connections. It *must not* be modified when the client has any active connections. -http_proxy: ?ProxyInformation = null, +http_proxy: ?Proxy = null, /// This is the proxy that will handle https:// connections. It *must not* be modified when the client has any active connections. -https_proxy: ?ProxyInformation = null, +https_proxy: ?Proxy = null, /// A set of linked lists of connections that can be reused. pub const ConnectionPool = struct { @@ -61,6 +65,8 @@ pub const ConnectionPool = struct { while (next) |node| : (next = node.prev) { if (node.data.protocol != criteria.protocol) continue; if (node.data.port != criteria.port) continue; + + // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; pool.acquireUnsafe(node); @@ -88,6 +94,9 @@ pub const ConnectionPool = struct { /// Tries to release a connection back to the connection pool. This function is threadsafe. /// If the connection is marked as closing, it will be closed instead. + /// + /// The allocator must be the owner of all nodes in this pool. + /// The allocator must be the owner of all resources associated with the connection. pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -195,7 +204,7 @@ pub const Connection = struct { pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { return conn.tls_client.readv(conn.stream, buffers) catch |err| { - // TODO: https://github.com/ziglang/zig/issues/2473 + // https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; switch (err) { @@ -978,7 +987,7 @@ pub const Request = struct { } }; -pub const ProxyInformation = struct { +pub const Proxy = struct { allocator: Allocator, headers: http.Headers, @@ -990,8 +999,12 @@ pub const ProxyInformation = struct { }; /// Release all associated resources with the client. -/// TODO: currently leaks all request allocated data +/// +/// All pending requests must be de-initialized and all active connections released +/// before calling this function. pub fn deinit(client: *Client) void { + assert(client.connection_pool.used.first == null); // There are still active requests. + client.connection_pool.deinit(client.allocator); if (client.http_proxy) |*proxy| { @@ -1013,6 +1026,12 @@ pub fn deinit(client: *Client) void { /// Uses the *_proxy environment variable to set any unset proxies for the client. /// This function *must not* be called when the client has any active connections. pub fn loadDefaultProxies(client: *Client) !void { + // Prevent any new connections from being created. + client.connection_pool.mutex.lock(); + defer client.connection_pool.mutex.unlock(); + + assert(client.connection_pool.used.first == null); // There are still active requests. + if (client.http_proxy == null) http: { const content: []const u8 = if (std.process.hasEnvVarConstant("http_proxy")) try std.process.getEnvVarOwned(client.allocator, "http_proxy") @@ -1203,7 +1222,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti /// This function is threadsafe. pub fn connectTunnel( client: *Client, - proxy: *ProxyInformation, + proxy: *Proxy, tunnel_host: []const u8, tunnel_port: u16, ) !*Connection { @@ -1217,7 +1236,7 @@ pub fn connectTunnel( return node; var maybe_valid = false; - _ = tunnel: { + (tunnel: { const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); errdefer { conn.closing = true; @@ -1241,7 +1260,7 @@ pub fn connectTunnel( var req = client.open(.CONNECT, uri, proxy.headers, .{ .handle_redirects = false, .connection = conn, - .header_strategy = .{ .static = buffer[0..] }, + .header_strategy = .{ .static = &buffer }, }) catch |err| { std.log.debug("err {}", .{err}); break :tunnel err; @@ -1269,7 +1288,7 @@ pub fn connectTunnel( conn.closing = false; return conn; - } catch { + }) catch { // something went wrong with the tunnel proxy.supports_connect = maybe_valid; return error.TunnelNotSupported; @@ -1287,7 +1306,7 @@ pub const ConnectError = ConnectErrorPartial || RequestError; /// This function is threadsafe. pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*Connection { // pointer required so that `supports_connect` can be updated if a CONNECT fails - const potential_proxy: ?*ProxyInformation = switch (protocol) { + const potential_proxy: ?*Proxy = switch (protocol) { .plain => if (client.http_proxy) |*proxy_info| proxy_info else null, .tls => if (client.https_proxy) |*proxy_info| proxy_info else null, }; @@ -1298,12 +1317,12 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio return client.connectTcp(host, port, protocol); } - _ = if (proxy.supports_connect) tunnel: { + if (proxy.supports_connect) tunnel: { return connectTunnel(client, proxy, host, port) catch |err| switch (err) { error.TunnelNotSupported => break :tunnel, else => |e| return e, }; - }; + } // fall back to using the proxy as a normal http proxy const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); diff --git a/test/standalone/http.zig b/test/standalone/http.zig index b79aabd0fb..c48a5ae1b9 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -634,9 +634,6 @@ pub fn main() !void { req.transfer_encoding = .chunked; try req.send(.{}); - try req.wait(); - try testing.expectEqual(http.Status.@"continue", req.response.status); - try req.writeAll("Hello, "); try req.writeAll("World!\n"); try req.finish();