std.http: very basic http client proxy

This commit is contained in:
Nameless 2023-04-14 12:38:13 -05:00
parent 2c492064fb
commit 96533b1289
No known key found for this signature in database
GPG Key ID: A477BC03CAFCCAF7
4 changed files with 155 additions and 50 deletions

View File

@ -27,6 +27,18 @@ pub fn escapeQuery(allocator: std.mem.Allocator, input: []const u8) error{OutOfM
return escapeStringWithFn(allocator, input, isQueryChar);
}
pub fn writeEscapedString(writer: anytype, input: []const u8) !void {
return writeEscapedStringWithFn(writer, input, isUnreserved);
}
pub fn writeEscapedPath(writer: anytype, input: []const u8) !void {
return writeEscapedStringWithFn(writer, input, isPathChar);
}
pub fn writeEscapedQuery(writer: anytype, input: []const u8) !void {
return writeEscapedStringWithFn(writer, input, isQueryChar);
}
pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]const u8 {
var outsize: usize = 0;
for (input) |c| {
@ -52,6 +64,16 @@ pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, compt
return output;
}
pub fn writeEscapedStringWithFn(writer: anytype, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) @TypeOf(writer).Error!void {
for (input) |c| {
if (keepUnescaped(c)) {
try writer.writeByte(c);
} else {
try writer.print("%{X:0>2}", .{c});
}
}
}
/// Parses a URI string and unescapes all %XX where XX is a valid hex number. Otherwise, verbatim copies
/// them to the output.
pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
@ -184,6 +206,60 @@ pub fn parseWithoutScheme(text: []const u8) ParseError!Uri {
return uri;
}
pub fn format(
uri: Uri,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) @TypeOf(writer).Error!void {
_ = options;
const needs_absolute = comptime std.mem.indexOf(u8, fmt, "+") != null;
const needs_path = comptime std.mem.indexOf(u8, fmt, "/") != null or fmt.len == 0;
if (needs_absolute) {
try writer.writeAll(uri.scheme);
try writer.writeAll(":");
if (uri.host) |host| {
try writer.writeAll("//");
if (uri.user) |user| {
try writer.writeAll(user);
if (uri.password) |password| {
try writer.writeAll(":");
try writer.writeAll(password);
}
try writer.writeAll("@");
}
try writer.writeAll(host);
if (uri.port) |port| {
try writer.writeAll(":");
try std.fmt.formatInt(port, 10, .lower, .{}, writer);
}
}
}
if (needs_path) {
if (uri.path.len == 0) {
try writer.writeAll("/");
} else {
try Uri.writeEscapedPath(writer, uri.path);
}
if (uri.query) |q| {
try writer.writeAll("?");
try Uri.writeEscapedQuery(writer, q);
}
if (uri.fragment) |f| {
try writer.writeAll("#");
try Uri.writeEscapedQuery(writer, f);
}
}
}
/// Parses the URI or returns an error.
/// The return value will contain unescaped strings pointing into the
/// original `text`. Each component that is provided, will be non-`null`.

View File

@ -265,7 +265,7 @@ pub const Connection = enum {
close,
};
pub const CustomHeader = struct {
pub const Header = struct {
name: []const u8,
value: []const u8,
};

View File

@ -25,27 +25,7 @@ next_https_rescan_certs: bool = true,
/// The pool of connections that can be reused (and currently in use).
connection_pool: ConnectionPool = .{},
pub const ExtraError = union(enum) {
pub const TcpConnectError = std.net.TcpConnectToHostError;
pub const TlsError = std.crypto.tls.Client.InitError(net.Stream);
pub const WriteError = BufferedConnection.WriteError;
pub const ReadError = BufferedConnection.ReadError || error{HttpChunkInvalid};
pub const CaBundleError = std.crypto.Certificate.Bundle.RescanError;
pub const ZlibInitError = error{ BadHeader, InvalidCompression, InvalidWindowSize, Unsupported, EndOfStream, OutOfMemory } || Request.TransferReadError;
pub const GzipInitError = error{ BadHeader, InvalidCompression, OutOfMemory, WrongChecksum, EndOfStream, StreamTooLong } || Request.TransferReadError;
// pub const DecompressError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error;
pub const DecompressError = anyerror; // FIXME: the above line causes a false positive dependency loop
zlib_init: ZlibInitError, // error.CompressionInitializationFailed
gzip_init: GzipInitError, // error.CompressionInitializationFailed
connect: TcpConnectError, // error.ConnectionFailed
ca_bundle: CaBundleError, // error.CertificateAuthorityBundleFailed
tls: TlsError, // error.TlsInitializationFailed
write: WriteError, // error.WriteFailed
read: ReadError, // error.ReadFailed
decompress: DecompressError, // error.ReadFailed
};
proxy: ?HttpProxy = null,
/// A set of linked lists of connections that can be reused.
pub const ConnectionPool = struct {
@ -61,6 +41,7 @@ pub const ConnectionPool = struct {
host: []u8,
port: u16,
proxied: bool = false,
closing: bool = false,
pub fn deinit(self: *StoredConnection, client: *Client) void {
@ -137,7 +118,12 @@ pub const ConnectionPool = struct {
return client.allocator.destroy(popped);
}
pool.free.append(node);
if (node.data.proxied) {
pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first
} else {
pool.free.append(node);
}
pool.free_len += 1;
}
@ -546,9 +532,10 @@ pub const Request = struct {
if (!req.response.parser.done) {
// If the response wasn't fully read, then we need to close the connection.
req.connection.data.closing = true;
req.client.connection_pool.release(req.client, req.connection);
}
req.client.connection_pool.release(req.client, req.connection);
req.arena.deinit();
req.* = undefined;
}
@ -557,30 +544,20 @@ pub const Request = struct {
var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
const w = buffered.writer();
const escaped_path = try Uri.escapePath(req.client.allocator, uri.path);
defer req.client.allocator.free(escaped_path);
const escaped_query = if (uri.query) |q| try Uri.escapeQuery(req.client.allocator, q) else null;
defer if (escaped_query) |q| req.client.allocator.free(q);
const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(req.client.allocator, f) else null;
defer if (escaped_fragment) |f| req.client.allocator.free(f);
try w.writeAll(@tagName(headers.method));
try w.writeByte(' ');
if (escaped_path.len == 0) {
try w.writeByte('/');
if (req.headers.method == .CONNECT) {
try w.writeAll(uri.host.?);
try w.writeByte(':');
try w.print("{}", .{uri.port.?});
} else if (req.connection.data.proxied) {
// proxied connections require the full uri
try w.print("{+/}", .{uri});
} else {
try w.writeAll(escaped_path);
}
if (escaped_query) |q| {
try w.writeByte('?');
try w.writeAll(q);
}
if (escaped_fragment) |f| {
try w.writeByte('#');
try w.writeAll(f);
try w.print("{/}", .{uri});
}
try w.writeByte(' ');
try w.writeAll(@tagName(headers.version));
try w.writeAll("\r\nHost: ");
@ -659,6 +636,12 @@ pub const Request = struct {
req.response.parser.done = true;
}
if (req.headers.method == .CONNECT and req.response.headers.status == .ok) {
req.connection.data.closing = false;
req.connection.data.proxied = true;
req.response.parser.done = true;
}
if (req.headers.connection == .keep_alive and req.response.headers.connection == .keep_alive) {
req.connection.data.closing = false;
} else {
@ -802,7 +785,7 @@ pub const Request = struct {
}
}
pub const FinishError = WriteError || error{ MessageNotCompleted };
pub const FinishError = WriteError || error{MessageNotCompleted};
/// Finish the body of a request. This notifies the server that you have no more data to send.
pub fn finish(req: *Request) FinishError!void {
@ -817,6 +800,20 @@ pub const Request = struct {
}
};
pub const HttpProxy = struct {
pub const ProxyAuthentication = union(enum) {
basic: []const u8,
custom: []const u8,
};
protocol: Connection.Protocol,
host: []const u8,
port: ?u16 = null,
/// The value for the Proxy-Authorization header.
auth: ?ProxyAuthentication = null,
};
/// Release all associated resources with the client.
/// TODO: currently leaks all request allocated data
pub fn deinit(client: *Client) void {
@ -826,11 +823,11 @@ pub fn deinit(client: *Client) void {
client.* = undefined;
}
pub const ConnectError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
pub const ConnectUnproxiedError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
/// This function is threadsafe.
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectUnproxiedError!*ConnectionPool.Node {
if (client.connection_pool.findConnection(.{
.host = host,
.port = port,
@ -884,7 +881,34 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
return conn;
}
pub const RequestError = ConnectError || BufferedConnection.WriteError || error{
// Prevents a dependency loop in request()
const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused };
pub const ConnectError = ConnectErrorPartial || RequestError;
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
if (client.connection_pool.findConnection(.{
.host = host,
.port = port,
.is_tls = protocol == .tls,
})) |node|
return node;
if (client.proxy) |proxy| {
const proxy_port: u16 = proxy.port orelse switch (proxy.protocol) {
.plain => 80,
.tls => 443,
};
const conn = try client.connectUnproxied(proxy.host, proxy_port, proxy.protocol);
conn.data.proxied = true;
return conn;
} else {
return client.connectUnproxied(host, port, protocol);
}
}
pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || BufferedConnection.WriteError || error{
UnsupportedUrlScheme,
UriMissingHost,
@ -896,6 +920,9 @@ pub const Options = struct {
max_redirects: u32 = 3,
header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
/// Must be an already acquired connection.
connection: ?*ConnectionPool.Node = null,
pub const HeaderStrategy = union(enum) {
/// In this case, the client's Allocator will be used to store the
/// entire HTTP header. This value is the maximum total size of
@ -939,10 +966,12 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Opt
}
}
const conn = options.connection orelse try client.connect(host, port, protocol);
var req: Request = .{
.uri = uri,
.client = client,
.connection = try client.connect(host, port, protocol),
.connection = conn,
.headers = headers,
.redirects_left = options.max_redirects,
.handle_redirects = options.handle_redirects,

View File

@ -1,4 +1,4 @@
const std = @import("std");
const std = @import("../std.zig");
const testing = std.testing;
const mem = std.mem;