std.http: handle relative redirects

This commit is contained in:
Nameless 2023-03-06 23:35:35 -06:00
parent fd2f906d1e
commit 0a4130f364
No known key found for this signature in database
GPG Key ID: A477BC03CAFCCAF7
4 changed files with 196 additions and 61 deletions

View File

@ -16,15 +16,27 @@ fragment: ?[]const u8,
/// Applies URI encoding and replaces all reserved characters with their respective %XX code.
pub fn escapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
return escapeStringWithFn(allocator, input, isUnreserved);
}
pub fn escapePath(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
return escapeStringWithFn(allocator, input, isPathChar);
}
pub fn escapeQuery(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
return escapeStringWithFn(allocator, input, isQueryChar);
}
pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]const u8 {
var outsize: usize = 0;
for (input) |c| {
outsize += if (isUnreserved(c)) @as(usize, 1) else 3;
outsize += if (keepUnescaped(c)) @as(usize, 1) else 3;
}
var output = try allocator.alloc(u8, outsize);
var outptr: usize = 0;
for (input) |c| {
if (isUnreserved(c)) {
if (keepUnescaped(c)) {
output[outptr] = c;
outptr += 1;
} else {
@ -94,13 +106,14 @@ pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{Out
pub const ParseError = error{ UnexpectedCharacter, InvalidFormat, InvalidPort };
/// Parses the URI or returns an error.
/// Parses the URI or returns an error. This function is not compliant, but is required to parse
/// some forms of URIs in the wild. Such as HTTP Location headers.
/// The return value will contain unescaped strings pointing into the
/// original `text`. Each component that is provided, will be non-`null`.
pub fn parse(text: []const u8) ParseError!Uri {
pub fn parseWithoutScheme(text: []const u8) ParseError!Uri {
var reader = SliceReader{ .slice = text };
var uri = Uri{
.scheme = reader.readWhile(isSchemeChar),
.scheme = "",
.user = null,
.password = null,
.host = null,
@ -110,14 +123,6 @@ pub fn parse(text: []const u8) ParseError!Uri {
.fragment = null,
};
// after the scheme, a ':' must appear
if (reader.get()) |c| {
if (c != ':')
return error.UnexpectedCharacter;
} else {
return error.InvalidFormat;
}
if (reader.peekPrefix("//")) { // authority part
std.debug.assert(reader.get().? == '/');
std.debug.assert(reader.get().? == '/');
@ -179,6 +184,76 @@ pub fn parse(text: []const u8) ParseError!Uri {
return uri;
}
/// Parses the URI or returns an error.
/// The return value will contain unescaped strings pointing into the
/// original `text`. Each component that is provided, will be non-`null`.
pub fn parse(text: []const u8) ParseError!Uri {
var reader = SliceReader{ .slice = text };
const scheme = reader.readWhile(isSchemeChar);
// after the scheme, a ':' must appear
if (reader.get()) |c| {
if (c != ':')
return error.UnexpectedCharacter;
} else {
return error.InvalidFormat;
}
var uri = try parseWithoutScheme(reader.readUntilEof());
uri.scheme = scheme;
return uri;
}
/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5.
/// arena owns any memory allocated by this function.
pub fn resolve(Base: Uri, R: Uri, strict: bool, arena: std.mem.Allocator) !Uri {
var T: Uri = undefined;
if (R.scheme.len > 0 and !((!strict) and (std.mem.eql(u8, R.scheme, Base.scheme)))) {
T.scheme = R.scheme;
T.user = R.user;
T.host = R.host;
T.port = R.port;
T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path });
T.query = R.query;
} else {
if (R.host) |host| {
T.user = R.user;
T.host = host;
T.port = R.port;
T.path = R.path;
T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path });
T.query = R.query;
} else {
if (R.path.len == 0) {
T.path = Base.path;
if (R.query) |query| {
T.query = query;
} else {
T.query = Base.query;
}
} else {
if (R.path[0] == '/') {
T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path });
} else {
T.path = try std.fs.path.resolvePosix(arena, &.{ "/", Base.path, R.path });
}
T.query = R.query;
}
T.user = Base.user;
T.host = Base.host;
T.port = Base.port;
}
T.scheme = Base.scheme;
}
T.fragment = R.fragment;
return T;
}
const SliceReader = struct {
const Self = @This();
@ -284,6 +359,14 @@ fn isPathSeparator(c: u8) bool {
};
}
fn isPathChar(c: u8) bool {
return isUnreserved(c) or isSubLimit(c) or c == '/' or c == ':' or c == '@';
}
fn isQueryChar(c: u8) bool {
return isPathChar(c) or c == '?';
}
fn isQuerySeparator(c: u8) bool {
return switch (c) {
'#' => true,

View File

@ -89,7 +89,7 @@ pub const StreamInterface = struct {
};
pub fn InitError(comptime Stream: type) type {
return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error {
return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{
InsufficientEntropy,
DiskQuota,
LockViolation,

View File

@ -29,9 +29,10 @@ const ConnectionPool = std.TailQueue(Connection);
const ConnectionNode = ConnectionPool.Node;
/// Acquires an existing connection from the connection pool. This function is threadsafe.
pub fn acquire(client: *Client, node: *ConnectionNode) void {
client.connection_mutex.lock();
defer client.connection_mutex.unlock();
/// If the caller already holds the connection mutex, it should pass `true` for `held`.
pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void {
if (!held) client.connection_mutex.lock();
defer if (!held) client.connection_mutex.unlock();
client.connection_pool.remove(node);
client.connection_used.append(node);
@ -40,16 +41,17 @@ pub fn acquire(client: *Client, node: *ConnectionNode) void {
/// Tries to release a connection back to the connection pool. This function is threadsafe.
/// If the connection is marked as closing, it will be closed instead.
pub fn release(client: *Client, node: *ConnectionNode) void {
client.connection_mutex.lock();
defer client.connection_mutex.unlock();
client.connection_used.remove(node);
if (node.data.closing) {
node.data.close(client);
return client.allocator.destroy(node);
}
client.connection_mutex.lock();
defer client.connection_mutex.unlock();
client.connection_used.remove(node);
client.connection_pool.append(node);
}
@ -83,7 +85,7 @@ pub const Connection = struct {
}
}
pub const ReadError = std.net.Stream.ReadError || error{
pub const ReadError = net.Stream.ReadError || error{
TlsConnectionTruncated,
TlsRecordOverflow,
TlsDecodeError,
@ -115,7 +117,7 @@ pub const Connection = struct {
}
}
pub const WriteError = std.net.Stream.WriteError || error{};
pub const WriteError = net.Stream.WriteError || error{};
pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer {
@ -139,14 +141,21 @@ pub const Request = struct {
const read_buffer_size = 8192;
const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size);
uri: Uri,
client: *Client,
connection: *ConnectionNode,
redirects_left: u32,
response: Response,
/// These are stored in Request so that they are available when following
/// redirects.
headers: Headers,
redirects_left: u32,
handle_redirects: bool,
compression_init: bool,
/// Used as a allocator for resolving redirects locations.
arena: std.heap.ArenaAllocator,
/// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning.
read_buffer: [read_buffer_size]u8 = undefined,
read_buffer_start: ReadBufferIndex = 0,
@ -661,6 +670,7 @@ pub const Request = struct {
pub const Headers = struct {
version: http.Version = .@"HTTP/1.1",
method: http.Method = .GET,
user_agent: []const u8 = "Zig (std.http)",
connection: http.Connection = .keep_alive,
transfer_encoding: RequestTransfer = .none,
@ -668,6 +678,7 @@ pub const Request = struct {
};
pub const Options = struct {
handle_redirects: bool = true,
max_redirects: u32 = 3,
header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
@ -703,10 +714,11 @@ pub const Request = struct {
req.client.release(req.connection);
}
req.arena.deinit();
req.* = undefined;
}
const ReadRawError = Connection.ReadError || std.Uri.ParseError || RequestError || error{
const ReadRawError = Connection.ReadError || Uri.ParseError || RequestError || error{
UnexpectedEndOfStream,
TooManyHttpRedirects,
HttpRedirectMissingLocation,
@ -723,9 +735,7 @@ pub const Request = struct {
var index: usize = 0;
while (index == 0) {
const amt = try req.readRawAdvanced(buffer[index..]);
const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect;
if (amt == 0 and zero_means_end) break;
if (amt == 0 and req.response.done) break;
index += amt;
}
@ -769,6 +779,8 @@ pub const Request = struct {
}
} else if (req.response.headers.content_length) |content_length| {
req.response.next_chunk_length = content_length;
if (content_length == 0) req.response.done = true;
} else {
req.response.done = true;
}
@ -779,7 +791,7 @@ pub const Request = struct {
return 0;
}
pub const WaitForCompleteHeadError = ReadRawError || error {
pub const WaitForCompleteHeadError = ReadRawError || error{
UnexpectedEndOfStream,
HttpHeadersExceededSizeLimit,
@ -810,27 +822,8 @@ pub const Request = struct {
/// This one can return 0 without meaning EOF.
fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
if (req.response.done) {
if (req.response.headers.status.class() == .redirect) {
if (req.redirects_left == 0) return error.TooManyHttpRedirects;
const location = req.response.headers.location orelse
return error.HttpRedirectMissingLocation;
const new_url = try std.Uri.parse(location);
const new_req = try req.client.request(new_url, req.headers, .{
.max_redirects = req.redirects_left - 1,
.header_strategy = if (req.response.header_bytes_owned) .{
.dynamic = req.response.max_header_bytes,
} else .{
.static = req.response.header_bytes.unusedCapacitySlice(),
},
});
req.deinit();
req.* = new_req;
} else {
return 0;
}
}
assert(req.response.state.isContent());
if (req.response.done) return 0;
// var in: []const u8 = undefined;
if (req.read_buffer_start == req.read_buffer_len) {
@ -851,7 +844,7 @@ pub const Request = struct {
const data_avail = req.response.next_chunk_length;
const out_avail = buffer.len;
if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) {
if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
const can_read = @intCast(usize, @min(buf_avail, data_avail));
req.response.next_chunk_length -= can_read;
@ -859,7 +852,6 @@ pub const Request = struct {
req.client.release(req.connection);
req.connection = undefined;
req.response.done = true;
continue;
}
return 0; // skip over as much data as possible
@ -943,7 +935,7 @@ pub const Request = struct {
const data_avail = req.response.next_chunk_length;
const out_avail = buffer.len - out_index;
if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) {
if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
const can_read = @intCast(usize, @min(buf_avail, data_avail));
req.response.next_chunk_length -= can_read;
@ -990,9 +982,41 @@ pub const Request = struct {
}
pub fn read(req: *Request, buffer: []u8) ReadError!usize {
if (!req.response.state.isContent()) try req.waitForCompleteHead();
while (true) {
if (!req.response.state.isContent()) try req.waitForCompleteHead();
if (req.response.compression == .none and req.response.state.isContent()) {
if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
assert(try req.readRaw(buffer) == 0);
if (req.redirects_left == 0) return error.TooManyHttpRedirects;
const location = req.response.headers.location orelse
return error.HttpRedirectMissingLocation;
const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location);
var new_arena = std.heap.ArenaAllocator.init(req.client.allocator);
const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator());
errdefer new_arena.deinit();
req.arena.deinit();
req.arena = new_arena;
const new_req = try req.client.request(resolved_url, req.headers, .{
.max_redirects = req.redirects_left - 1,
.header_strategy = if (req.response.header_bytes_owned) .{
.dynamic = req.response.max_header_bytes,
} else .{
.static = req.response.header_bytes.unusedCapacitySlice(),
},
});
req.deinit();
req.* = new_req;
} else {
break;
}
}
if (req.response.compression == .none) {
if (req.response.headers.transfer_compression) |compression| {
switch (compression) {
.compress => unreachable,
@ -1084,6 +1108,8 @@ pub const Request = struct {
};
pub fn deinit(client: *Client) void {
client.connection_mutex.lock();
var next = client.connection_pool.first;
while (next) |node| {
next = node.next;
@ -1106,7 +1132,7 @@ pub fn deinit(client: *Client) void {
client.* = undefined;
}
pub const ConnectError = std.mem.Allocator.Error || std.net.TcpConnectToHostError || std.crypto.tls.Client.InitError(std.net.Stream);
pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode {
{ // Search through the connection pool for a potential connection.
@ -1120,7 +1146,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
const same_protocol = node.data.protocol == protocol;
if (same_host and same_port and same_protocol) {
client.acquire(node);
client.acquire(node, true);
return node;
}
@ -1168,6 +1194,7 @@ pub const RequestError = ConnectError || Connection.WriteError || error{
InvalidPadding,
MissingEndCertificateMarker,
Unseekable,
EndOfStream,
};
pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request {
@ -1196,27 +1223,52 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
}
var req: Request = .{
.uri = uri,
.client = client,
.headers = headers,
.connection = try client.connect(host, port, protocol),
.redirects_left = options.max_redirects,
.handle_redirects = options.handle_redirects,
.compression_init = false,
.response = switch (options.header_strategy) {
.dynamic => |max| Request.Response.initDynamic(max),
.static => |buf| Request.Response.initStatic(buf),
},
.arena = undefined,
};
req.arena = std.heap.ArenaAllocator.init(client.allocator);
{
var buffered = std.io.bufferedWriter(req.connection.data.writer());
const writer = buffered.writer();
const escaped_path = try Uri.escapePath(client.allocator, uri.path);
defer client.allocator.free(escaped_path);
const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null;
defer if (escaped_query) |q| client.allocator.free(q);
const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null;
defer if (escaped_fragment) |f| client.allocator.free(f);
try writer.writeAll(@tagName(headers.method));
try writer.writeByte(' ');
try writer.writeAll(uri.path);
try writer.writeAll(escaped_path);
if (escaped_query) |q| {
try writer.writeByte('?');
try writer.writeAll(q);
}
if (escaped_fragment) |f| {
try writer.writeByte('#');
try writer.writeAll(f);
}
try writer.writeByte(' ');
try writer.writeAll(@tagName(headers.version));
try writer.writeAll("\r\nHost: ");
try writer.writeAll(host);
try writer.writeAll("\r\nUser-Agent: ");
try writer.writeAll(headers.user_agent);
if (headers.connection == .close) {
try writer.writeAll("\r\nConnection: close");
} else {

View File

@ -741,9 +741,9 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream {
return Stream{ .handle = sockfd };
}
const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error {
const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error{
// TODO: break this up into error sets from the various underlying functions
TemporaryNameServerFailure,
NameServerFailure,
AddressFamilyNotSupported,
@ -760,7 +760,7 @@ const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError ||
Incomplete,
InvalidIpv4Mapping,
InvalidIPAddressFormat,
InterfaceNotFound,
FileSystem,
};