From 0a4130f364c2714b206257d0cf589103da823407 Mon Sep 17 00:00:00 2001 From: Nameless Date: Mon, 6 Mar 2023 23:35:35 -0600 Subject: [PATCH] std.http: handle relative redirects --- lib/std/Uri.zig | 109 ++++++++++++++++++++++---- lib/std/crypto/tls/Client.zig | 2 +- lib/std/http/Client.zig | 140 +++++++++++++++++++++++----------- lib/std/net.zig | 6 +- 4 files changed, 196 insertions(+), 61 deletions(-) diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 015b6c34f6..eb6311a19b 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -16,15 +16,27 @@ fragment: ?[]const u8, /// Applies URI encoding and replaces all reserved characters with their respective %XX code. pub fn escapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 { + return escapeStringWithFn(allocator, input, isUnreserved); +} + +pub fn escapePath(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 { + return escapeStringWithFn(allocator, input, isPathChar); +} + +pub fn escapeQuery(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 { + return escapeStringWithFn(allocator, input, isQueryChar); +} + +pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]const u8 { var outsize: usize = 0; for (input) |c| { - outsize += if (isUnreserved(c)) @as(usize, 1) else 3; + outsize += if (keepUnescaped(c)) @as(usize, 1) else 3; } var output = try allocator.alloc(u8, outsize); var outptr: usize = 0; for (input) |c| { - if (isUnreserved(c)) { + if (keepUnescaped(c)) { output[outptr] = c; outptr += 1; } else { @@ -94,13 +106,14 @@ pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{Out pub const ParseError = error{ UnexpectedCharacter, InvalidFormat, InvalidPort }; -/// Parses the URI or returns an error. +/// Parses the URI or returns an error. This function is not compliant, but is required to parse +/// some forms of URIs in the wild. Such as HTTP Location headers. /// The return value will contain unescaped strings pointing into the /// original `text`. Each component that is provided, will be non-`null`. -pub fn parse(text: []const u8) ParseError!Uri { +pub fn parseWithoutScheme(text: []const u8) ParseError!Uri { var reader = SliceReader{ .slice = text }; var uri = Uri{ - .scheme = reader.readWhile(isSchemeChar), + .scheme = "", .user = null, .password = null, .host = null, @@ -110,14 +123,6 @@ pub fn parse(text: []const u8) ParseError!Uri { .fragment = null, }; - // after the scheme, a ':' must appear - if (reader.get()) |c| { - if (c != ':') - return error.UnexpectedCharacter; - } else { - return error.InvalidFormat; - } - if (reader.peekPrefix("//")) { // authority part std.debug.assert(reader.get().? == '/'); std.debug.assert(reader.get().? == '/'); @@ -179,6 +184,76 @@ pub fn parse(text: []const u8) ParseError!Uri { return uri; } +/// Parses the URI or returns an error. +/// The return value will contain unescaped strings pointing into the +/// original `text`. Each component that is provided, will be non-`null`. +pub fn parse(text: []const u8) ParseError!Uri { + var reader = SliceReader{ .slice = text }; + const scheme = reader.readWhile(isSchemeChar); + + // after the scheme, a ':' must appear + if (reader.get()) |c| { + if (c != ':') + return error.UnexpectedCharacter; + } else { + return error.InvalidFormat; + } + + var uri = try parseWithoutScheme(reader.readUntilEof()); + uri.scheme = scheme; + + return uri; +} + +/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5. +/// arena owns any memory allocated by this function. +pub fn resolve(Base: Uri, R: Uri, strict: bool, arena: std.mem.Allocator) !Uri { + var T: Uri = undefined; + + if (R.scheme.len > 0 and !((!strict) and (std.mem.eql(u8, R.scheme, Base.scheme)))) { + T.scheme = R.scheme; + T.user = R.user; + T.host = R.host; + T.port = R.port; + T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path }); + T.query = R.query; + } else { + if (R.host) |host| { + T.user = R.user; + T.host = host; + T.port = R.port; + T.path = R.path; + T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path }); + T.query = R.query; + } else { + if (R.path.len == 0) { + T.path = Base.path; + if (R.query) |query| { + T.query = query; + } else { + T.query = Base.query; + } + } else { + if (R.path[0] == '/') { + T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path }); + } else { + T.path = try std.fs.path.resolvePosix(arena, &.{ "/", Base.path, R.path }); + } + T.query = R.query; + } + + T.user = Base.user; + T.host = Base.host; + T.port = Base.port; + } + T.scheme = Base.scheme; + } + + T.fragment = R.fragment; + + return T; +} + const SliceReader = struct { const Self = @This(); @@ -284,6 +359,14 @@ fn isPathSeparator(c: u8) bool { }; } +fn isPathChar(c: u8) bool { + return isUnreserved(c) or isSubLimit(c) or c == '/' or c == ':' or c == '@'; +} + +fn isQueryChar(c: u8) bool { + return isPathChar(c) or c == '?'; +} + fn isQuerySeparator(c: u8) bool { return switch (c) { '#' => true, diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 01bf957820..bc59459ff9 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -89,7 +89,7 @@ pub const StreamInterface = struct { }; pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error { + return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{ InsufficientEntropy, DiskQuota, LockViolation, diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index cac6571798..5b3a74d292 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -29,9 +29,10 @@ const ConnectionPool = std.TailQueue(Connection); const ConnectionNode = ConnectionPool.Node; /// Acquires an existing connection from the connection pool. This function is threadsafe. -pub fn acquire(client: *Client, node: *ConnectionNode) void { - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); +/// If the caller already holds the connection mutex, it should pass `true` for `held`. +pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void { + if (!held) client.connection_mutex.lock(); + defer if (!held) client.connection_mutex.unlock(); client.connection_pool.remove(node); client.connection_used.append(node); @@ -40,16 +41,17 @@ pub fn acquire(client: *Client, node: *ConnectionNode) void { /// Tries to release a connection back to the connection pool. This function is threadsafe. /// If the connection is marked as closing, it will be closed instead. pub fn release(client: *Client, node: *ConnectionNode) void { + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + + client.connection_used.remove(node); + if (node.data.closing) { node.data.close(client); return client.allocator.destroy(node); } - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); - - client.connection_used.remove(node); client.connection_pool.append(node); } @@ -83,7 +85,7 @@ pub const Connection = struct { } } - pub const ReadError = std.net.Stream.ReadError || error{ + pub const ReadError = net.Stream.ReadError || error{ TlsConnectionTruncated, TlsRecordOverflow, TlsDecodeError, @@ -115,7 +117,7 @@ pub const Connection = struct { } } - pub const WriteError = std.net.Stream.WriteError || error{}; + pub const WriteError = net.Stream.WriteError || error{}; pub const Writer = std.io.Writer(*Connection, WriteError, write); pub fn writer(conn: *Connection) Writer { @@ -139,14 +141,21 @@ pub const Request = struct { const read_buffer_size = 8192; const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size); + uri: Uri, client: *Client, connection: *ConnectionNode, - redirects_left: u32, response: Response, /// These are stored in Request so that they are available when following /// redirects. headers: Headers, + redirects_left: u32, + handle_redirects: bool, + compression_init: bool, + + /// Used as a allocator for resolving redirects locations. + arena: std.heap.ArenaAllocator, + /// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning. read_buffer: [read_buffer_size]u8 = undefined, read_buffer_start: ReadBufferIndex = 0, @@ -661,6 +670,7 @@ pub const Request = struct { pub const Headers = struct { version: http.Version = .@"HTTP/1.1", method: http.Method = .GET, + user_agent: []const u8 = "Zig (std.http)", connection: http.Connection = .keep_alive, transfer_encoding: RequestTransfer = .none, @@ -668,6 +678,7 @@ pub const Request = struct { }; pub const Options = struct { + handle_redirects: bool = true, max_redirects: u32 = 3, header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 }, @@ -703,10 +714,11 @@ pub const Request = struct { req.client.release(req.connection); } + req.arena.deinit(); req.* = undefined; } - const ReadRawError = Connection.ReadError || std.Uri.ParseError || RequestError || error{ + const ReadRawError = Connection.ReadError || Uri.ParseError || RequestError || error{ UnexpectedEndOfStream, TooManyHttpRedirects, HttpRedirectMissingLocation, @@ -723,9 +735,7 @@ pub const Request = struct { var index: usize = 0; while (index == 0) { const amt = try req.readRawAdvanced(buffer[index..]); - const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect; - - if (amt == 0 and zero_means_end) break; + if (amt == 0 and req.response.done) break; index += amt; } @@ -769,6 +779,8 @@ pub const Request = struct { } } else if (req.response.headers.content_length) |content_length| { req.response.next_chunk_length = content_length; + + if (content_length == 0) req.response.done = true; } else { req.response.done = true; } @@ -779,7 +791,7 @@ pub const Request = struct { return 0; } - pub const WaitForCompleteHeadError = ReadRawError || error { + pub const WaitForCompleteHeadError = ReadRawError || error{ UnexpectedEndOfStream, HttpHeadersExceededSizeLimit, @@ -810,27 +822,8 @@ pub const Request = struct { /// This one can return 0 without meaning EOF. fn readRawAdvanced(req: *Request, buffer: []u8) !usize { - if (req.response.done) { - if (req.response.headers.status.class() == .redirect) { - if (req.redirects_left == 0) return error.TooManyHttpRedirects; - - const location = req.response.headers.location orelse - return error.HttpRedirectMissingLocation; - const new_url = try std.Uri.parse(location); - const new_req = try req.client.request(new_url, req.headers, .{ - .max_redirects = req.redirects_left - 1, - .header_strategy = if (req.response.header_bytes_owned) .{ - .dynamic = req.response.max_header_bytes, - } else .{ - .static = req.response.header_bytes.unusedCapacitySlice(), - }, - }); - req.deinit(); - req.* = new_req; - } else { - return 0; - } - } + assert(req.response.state.isContent()); + if (req.response.done) return 0; // var in: []const u8 = undefined; if (req.read_buffer_start == req.read_buffer_len) { @@ -851,7 +844,7 @@ pub const Request = struct { const data_avail = req.response.next_chunk_length; const out_avail = buffer.len; - if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) { + if (req.handle_redirects and req.response.headers.status.class() == .redirect) { const can_read = @intCast(usize, @min(buf_avail, data_avail)); req.response.next_chunk_length -= can_read; @@ -859,7 +852,6 @@ pub const Request = struct { req.client.release(req.connection); req.connection = undefined; req.response.done = true; - continue; } return 0; // skip over as much data as possible @@ -943,7 +935,7 @@ pub const Request = struct { const data_avail = req.response.next_chunk_length; const out_avail = buffer.len - out_index; - if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) { + if (req.handle_redirects and req.response.headers.status.class() == .redirect) { const can_read = @intCast(usize, @min(buf_avail, data_avail)); req.response.next_chunk_length -= can_read; @@ -990,9 +982,41 @@ pub const Request = struct { } pub fn read(req: *Request, buffer: []u8) ReadError!usize { - if (!req.response.state.isContent()) try req.waitForCompleteHead(); + while (true) { + if (!req.response.state.isContent()) try req.waitForCompleteHead(); - if (req.response.compression == .none and req.response.state.isContent()) { + if (req.handle_redirects and req.response.headers.status.class() == .redirect) { + assert(try req.readRaw(buffer) == 0); + + if (req.redirects_left == 0) return error.TooManyHttpRedirects; + + const location = req.response.headers.location orelse + return error.HttpRedirectMissingLocation; + const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location); + + var new_arena = std.heap.ArenaAllocator.init(req.client.allocator); + const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator()); + errdefer new_arena.deinit(); + + req.arena.deinit(); + req.arena = new_arena; + + const new_req = try req.client.request(resolved_url, req.headers, .{ + .max_redirects = req.redirects_left - 1, + .header_strategy = if (req.response.header_bytes_owned) .{ + .dynamic = req.response.max_header_bytes, + } else .{ + .static = req.response.header_bytes.unusedCapacitySlice(), + }, + }); + req.deinit(); + req.* = new_req; + } else { + break; + } + } + + if (req.response.compression == .none) { if (req.response.headers.transfer_compression) |compression| { switch (compression) { .compress => unreachable, @@ -1084,6 +1108,8 @@ pub const Request = struct { }; pub fn deinit(client: *Client) void { + client.connection_mutex.lock(); + var next = client.connection_pool.first; while (next) |node| { next = node.next; @@ -1106,7 +1132,7 @@ pub fn deinit(client: *Client) void { client.* = undefined; } -pub const ConnectError = std.mem.Allocator.Error || std.net.TcpConnectToHostError || std.crypto.tls.Client.InitError(std.net.Stream); +pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream); pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode { { // Search through the connection pool for a potential connection. @@ -1120,7 +1146,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio const same_protocol = node.data.protocol == protocol; if (same_host and same_port and same_protocol) { - client.acquire(node); + client.acquire(node, true); return node; } @@ -1168,6 +1194,7 @@ pub const RequestError = ConnectError || Connection.WriteError || error{ InvalidPadding, MissingEndCertificateMarker, Unseekable, + EndOfStream, }; pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request { @@ -1196,27 +1223,52 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req } var req: Request = .{ + .uri = uri, .client = client, .headers = headers, .connection = try client.connect(host, port, protocol), .redirects_left = options.max_redirects, + .handle_redirects = options.handle_redirects, + .compression_init = false, .response = switch (options.header_strategy) { .dynamic => |max| Request.Response.initDynamic(max), .static => |buf| Request.Response.initStatic(buf), }, + .arena = undefined, }; + req.arena = std.heap.ArenaAllocator.init(client.allocator); + { var buffered = std.io.bufferedWriter(req.connection.data.writer()); const writer = buffered.writer(); + const escaped_path = try Uri.escapePath(client.allocator, uri.path); + defer client.allocator.free(escaped_path); + + const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null; + defer if (escaped_query) |q| client.allocator.free(q); + + const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null; + defer if (escaped_fragment) |f| client.allocator.free(f); + try writer.writeAll(@tagName(headers.method)); try writer.writeByte(' '); - try writer.writeAll(uri.path); + try writer.writeAll(escaped_path); + if (escaped_query) |q| { + try writer.writeByte('?'); + try writer.writeAll(q); + } + if (escaped_fragment) |f| { + try writer.writeByte('#'); + try writer.writeAll(f); + } try writer.writeByte(' '); try writer.writeAll(@tagName(headers.version)); try writer.writeAll("\r\nHost: "); try writer.writeAll(host); + try writer.writeAll("\r\nUser-Agent: "); + try writer.writeAll(headers.user_agent); if (headers.connection == .close) { try writer.writeAll("\r\nConnection: close"); } else { diff --git a/lib/std/net.zig b/lib/std/net.zig index cf112cbab9..7222433fd5 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -741,9 +741,9 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { return Stream{ .handle = sockfd }; } -const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error { +const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error{ // TODO: break this up into error sets from the various underlying functions - + TemporaryNameServerFailure, NameServerFailure, AddressFamilyNotSupported, @@ -760,7 +760,7 @@ const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || Incomplete, InvalidIpv4Mapping, InvalidIPAddressFormat, - + InterfaceNotFound, FileSystem, };