From afb26f4e6b39431001eff75cc8ce19144cb5301a Mon Sep 17 00:00:00 2001 From: Nameless Date: Thu, 2 Mar 2023 12:45:34 -0600 Subject: [PATCH 1/6] std.http: add connection pooling and make keep-alive requests by default --- lib/std/http/Client.zig | 211 ++++++++++++++++++++++++++++++---------- 1 file changed, 161 insertions(+), 50 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 7f62b2d597..d4d8f85ad1 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -21,11 +21,27 @@ ca_bundle: std.crypto.Certificate.Bundle = .{}, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, +connection_pool: std.TailQueue(Connection) = .{}, + +const ConnectionPool = std.TailQueue(Connection); +const ConnectionNode = ConnectionPool.Node; + +pub fn release(client: *Client, node: *ConnectionNode) void { + if (node.data.unusable) return node.data.close(client); + + client.connection_pool.append(node); +} + pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: std.crypto.tls.Client, + tls_client: std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB. protocol: Protocol, + host: []u8, + port: u16, + + // This connection has been part of a non keepalive request and cannot be added to the pool. + unusable: bool = false, pub const Protocol = enum { plain, tls }; @@ -56,6 +72,17 @@ pub const Connection = struct { .tls => return conn.tls_client.write(conn.stream, buffer), } } + + pub fn close(conn: *Connection, client: *const Client) void { + if (conn.protocol == .tls) { + // try to cleanly close the TLS connection, for any server that cares. + _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + } + + conn.stream.close(); + + client.allocator.free(conn.host); + } }; /// TODO: emit error.UnexpectedEndOfStream or something like that when the read @@ -63,7 +90,7 @@ pub const Connection = struct { /// close_notify protection on underlying TLS streams. pub const Request = struct { client: *Client, - connection: Connection, + connection: *ConnectionNode, redirects_left: u32, response: Response, /// These are stored in Request so that they are available when following @@ -79,6 +106,7 @@ pub const Request = struct { header_bytes: std.ArrayListUnmanaged(u8), max_header_bytes: usize, next_chunk_length: u64, + done: bool, pub const Headers = struct { status: http.Status, @@ -86,6 +114,7 @@ pub const Request = struct { location: ?[]const u8 = null, content_length: ?u64 = null, transfer_encoding: ?http.TransferEncoding = null, + connection_close: bool = true, pub fn parse(bytes: []const u8) !Response.Headers { var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); @@ -126,6 +155,14 @@ pub const Request = struct { if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { + headers.connection_close = false; + } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { + headers.connection_close = true; + } else { + return error.HttpConnectionHeaderUnsupported; + } } } @@ -185,10 +222,10 @@ pub const Request = struct { chunk_r, chunk_data, - pub fn zeroMeansEnd(state: State) bool { - return switch (state) { - .finished, .chunk_data => true, - else => false, + pub fn isContent(self: State) bool { + return switch (self) { + .invalid, .start, .seen_r, .seen_rn, .seen_rnr => false, + .finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true, }; } }; @@ -201,6 +238,7 @@ pub const Request = struct { .max_header_bytes = max, .header_bytes_owned = true, .next_chunk_length = undefined, + .done = false, }; } @@ -212,6 +250,7 @@ pub const Request = struct { .max_header_bytes = buf.len, .header_bytes_owned = false, .next_chunk_length = undefined, + .done = false, }; } @@ -501,6 +540,7 @@ pub const Request = struct { pub const Headers = struct { version: http.Version = .@"HTTP/1.1", method: http.Method = .GET, + connection_close: bool = false, }; pub const Options = struct { @@ -545,6 +585,7 @@ pub const Request = struct { HttpHeadersExceededSizeLimit, HttpRedirectMissingLocation, HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, HttpContentLengthUnknown, TooManyHttpRedirects, ShortHttpStatusLine, @@ -669,8 +710,9 @@ pub const Request = struct { assert(len <= buffer.len); var index: usize = 0; while (index < len) { - const zero_means_end = req.response.state.zeroMeansEnd(); const amt = try readAdvanced(req, buffer[index..]); + const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect; + if (amt == 0 and zero_means_end) break; index += amt; } @@ -680,7 +722,29 @@ pub const Request = struct { /// This one can return 0 without meaning EOF. /// TODO change to readvAdvanced pub fn readAdvanced(req: *Request, buffer: []u8) !usize { - var in = buffer[0..try req.connection.read(buffer)]; + if (req.response.done) { + if (req.response.headers.status.class() == .redirect) { + if (req.redirects_left == 0) return error.TooManyHttpRedirects; + + const location = req.response.headers.location orelse + return error.HttpRedirectMissingLocation; + const new_url = try std.Uri.parse(location); + const new_req = try req.client.request(new_url, req.headers, .{ + .max_redirects = req.redirects_left - 1, + .header_strategy = if (req.response.header_bytes_owned) .{ + .dynamic = req.response.max_header_bytes, + } else .{ + .static = req.response.header_bytes.unusedCapacitySlice(), + }, + }); + req.deinit(); + req.* = new_req; + } else { + return 0; + } + } + + var in = buffer[0..try req.connection.data.read(buffer)]; var out_index: usize = 0; while (true) { switch (req.response.state) { @@ -698,24 +762,10 @@ pub const Request = struct { if (req.response.state == .finished) { req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); - if (req.response.headers.status.class() == .redirect) { - if (req.redirects_left == 0) return error.TooManyHttpRedirects; - const location = req.response.headers.location orelse - return error.HttpRedirectMissingLocation; - const new_url = try std.Uri.parse(location); - const new_req = try req.client.request(new_url, req.headers, .{ - .max_redirects = req.redirects_left - 1, - .header_strategy = if (req.response.header_bytes_owned) .{ - .dynamic = req.response.max_header_bytes, - } else .{ - .static = req.response.header_bytes.unusedCapacitySlice(), - }, - }); - req.deinit(); - req.* = new_req; - assert(out_index == 0); - in = buffer[0..try req.connection.read(buffer)]; - continue; + if (req.response.headers.connection_close == true) { + req.connection.data.unusable = true; + } else { + req.connection.data.unusable = false; } if (req.response.headers.transfer_encoding) |transfer_encoding| { @@ -742,11 +792,29 @@ pub const Request = struct { return 0; }, .finished => { + const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len)); + req.response.next_chunk_length -= sub_amt; + + if (req.response.next_chunk_length == 0) { + req.client.release(req.connection); + req.connection = undefined; + + req.response.done = true; + assert(in.len == sub_amt); // TODO: figure out how to not read more than necessary. + + if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0; + + mem.copy(u8, buffer[out_index..], in[0..sub_amt]); + return out_index + sub_amt; + } + + if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0; + if (in.ptr == buffer.ptr) { - return in.len; + return sub_amt; } else { - mem.copy(u8, buffer[out_index..], in); - return out_index + in.len; + mem.copy(u8, buffer[out_index..], in[0..sub_amt]); + return out_index + sub_amt; } }, .chunk_size_prefix_r => switch (in.len) { @@ -793,7 +861,10 @@ pub const Request = struct { .invalid => return error.HttpHeadersInvalid, .chunk_data => { if (req.response.next_chunk_length == 0) { - req.response.state = .start; + req.response.done = true; + req.client.release(req.connection); + req.connection = undefined; + return out_index; } in = in[i..]; @@ -807,20 +878,27 @@ pub const Request = struct { // TODO https://github.com/ziglang/zig/issues/14039 const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len)); req.response.next_chunk_length -= sub_amt; - if (req.response.next_chunk_length > 0) { - if (in.ptr == buffer.ptr) { - return sub_amt; - } else { - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; - return out_index; - } + + if (req.response.next_chunk_length == 0) { + req.response.state = .chunk_size_prefix_r; + in = in[sub_amt..]; + + if (req.response.headers.status.class() == .redirect) continue; + + mem.copy(u8, buffer[out_index..], in[0..sub_amt]); + out_index += sub_amt; + continue; + } + + if (req.response.headers.status.class() == .redirect) return 0; + + if (in.ptr == buffer.ptr) { + return sub_amt; + } else { + mem.copy(u8, buffer[out_index..], in[0..sub_amt]); + out_index += sub_amt; + return out_index; } - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; - req.response.state = .chunk_size_prefix_r; - in = in[sub_amt..]; - continue; }, } } @@ -844,24 +922,52 @@ pub const Request = struct { }; pub fn deinit(client: *Client) void { + var next = client.connection_pool.first; + while (next) |node| { + next = node.next; + + node.data.close(client); + + client.allocator.destroy(node); + } + client.ca_bundle.deinit(client.allocator); client.* = undefined; } -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !Connection { - var conn: Connection = .{ +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !*ConnectionNode { + var potential = client.connection_pool.last; + while (potential) |node| { + const same_host = mem.eql(u8, node.data.host, host); + const same_port = node.data.port == port; + const same_protocol = node.data.protocol == protocol; + + if (same_host and same_port and same_protocol) { + client.connection_pool.remove(node); + return node; + } + + potential = node.prev; + } + + const conn = try client.allocator.create(ConnectionNode); + errdefer client.allocator.destroy(conn); + + conn.* = .{ .data = .{ .stream = try net.tcpConnectToHost(client.allocator, host, port), .tls_client = undefined, .protocol = protocol, - }; + .host = try client.allocator.dupe(u8, host), + .port = port, + } }; switch (protocol) { .plain => {}, .tls => { - conn.tls_client = try std.crypto.tls.Client.init(conn.stream, client.ca_bundle, host); + conn.data.tls_client = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host); // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. - conn.tls_client.allow_truncation_attacks = true; + conn.data.tls_client.allow_truncation_attacks = true; }, } @@ -908,10 +1014,15 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req try h.appendSlice(@tagName(headers.version)); try h.appendSlice("\r\nHost: "); try h.appendSlice(host); - try h.appendSlice("\r\nConnection: close\r\n\r\n"); + if (headers.connection_close) { + try h.appendSlice("\r\nConnection: close"); + } else { + try h.appendSlice("\r\nConnection: keep-alive"); + } + try h.appendSlice("\r\n\r\n"); const header_bytes = h.slice(); - try req.connection.writeAll(header_bytes); + try req.connection.data.writeAll(header_bytes); } return req; From 8d86194b6e31788263d2cbdd03e2a8cde4134c37 Mon Sep 17 00:00:00 2001 From: Nameless Date: Mon, 6 Mar 2023 20:11:56 -0600 Subject: [PATCH 2/6] add error sets to tcpConnect* and tls.Client.init --- lib/std/crypto/tls/Client.zig | 50 ++++++++++++++++++++++++++++++++++- lib/std/net.zig | 34 +++++++++++++++++++++--- 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 627ad7ea59..01bf957820 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -88,11 +88,59 @@ pub const StreamInterface = struct { } }; +pub fn InitError(comptime Stream: type) type { + return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error { + InsufficientEntropy, + DiskQuota, + LockViolation, + NotOpenForWriting, + TlsAlert, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsDecryptFailure, + TlsRecordOverflow, + TlsBadRecordMac, + CertificateFieldHasInvalidLength, + CertificateHostMismatch, + CertificatePublicKeyInvalid, + CertificateExpired, + CertificateFieldHasWrongDataType, + CertificateIssuerMismatch, + CertificateNotYetValid, + CertificateSignatureAlgorithmMismatch, + CertificateSignatureAlgorithmUnsupported, + CertificateSignatureInvalid, + CertificateSignatureInvalidLength, + CertificateSignatureNamedCurveUnsupported, + CertificateSignatureUnsupportedBitCount, + TlsCertificateNotVerified, + TlsBadSignatureScheme, + TlsBadRsaSignatureBitCount, + InvalidEncoding, + IdentityElement, + SignatureVerificationFailed, + TlsDecryptError, + TlsConnectionTruncated, + TlsDecodeError, + UnsupportedCertificateVersion, + CertificateTimeInvalid, + CertificateHasUnrecognizedObjectId, + CertificateHasInvalidBitString, + MessageTooLong, + NegativeIntoUnsigned, + TargetTooSmall, + BufferTooSmall, + InvalidSignature, + NotSquare, + NonCanonical, + }; +} + /// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which /// must conform to `StreamInterface`. /// /// `host` is only borrowed during this function call. -pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !Client { +pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client { const host_len = @intCast(u16, host.len); var random_buffer: [128]u8 = undefined; diff --git a/lib/std/net.zig b/lib/std/net.zig index 50a0f8b9d7..cf112cbab9 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -702,8 +702,10 @@ pub const AddressList = struct { } }; +pub const TcpConnectToHostError = GetAddressListError || TcpConnectToAddressError; + /// All memory allocated with `allocator` will be freed before this function returns. -pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) !Stream { +pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) TcpConnectToHostError!Stream { const list = try getAddressList(allocator, name, port); defer list.deinit(); @@ -720,7 +722,9 @@ pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) ! return std.os.ConnectError.ConnectionRefused; } -pub fn tcpConnectToAddress(address: Address) !Stream { +pub const TcpConnectToAddressError = std.os.SocketError || std.os.ConnectError; + +pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { const nonblock = if (std.io.is_async) os.SOCK.NONBLOCK else 0; const sock_flags = os.SOCK.STREAM | nonblock | (if (builtin.target.os.tag == .windows) 0 else os.SOCK.CLOEXEC); @@ -737,8 +741,32 @@ pub fn tcpConnectToAddress(address: Address) !Stream { return Stream{ .handle = sockfd }; } +const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error { + // TODO: break this up into error sets from the various underlying functions + + TemporaryNameServerFailure, + NameServerFailure, + AddressFamilyNotSupported, + UnknownHostName, + ServiceUnavailable, + Unexpected, + + HostLacksNetworkAddresses, + + InvalidCharacter, + InvalidEnd, + NonCanonical, + Overflow, + Incomplete, + InvalidIpv4Mapping, + InvalidIPAddressFormat, + + InterfaceNotFound, + FileSystem, +}; + /// Call `AddressList.deinit` on the result. -pub fn getAddressList(allocator: mem.Allocator, name: []const u8, port: u16) !*AddressList { +pub fn getAddressList(allocator: mem.Allocator, name: []const u8, port: u16) GetAddressListError!*AddressList { const result = blk: { var arena = std.heap.ArenaAllocator.init(allocator); errdefer arena.deinit(); From fd2f906d1ede2b65ba21eec59137b2d4b676eedc Mon Sep 17 00:00:00 2001 From: Nameless Date: Mon, 6 Mar 2023 20:13:15 -0600 Subject: [PATCH 3/6] std.http: handle compressed payloads --- lib/std/http.zig | 10 + lib/std/http/Client.zig | 769 ++++++++++++++++++++++++++-------------- 2 files changed, 504 insertions(+), 275 deletions(-) diff --git a/lib/std/http.zig b/lib/std/http.zig index 7c2a2da605..d4cc259f19 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -253,6 +253,16 @@ pub const TransferEncoding = enum { gzip, }; +pub const Connection = enum { + keep_alive, + close, +}; + +pub const CustomHeader = struct { + name: []const u8, + value: []const u8, +}; + const std = @import("std.zig"); test { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index d4d8f85ad1..cac6571798 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -21,27 +21,51 @@ ca_bundle: std.crypto.Certificate.Bundle = .{}, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, -connection_pool: std.TailQueue(Connection) = .{}, +connection_mutex: std.Thread.Mutex = .{}, +connection_pool: ConnectionPool = .{}, +connection_used: ConnectionPool = .{}, const ConnectionPool = std.TailQueue(Connection); const ConnectionNode = ConnectionPool.Node; -pub fn release(client: *Client, node: *ConnectionNode) void { - if (node.data.unusable) return node.data.close(client); +/// Acquires an existing connection from the connection pool. This function is threadsafe. +pub fn acquire(client: *Client, node: *ConnectionNode) void { + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + client.connection_pool.remove(node); + client.connection_used.append(node); +} + +/// Tries to release a connection back to the connection pool. This function is threadsafe. +/// If the connection is marked as closing, it will be closed instead. +pub fn release(client: *Client, node: *ConnectionNode) void { + if (node.data.closing) { + node.data.close(client); + + return client.allocator.destroy(node); + } + + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + + client.connection_used.remove(node); client.connection_pool.append(node); } +const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw); +const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw); + pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB. + tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB. protocol: Protocol, host: []u8, port: u16, // This connection has been part of a non keepalive request and cannot be added to the pool. - unusable: bool = false, + closing: bool = false, pub const Protocol = enum { plain, tls }; @@ -59,6 +83,24 @@ pub const Connection = struct { } } + pub const ReadError = std.net.Stream.ReadError || error{ + TlsConnectionTruncated, + TlsRecordOverflow, + TlsDecodeError, + TlsAlert, + TlsBadRecordMac, + Overflow, + TlsBadLength, + TlsIllegalParameter, + TlsUnexpectedMessage, + }; + + pub const Reader = std.io.Reader(*Connection, ReadError, read); + + pub fn reader(conn: *Connection) Reader { + return Reader{ .context = conn }; + } + pub fn writeAll(conn: *Connection, buffer: []const u8) !void { switch (conn.protocol) { .plain => return conn.stream.writeAll(buffer), @@ -73,10 +115,18 @@ pub const Connection = struct { } } + pub const WriteError = std.net.Stream.WriteError || error{}; + pub const Writer = std.io.Writer(*Connection, WriteError, write); + + pub fn writer(conn: *Connection) Writer { + return Writer{ .context = conn }; + } + pub fn close(conn: *Connection, client: *const Client) void { if (conn.protocol == .tls) { // try to cleanly close the TLS connection, for any server that cares. _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + client.allocator.destroy(conn.tls_client); } conn.stream.close(); @@ -85,10 +135,10 @@ pub const Connection = struct { } }; -/// TODO: emit error.UnexpectedEndOfStream or something like that when the read -/// data does not match the content length. This is necessary since HTTPS disables -/// close_notify protection on underlying TLS streams. pub const Request = struct { + const read_buffer_size = 8192; + const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size); + client: *Client, connection: *ConnectionNode, redirects_left: u32, @@ -97,6 +147,11 @@ pub const Request = struct { /// redirects. headers: Headers, + /// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning. + read_buffer: [read_buffer_size]u8 = undefined, + read_buffer_start: ReadBufferIndex = 0, + read_buffer_len: ReadBufferIndex = 0, + pub const Response = struct { headers: Response.Headers, state: State, @@ -106,15 +161,24 @@ pub const Request = struct { header_bytes: std.ArrayListUnmanaged(u8), max_header_bytes: usize, next_chunk_length: u64, - done: bool, + done: bool = false, + + compression: union(enum) { + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + none: void, + } = .none, pub const Headers = struct { status: http.Status, version: http.Version, location: ?[]const u8 = null, content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - connection_close: bool = true, + transfer_encoding: ?http.TransferEncoding = null, // This should only ever be chunked, compression is handled separately. + transfer_compression: ?http.TransferEncoding = null, + connection: http.Connection = .close, + + number_of_headers: usize = 0, pub fn parse(bytes: []const u8) !Response.Headers { var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); @@ -137,6 +201,8 @@ pub const Request = struct { }; while (it.next()) |line| { + headers.number_of_headers += 1; + if (line.len == 0) return error.HttpHeadersInvalid; switch (line[0]) { ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, @@ -152,14 +218,65 @@ pub const Request = struct { if (headers.content_length != null) return error.HttpHeadersInvalid; headers.content_length = try std.fmt.parseInt(u64, header_value, 10); } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; - headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse + if (headers.transfer_encoding != null or headers.transfer_compression != null) return error.HttpHeadersInvalid; + + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = std.mem.splitBackwards(u8, header_value, ","); + + if (iter.next()) |first| { + const kind = std.meta.stringToEnum( + http.TransferEncoding, + std.mem.trim(u8, first, " "), + ) orelse + return error.HttpTransferEncodingUnsupported; + + switch (kind) { + .chunked => headers.transfer_encoding = .chunked, + .compress => headers.transfer_compression = .compress, + .deflate => headers.transfer_compression = .deflate, + .gzip => headers.transfer_compression = .gzip, + } + } + + if (iter.next()) |second| { + if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + + const kind = std.meta.stringToEnum( + http.TransferEncoding, + std.mem.trim(u8, second, " "), + ) orelse + return error.HttpTransferEncodingUnsupported; + + switch (kind) { + .chunked => return error.HttpHeadersInvalid, // chunked must come last + .compress => return error.HttpTransferEncodingUnsupported, // compress not supported + .deflate => headers.transfer_compression = .deflate, + .gzip => headers.transfer_compression = .gzip, + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (headers.transfer_compression != null) return error.HttpHeadersInvalid; + + const kind = std.meta.stringToEnum( + http.TransferEncoding, + std.mem.trim(u8, header_value, " "), + ) orelse return error.HttpTransferEncodingUnsupported; + + switch (kind) { + .chunked => return error.HttpHeadersInvalid, // not transfer encoding + .compress => return error.HttpTransferEncodingUnsupported, // compress not supported + .deflate => headers.transfer_compression = .deflate, + .gzip => headers.transfer_compression = .gzip, + } } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { - headers.connection_close = false; + headers.connection = .keep_alive; } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { - headers.connection_close = true; + headers.connection = .close; } else { return error.HttpConnectionHeaderUnsupported; } @@ -238,7 +355,6 @@ pub const Request = struct { .max_header_bytes = max, .header_bytes_owned = true, .next_chunk_length = undefined, - .done = false, }; } @@ -250,7 +366,6 @@ pub const Request = struct { .max_header_bytes = buf.len, .header_bytes_owned = false, .next_chunk_length = undefined, - .done = false, }; } @@ -537,10 +652,19 @@ pub const Request = struct { } }; + pub const RequestTransfer = union(enum) { + content_length: u64, + chunked: void, + none: void, + }; + pub const Headers = struct { version: http.Version = .@"HTTP/1.1", method: http.Method = .GET, - connection_close: bool = false, + connection: http.Connection = .keep_alive, + transfer_encoding: RequestTransfer = .none, + + custom: []const http.CustomHeader = &[_]http.CustomHeader{}, }; pub const Options = struct { @@ -561,167 +685,131 @@ pub const Request = struct { }; }; - /// May be skipped if header strategy is buffer. + /// Frees all resources associated with the request. pub fn deinit(req: *Request) void { + switch (req.response.compression) { + .none => {}, + .deflate => |*deflate| deflate.deinit(), + .gzip => |*gzip| gzip.deinit(), + } + if (req.response.header_bytes_owned) { req.response.header_bytes.deinit(req.client.allocator); } + + if (!req.response.done) { + // If the response wasn't fully read, then we need to close the connection. + req.connection.data.closing = true; + req.client.release(req.connection); + } + req.* = undefined; } - pub const Reader = std.io.Reader(*Request, ReadError, read); - - pub fn reader(req: *Request) Reader { - return .{ .context = req }; - } - - pub fn readAll(req: *Request, buffer: []u8) !usize { - return readAtLeast(req, buffer, buffer.len); - } - - pub const ReadError = net.Stream.ReadError || error{ - // From HTTP protocol - HttpHeadersInvalid, - HttpHeadersExceededSizeLimit, - HttpRedirectMissingLocation, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - HttpContentLengthUnknown, + const ReadRawError = Connection.ReadError || std.Uri.ParseError || RequestError || error{ + UnexpectedEndOfStream, TooManyHttpRedirects, - ShortHttpStatusLine, - BadHttpVersion, - HttpHeaderContinuationsUnsupported, - UnsupportedUrlScheme, - UriMissingHost, - UnknownHostName, - - // Network problems - NetworkUnreachable, - HostLacksNetworkAddresses, - TemporaryNameServerFailure, - NameServerFailure, - ProtocolFamilyNotAvailable, - ProtocolNotSupported, - - // System resource problems - ProcessFdQuotaExceeded, - SystemFdQuotaExceeded, - OutOfMemory, - - // TLS problems - InsufficientEntropy, - TlsConnectionTruncated, - TlsRecordOverflow, - TlsDecodeError, - TlsAlert, - TlsBadRecordMac, - TlsBadLength, - TlsIllegalParameter, - TlsUnexpectedMessage, - TlsDecryptFailure, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - TlsDecryptError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - CertificateAuthorityBundleTooBig, - - // TODO: convert to higher level errors - InvalidFormat, - InvalidPort, - UnexpectedCharacter, - Overflow, - InvalidCharacter, - AddressFamilyNotSupported, - AddressInUse, - AddressNotAvailable, - ConnectionPending, - ConnectionRefused, - FileNotFound, - PermissionDenied, - ServiceUnavailable, - SocketTypeNotSupported, - FileTooBig, - LockViolation, - NoSpaceLeft, - NotOpenForWriting, - InvalidEncoding, - IdentityElement, - NonCanonical, - SignatureVerificationFailed, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - DiskQuota, - InvalidEnd, - Incomplete, - InvalidIpv4Mapping, - InvalidIPAddressFormat, - BadPathName, - DeviceBusy, - FileBusy, - FileLocksNotSupported, - InvalidHandle, - InvalidUtf8, - NameTooLong, - NoDevice, - PathAlreadyExists, - PipeBusy, - SharingViolation, - SymLinkLoop, - FileSystem, - InterfaceNotFound, - AlreadyBound, - FileDescriptorNotASocket, - NetworkSubsystemFailed, - NotDir, - ReadOnlyFileSystem, - Unseekable, - MissingEndCertificateMarker, - InvalidPadding, - EndOfStream, - InvalidArgument, + HttpRedirectMissingLocation, + HttpHeadersInvalid, }; - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - return readAtLeast(req, buffer, 1); - } + const ReaderRaw = std.io.Reader(*Request, ReadRawError, readRaw); + + /// Read from the underlying stream, without decompressing or parsing the headers. Must be called + /// after waitForCompleteHead() has returned successfully. + pub fn readRaw(req: *Request, buffer: []u8) ReadRawError!usize { + assert(req.response.state.isContent()); - pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { - assert(len <= buffer.len); var index: usize = 0; - while (index < len) { - const amt = try readAdvanced(req, buffer[index..]); + while (index == 0) { + const amt = try req.readRawAdvanced(buffer[index..]); const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect; if (amt == 0 and zero_means_end) break; index += amt; } + return index; } + fn checkForCompleteHead(req: *Request, buffer: []u8) !usize { + switch (req.response.state) { + .invalid => unreachable, + .start, .seen_r, .seen_rn, .seen_rnr => {}, + else => return 0, // No more headers to read. + } + + const i = req.response.findHeadersEnd(buffer[0..]); + if (req.response.state == .invalid) return error.HttpHeadersInvalid; + + const headers_data = buffer[0..i]; + if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { + return error.HttpHeadersExceededSizeLimit; + } + try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); + + if (req.response.state == .finished) { + req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); + + if (req.response.headers.connection == .keep_alive) { + req.connection.data.closing = false; + } else { + req.connection.data.closing = true; + } + + if (req.response.headers.transfer_encoding) |transfer_encoding| { + switch (transfer_encoding) { + .chunked => { + req.response.next_chunk_length = 0; + req.response.state = .chunk_size; + }, + .compress => unreachable, + .deflate => unreachable, + .gzip => unreachable, + } + } else if (req.response.headers.content_length) |content_length| { + req.response.next_chunk_length = content_length; + } else { + req.response.done = true; + } + + return i; + } + + return 0; + } + + pub const WaitForCompleteHeadError = ReadRawError || error { + UnexpectedEndOfStream, + + HttpHeadersExceededSizeLimit, + ShortHttpStatusLine, + BadHttpVersion, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + }; + + /// Reads a complete response head. Any leftover data is stored in the request. This function is idempotent. + pub fn waitForCompleteHead(req: *Request) WaitForCompleteHeadError!void { + if (req.response.state.isContent()) return; + + while (true) { + const nread = try req.connection.data.read(req.read_buffer[0..]); + const amt = try checkForCompleteHead(req, req.read_buffer[0..nread]); + + if (amt != 0) { + req.read_buffer_start = @intCast(ReadBufferIndex, amt); + req.read_buffer_len = @intCast(ReadBufferIndex, nread); + return; + } else if (nread == 0) { + return error.UnexpectedEndOfStream; + } + } + } + /// This one can return 0 without meaning EOF. - /// TODO change to readvAdvanced - pub fn readAdvanced(req: *Request, buffer: []u8) !usize { + fn readRawAdvanced(req: *Request, buffer: []u8) !usize { if (req.response.done) { if (req.response.headers.status.class() == .redirect) { if (req.redirects_left == 0) return error.TooManyHttpRedirects; @@ -744,82 +832,56 @@ pub const Request = struct { } } - var in = buffer[0..try req.connection.data.read(buffer)]; + // var in: []const u8 = undefined; + if (req.read_buffer_start == req.read_buffer_len) { + const nread = try req.connection.data.read(req.read_buffer[0..]); + if (nread == 0) return error.UnexpectedEndOfStream; + + req.read_buffer_start = 0; + req.read_buffer_len = @intCast(ReadBufferIndex, nread); + } + var out_index: usize = 0; while (true) { switch (req.response.state) { - .invalid => unreachable, - .start, .seen_r, .seen_rn, .seen_rnr => { - const i = req.response.findHeadersEnd(in); - if (req.response.state == .invalid) return error.HttpHeadersInvalid; - - const headers_data = in[0..i]; - if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { - return error.HttpHeadersExceededSizeLimit; - } - try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); - - if (req.response.state == .finished) { - req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); - - if (req.response.headers.connection_close == true) { - req.connection.data.unusable = true; - } else { - req.connection.data.unusable = false; - } - - if (req.response.headers.transfer_encoding) |transfer_encoding| { - switch (transfer_encoding) { - .chunked => { - req.response.next_chunk_length = 0; - req.response.state = .chunk_size; - }, - .compress => return error.HttpTransferEncodingUnsupported, - .deflate => return error.HttpTransferEncodingUnsupported, - .gzip => return error.HttpTransferEncodingUnsupported, - } - } else if (req.response.headers.content_length) |content_length| { - req.response.next_chunk_length = content_length; - } else { - return error.HttpContentLengthUnknown; - } - - in = in[i..]; - continue; - } - - assert(out_index == 0); - return 0; - }, + .invalid, .start, .seen_r, .seen_rn, .seen_rnr => unreachable, .finished => { - const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len)); - req.response.next_chunk_length -= sub_amt; + // TODO https://github.com/ziglang/zig/issues/14039 + const buf_avail = req.read_buffer_len - req.read_buffer_start; + const data_avail = req.response.next_chunk_length; + const out_avail = buffer.len; + + if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) { + const can_read = @intCast(usize, @min(buf_avail, data_avail)); + req.response.next_chunk_length -= can_read; + + if (req.response.next_chunk_length == 0) { + req.client.release(req.connection); + req.connection = undefined; + req.response.done = true; + continue; + } + + return 0; // skip over as much data as possible + } + + const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); + req.response.next_chunk_length -= can_read; + + mem.copy(u8, buffer[0..], req.read_buffer[req.read_buffer_start..][0..can_read]); + req.read_buffer_start += @intCast(ReadBufferIndex, can_read); if (req.response.next_chunk_length == 0) { req.client.release(req.connection); req.connection = undefined; - req.response.done = true; - assert(in.len == sub_amt); // TODO: figure out how to not read more than necessary. - - if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0; - - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - return out_index + sub_amt; } - if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0; - - if (in.ptr == buffer.ptr) { - return sub_amt; - } else { - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - return out_index + sub_amt; - } + return can_read; }, - .chunk_size_prefix_r => switch (in.len) { + .chunk_size_prefix_r => switch (req.read_buffer_len - req.read_buffer_start) { 0 => return out_index, - 1 => switch (in[0]) { + 1 => switch (req.read_buffer[req.read_buffer_start]) { '\r' => { req.response.state = .chunk_size_prefix_n; return out_index; @@ -829,9 +891,9 @@ pub const Request = struct { return error.HttpHeadersInvalid; }, }, - else => switch (int16(in[0..2])) { + else => switch (int16(req.read_buffer[req.read_buffer_start..][0..2])) { int16("\r\n") => { - in = in[2..]; + req.read_buffer_start += 2; req.response.state = .chunk_size; continue; }, @@ -841,11 +903,11 @@ pub const Request = struct { }, }, }, - .chunk_size_prefix_n => switch (in.len) { + .chunk_size_prefix_n => switch (req.read_buffer_len - req.read_buffer_start) { 0 => return out_index, - else => switch (in[0]) { + else => switch (req.read_buffer[req.read_buffer_start]) { '\n' => { - in = in[1..]; + req.read_buffer_start += 1; req.response.state = .chunk_size; continue; }, @@ -856,7 +918,7 @@ pub const Request = struct { }, }, .chunk_size, .chunk_r => { - const i = req.response.findChunkedLen(in); + const i = req.response.findChunkedLen(req.read_buffer[req.read_buffer_start..req.read_buffer_len]); switch (req.response.state) { .invalid => return error.HttpHeadersInvalid, .chunk_data => { @@ -867,7 +929,8 @@ pub const Request = struct { return out_index; } - in = in[i..]; + + req.read_buffer_start += @intCast(ReadBufferIndex, i); continue; }, .chunk_size => return out_index, @@ -876,34 +939,129 @@ pub const Request = struct { }, .chunk_data => { // TODO https://github.com/ziglang/zig/issues/14039 - const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len)); - req.response.next_chunk_length -= sub_amt; + const buf_avail = req.read_buffer_len - req.read_buffer_start; + const data_avail = req.response.next_chunk_length; + const out_avail = buffer.len - out_index; + + if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) { + const can_read = @intCast(usize, @min(buf_avail, data_avail)); + req.response.next_chunk_length -= can_read; + + if (req.response.next_chunk_length == 0) { + req.client.release(req.connection); + req.connection = undefined; + req.response.done = true; + continue; + } + + return 0; // skip over as much data as possible + } + + const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); + req.response.next_chunk_length -= can_read; + + mem.copy(u8, buffer[out_index..], req.read_buffer[req.read_buffer_start..][0..can_read]); + req.read_buffer_start += @intCast(ReadBufferIndex, can_read); + out_index += can_read; if (req.response.next_chunk_length == 0) { req.response.state = .chunk_size_prefix_r; - in = in[sub_amt..]; - if (req.response.headers.status.class() == .redirect) continue; - - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; continue; } - if (req.response.headers.status.class() == .redirect) return 0; - - if (in.ptr == buffer.ptr) { - return sub_amt; - } else { - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; - return out_index; - } + return out_index; }, } } } + pub const ReadError = DeflateDecompressor.Error || GzipDecompressor.Error || WaitForCompleteHeadError || error{ + BadHeader, + InvalidCompression, + StreamTooLong, + InvalidWindowSize, + }; + + pub const Reader = std.io.Reader(*Request, ReadError, read); + + pub fn reader(req: *Request) Reader { + return .{ .context = req }; + } + + pub fn read(req: *Request, buffer: []u8) ReadError!usize { + if (!req.response.state.isContent()) try req.waitForCompleteHead(); + + if (req.response.compression == .none and req.response.state.isContent()) { + if (req.response.headers.transfer_compression) |compression| { + switch (compression) { + .compress => unreachable, + .deflate => req.response.compression = .{ + .deflate = try std.compress.zlib.zlibStream(req.client.allocator, ReaderRaw{ .context = req }), + }, + .gzip => req.response.compression = .{ + .gzip = try std.compress.gzip.decompress(req.client.allocator, ReaderRaw{ .context = req }), + }, + .chunked => unreachable, + } + } + } + + return switch (req.response.compression) { + .deflate => |*deflate| try deflate.read(buffer), + .gzip => |*gzip| try gzip.read(buffer), + else => try req.readRaw(buffer), + }; + } + + pub fn readAll(req: *Request, buffer: []u8) !usize { + var index: usize = 0; + while (index < buffer.len) { + const amt = try read(req, buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + pub const WriteError = Connection.WriteError || error{MessageTooLong}; + + pub const Writer = std.io.Writer(*Request, WriteError, write); + + pub fn writer(req: *Request) Writer { + return .{ .context = req }; + } + + /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + pub fn write(req: *Request, bytes: []const u8) !usize { + switch (req.headers.transfer_encoding) { + .chunked => { + try req.connection.data.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.data.writeAll(bytes); + try req.connection.data.writeAll("\r\n"); + + return bytes.len; + }, + .content_length => |*len| { + if (len.* < bytes.len) return error.MessageTooLong; + + const amt = try req.connection.data.write(bytes); + len.* -= amt; + return amt; + }, + .none => return error.NotWriteable, + } + } + + /// Finish the body of a request. This notifies the server that you have no more data to send. + pub fn finish(req: *Request) !void { + switch (req.headers.transfer_encoding) { + .chunked => try req.connection.data.writeAll("0\r\n"), + .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .none => {}, + } + } + inline fn int16(array: *const [2]u8) u16 { return @bitCast(u16, array.*); } @@ -917,6 +1075,10 @@ pub const Request = struct { } test { + const builtin = @import("builtin"); + + if (builtin.os.tag == .wasi) return error.SkipZigTest; + _ = Response; } }; @@ -931,23 +1093,39 @@ pub fn deinit(client: *Client) void { client.allocator.destroy(node); } + next = client.connection_used.first; + while (next) |node| { + next = node.next; + + node.data.close(client); + + client.allocator.destroy(node); + } + client.ca_bundle.deinit(client.allocator); client.* = undefined; } -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !*ConnectionNode { - var potential = client.connection_pool.last; - while (potential) |node| { - const same_host = mem.eql(u8, node.data.host, host); - const same_port = node.data.port == port; - const same_protocol = node.data.protocol == protocol; +pub const ConnectError = std.mem.Allocator.Error || std.net.TcpConnectToHostError || std.crypto.tls.Client.InitError(std.net.Stream); - if (same_host and same_port and same_protocol) { - client.connection_pool.remove(node); - return node; +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode { + { // Search through the connection pool for a potential connection. + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + + var potential = client.connection_pool.last; + while (potential) |node| { + const same_host = mem.eql(u8, node.data.host, host); + const same_port = node.data.port == port; + const same_protocol = node.data.protocol == protocol; + + if (same_host and same_port and same_protocol) { + client.acquire(node); + return node; + } + + potential = node.prev; } - - potential = node.prev; } const conn = try client.allocator.create(ConnectionNode); @@ -964,17 +1142,35 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio switch (protocol) { .plain => {}, .tls => { - conn.data.tls_client = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host); + conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host); // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. conn.data.tls_client.allow_truncation_attacks = true; }, } + { + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + + client.connection_used.append(conn); + } + return conn; } -pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) !Request { +pub const RequestError = ConnectError || Connection.WriteError || error{ + UnsupportedUrlScheme, + UriMissingHost, + + CertificateAuthorityBundleTooBig, + InvalidPadding, + MissingEndCertificateMarker, + Unseekable, +}; + +pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request { const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http")) .plain else if (mem.eql(u8, uri.scheme, "https")) @@ -990,8 +1186,13 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req const host = uri.host orelse return error.UriMissingHost; if (client.next_https_rescan_certs and protocol == .tls) { - try client.ca_bundle.rescan(client.allocator); - client.next_https_rescan_certs = false; + client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex. + defer client.connection_mutex.unlock(); + + if (client.next_https_rescan_certs) { + try client.ca_bundle.rescan(client.allocator); + client.next_https_rescan_certs = false; + } } var req: Request = .{ @@ -1006,23 +1207,39 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req }; { - var h = try std.BoundedArray(u8, 1000).init(0); - try h.appendSlice(@tagName(headers.method)); - try h.appendSlice(" "); - try h.appendSlice(uri.path); - try h.appendSlice(" "); - try h.appendSlice(@tagName(headers.version)); - try h.appendSlice("\r\nHost: "); - try h.appendSlice(host); - if (headers.connection_close) { - try h.appendSlice("\r\nConnection: close"); - } else { - try h.appendSlice("\r\nConnection: keep-alive"); - } - try h.appendSlice("\r\n\r\n"); + var buffered = std.io.bufferedWriter(req.connection.data.writer()); + const writer = buffered.writer(); - const header_bytes = h.slice(); - try req.connection.data.writeAll(header_bytes); + try writer.writeAll(@tagName(headers.method)); + try writer.writeByte(' '); + try writer.writeAll(uri.path); + try writer.writeByte(' '); + try writer.writeAll(@tagName(headers.version)); + try writer.writeAll("\r\nHost: "); + try writer.writeAll(host); + if (headers.connection == .close) { + try writer.writeAll("\r\nConnection: close"); + } else { + try writer.writeAll("\r\nConnection: keep-alive"); + } + try writer.writeAll("\r\nAccept-Encoding: gzip, deflate"); + + switch (headers.transfer_encoding) { + .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"), + .content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}), + .none => {}, + } + + for (headers.custom) |header| { + try writer.writeAll("\r\n"); + try writer.writeAll(header.name); + try writer.writeAll(": "); + try writer.writeAll(header.value); + } + + try writer.writeAll("\r\n\r\n"); + + try buffered.flush(); } return req; @@ -1036,5 +1253,7 @@ test { return error.SkipZigTest; } + if (builtin.os.tag == .wasi) return error.SkipZigTest; + _ = Request; } From 0a4130f364c2714b206257d0cf589103da823407 Mon Sep 17 00:00:00 2001 From: Nameless Date: Mon, 6 Mar 2023 23:35:35 -0600 Subject: [PATCH 4/6] std.http: handle relative redirects --- lib/std/Uri.zig | 109 ++++++++++++++++++++++---- lib/std/crypto/tls/Client.zig | 2 +- lib/std/http/Client.zig | 140 +++++++++++++++++++++++----------- lib/std/net.zig | 6 +- 4 files changed, 196 insertions(+), 61 deletions(-) diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 015b6c34f6..eb6311a19b 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -16,15 +16,27 @@ fragment: ?[]const u8, /// Applies URI encoding and replaces all reserved characters with their respective %XX code. pub fn escapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 { + return escapeStringWithFn(allocator, input, isUnreserved); +} + +pub fn escapePath(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 { + return escapeStringWithFn(allocator, input, isPathChar); +} + +pub fn escapeQuery(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 { + return escapeStringWithFn(allocator, input, isQueryChar); +} + +pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]const u8 { var outsize: usize = 0; for (input) |c| { - outsize += if (isUnreserved(c)) @as(usize, 1) else 3; + outsize += if (keepUnescaped(c)) @as(usize, 1) else 3; } var output = try allocator.alloc(u8, outsize); var outptr: usize = 0; for (input) |c| { - if (isUnreserved(c)) { + if (keepUnescaped(c)) { output[outptr] = c; outptr += 1; } else { @@ -94,13 +106,14 @@ pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{Out pub const ParseError = error{ UnexpectedCharacter, InvalidFormat, InvalidPort }; -/// Parses the URI or returns an error. +/// Parses the URI or returns an error. This function is not compliant, but is required to parse +/// some forms of URIs in the wild. Such as HTTP Location headers. /// The return value will contain unescaped strings pointing into the /// original `text`. Each component that is provided, will be non-`null`. -pub fn parse(text: []const u8) ParseError!Uri { +pub fn parseWithoutScheme(text: []const u8) ParseError!Uri { var reader = SliceReader{ .slice = text }; var uri = Uri{ - .scheme = reader.readWhile(isSchemeChar), + .scheme = "", .user = null, .password = null, .host = null, @@ -110,14 +123,6 @@ pub fn parse(text: []const u8) ParseError!Uri { .fragment = null, }; - // after the scheme, a ':' must appear - if (reader.get()) |c| { - if (c != ':') - return error.UnexpectedCharacter; - } else { - return error.InvalidFormat; - } - if (reader.peekPrefix("//")) { // authority part std.debug.assert(reader.get().? == '/'); std.debug.assert(reader.get().? == '/'); @@ -179,6 +184,76 @@ pub fn parse(text: []const u8) ParseError!Uri { return uri; } +/// Parses the URI or returns an error. +/// The return value will contain unescaped strings pointing into the +/// original `text`. Each component that is provided, will be non-`null`. +pub fn parse(text: []const u8) ParseError!Uri { + var reader = SliceReader{ .slice = text }; + const scheme = reader.readWhile(isSchemeChar); + + // after the scheme, a ':' must appear + if (reader.get()) |c| { + if (c != ':') + return error.UnexpectedCharacter; + } else { + return error.InvalidFormat; + } + + var uri = try parseWithoutScheme(reader.readUntilEof()); + uri.scheme = scheme; + + return uri; +} + +/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5. +/// arena owns any memory allocated by this function. +pub fn resolve(Base: Uri, R: Uri, strict: bool, arena: std.mem.Allocator) !Uri { + var T: Uri = undefined; + + if (R.scheme.len > 0 and !((!strict) and (std.mem.eql(u8, R.scheme, Base.scheme)))) { + T.scheme = R.scheme; + T.user = R.user; + T.host = R.host; + T.port = R.port; + T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path }); + T.query = R.query; + } else { + if (R.host) |host| { + T.user = R.user; + T.host = host; + T.port = R.port; + T.path = R.path; + T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path }); + T.query = R.query; + } else { + if (R.path.len == 0) { + T.path = Base.path; + if (R.query) |query| { + T.query = query; + } else { + T.query = Base.query; + } + } else { + if (R.path[0] == '/') { + T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path }); + } else { + T.path = try std.fs.path.resolvePosix(arena, &.{ "/", Base.path, R.path }); + } + T.query = R.query; + } + + T.user = Base.user; + T.host = Base.host; + T.port = Base.port; + } + T.scheme = Base.scheme; + } + + T.fragment = R.fragment; + + return T; +} + const SliceReader = struct { const Self = @This(); @@ -284,6 +359,14 @@ fn isPathSeparator(c: u8) bool { }; } +fn isPathChar(c: u8) bool { + return isUnreserved(c) or isSubLimit(c) or c == '/' or c == ':' or c == '@'; +} + +fn isQueryChar(c: u8) bool { + return isPathChar(c) or c == '?'; +} + fn isQuerySeparator(c: u8) bool { return switch (c) { '#' => true, diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 01bf957820..bc59459ff9 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -89,7 +89,7 @@ pub const StreamInterface = struct { }; pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error { + return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{ InsufficientEntropy, DiskQuota, LockViolation, diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index cac6571798..5b3a74d292 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -29,9 +29,10 @@ const ConnectionPool = std.TailQueue(Connection); const ConnectionNode = ConnectionPool.Node; /// Acquires an existing connection from the connection pool. This function is threadsafe. -pub fn acquire(client: *Client, node: *ConnectionNode) void { - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); +/// If the caller already holds the connection mutex, it should pass `true` for `held`. +pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void { + if (!held) client.connection_mutex.lock(); + defer if (!held) client.connection_mutex.unlock(); client.connection_pool.remove(node); client.connection_used.append(node); @@ -40,16 +41,17 @@ pub fn acquire(client: *Client, node: *ConnectionNode) void { /// Tries to release a connection back to the connection pool. This function is threadsafe. /// If the connection is marked as closing, it will be closed instead. pub fn release(client: *Client, node: *ConnectionNode) void { + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + + client.connection_used.remove(node); + if (node.data.closing) { node.data.close(client); return client.allocator.destroy(node); } - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); - - client.connection_used.remove(node); client.connection_pool.append(node); } @@ -83,7 +85,7 @@ pub const Connection = struct { } } - pub const ReadError = std.net.Stream.ReadError || error{ + pub const ReadError = net.Stream.ReadError || error{ TlsConnectionTruncated, TlsRecordOverflow, TlsDecodeError, @@ -115,7 +117,7 @@ pub const Connection = struct { } } - pub const WriteError = std.net.Stream.WriteError || error{}; + pub const WriteError = net.Stream.WriteError || error{}; pub const Writer = std.io.Writer(*Connection, WriteError, write); pub fn writer(conn: *Connection) Writer { @@ -139,14 +141,21 @@ pub const Request = struct { const read_buffer_size = 8192; const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size); + uri: Uri, client: *Client, connection: *ConnectionNode, - redirects_left: u32, response: Response, /// These are stored in Request so that they are available when following /// redirects. headers: Headers, + redirects_left: u32, + handle_redirects: bool, + compression_init: bool, + + /// Used as a allocator for resolving redirects locations. + arena: std.heap.ArenaAllocator, + /// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning. read_buffer: [read_buffer_size]u8 = undefined, read_buffer_start: ReadBufferIndex = 0, @@ -661,6 +670,7 @@ pub const Request = struct { pub const Headers = struct { version: http.Version = .@"HTTP/1.1", method: http.Method = .GET, + user_agent: []const u8 = "Zig (std.http)", connection: http.Connection = .keep_alive, transfer_encoding: RequestTransfer = .none, @@ -668,6 +678,7 @@ pub const Request = struct { }; pub const Options = struct { + handle_redirects: bool = true, max_redirects: u32 = 3, header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 }, @@ -703,10 +714,11 @@ pub const Request = struct { req.client.release(req.connection); } + req.arena.deinit(); req.* = undefined; } - const ReadRawError = Connection.ReadError || std.Uri.ParseError || RequestError || error{ + const ReadRawError = Connection.ReadError || Uri.ParseError || RequestError || error{ UnexpectedEndOfStream, TooManyHttpRedirects, HttpRedirectMissingLocation, @@ -723,9 +735,7 @@ pub const Request = struct { var index: usize = 0; while (index == 0) { const amt = try req.readRawAdvanced(buffer[index..]); - const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect; - - if (amt == 0 and zero_means_end) break; + if (amt == 0 and req.response.done) break; index += amt; } @@ -769,6 +779,8 @@ pub const Request = struct { } } else if (req.response.headers.content_length) |content_length| { req.response.next_chunk_length = content_length; + + if (content_length == 0) req.response.done = true; } else { req.response.done = true; } @@ -779,7 +791,7 @@ pub const Request = struct { return 0; } - pub const WaitForCompleteHeadError = ReadRawError || error { + pub const WaitForCompleteHeadError = ReadRawError || error{ UnexpectedEndOfStream, HttpHeadersExceededSizeLimit, @@ -810,27 +822,8 @@ pub const Request = struct { /// This one can return 0 without meaning EOF. fn readRawAdvanced(req: *Request, buffer: []u8) !usize { - if (req.response.done) { - if (req.response.headers.status.class() == .redirect) { - if (req.redirects_left == 0) return error.TooManyHttpRedirects; - - const location = req.response.headers.location orelse - return error.HttpRedirectMissingLocation; - const new_url = try std.Uri.parse(location); - const new_req = try req.client.request(new_url, req.headers, .{ - .max_redirects = req.redirects_left - 1, - .header_strategy = if (req.response.header_bytes_owned) .{ - .dynamic = req.response.max_header_bytes, - } else .{ - .static = req.response.header_bytes.unusedCapacitySlice(), - }, - }); - req.deinit(); - req.* = new_req; - } else { - return 0; - } - } + assert(req.response.state.isContent()); + if (req.response.done) return 0; // var in: []const u8 = undefined; if (req.read_buffer_start == req.read_buffer_len) { @@ -851,7 +844,7 @@ pub const Request = struct { const data_avail = req.response.next_chunk_length; const out_avail = buffer.len; - if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) { + if (req.handle_redirects and req.response.headers.status.class() == .redirect) { const can_read = @intCast(usize, @min(buf_avail, data_avail)); req.response.next_chunk_length -= can_read; @@ -859,7 +852,6 @@ pub const Request = struct { req.client.release(req.connection); req.connection = undefined; req.response.done = true; - continue; } return 0; // skip over as much data as possible @@ -943,7 +935,7 @@ pub const Request = struct { const data_avail = req.response.next_chunk_length; const out_avail = buffer.len - out_index; - if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) { + if (req.handle_redirects and req.response.headers.status.class() == .redirect) { const can_read = @intCast(usize, @min(buf_avail, data_avail)); req.response.next_chunk_length -= can_read; @@ -990,9 +982,41 @@ pub const Request = struct { } pub fn read(req: *Request, buffer: []u8) ReadError!usize { - if (!req.response.state.isContent()) try req.waitForCompleteHead(); + while (true) { + if (!req.response.state.isContent()) try req.waitForCompleteHead(); - if (req.response.compression == .none and req.response.state.isContent()) { + if (req.handle_redirects and req.response.headers.status.class() == .redirect) { + assert(try req.readRaw(buffer) == 0); + + if (req.redirects_left == 0) return error.TooManyHttpRedirects; + + const location = req.response.headers.location orelse + return error.HttpRedirectMissingLocation; + const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location); + + var new_arena = std.heap.ArenaAllocator.init(req.client.allocator); + const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator()); + errdefer new_arena.deinit(); + + req.arena.deinit(); + req.arena = new_arena; + + const new_req = try req.client.request(resolved_url, req.headers, .{ + .max_redirects = req.redirects_left - 1, + .header_strategy = if (req.response.header_bytes_owned) .{ + .dynamic = req.response.max_header_bytes, + } else .{ + .static = req.response.header_bytes.unusedCapacitySlice(), + }, + }); + req.deinit(); + req.* = new_req; + } else { + break; + } + } + + if (req.response.compression == .none) { if (req.response.headers.transfer_compression) |compression| { switch (compression) { .compress => unreachable, @@ -1084,6 +1108,8 @@ pub const Request = struct { }; pub fn deinit(client: *Client) void { + client.connection_mutex.lock(); + var next = client.connection_pool.first; while (next) |node| { next = node.next; @@ -1106,7 +1132,7 @@ pub fn deinit(client: *Client) void { client.* = undefined; } -pub const ConnectError = std.mem.Allocator.Error || std.net.TcpConnectToHostError || std.crypto.tls.Client.InitError(std.net.Stream); +pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream); pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode { { // Search through the connection pool for a potential connection. @@ -1120,7 +1146,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio const same_protocol = node.data.protocol == protocol; if (same_host and same_port and same_protocol) { - client.acquire(node); + client.acquire(node, true); return node; } @@ -1168,6 +1194,7 @@ pub const RequestError = ConnectError || Connection.WriteError || error{ InvalidPadding, MissingEndCertificateMarker, Unseekable, + EndOfStream, }; pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request { @@ -1196,27 +1223,52 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req } var req: Request = .{ + .uri = uri, .client = client, .headers = headers, .connection = try client.connect(host, port, protocol), .redirects_left = options.max_redirects, + .handle_redirects = options.handle_redirects, + .compression_init = false, .response = switch (options.header_strategy) { .dynamic => |max| Request.Response.initDynamic(max), .static => |buf| Request.Response.initStatic(buf), }, + .arena = undefined, }; + req.arena = std.heap.ArenaAllocator.init(client.allocator); + { var buffered = std.io.bufferedWriter(req.connection.data.writer()); const writer = buffered.writer(); + const escaped_path = try Uri.escapePath(client.allocator, uri.path); + defer client.allocator.free(escaped_path); + + const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null; + defer if (escaped_query) |q| client.allocator.free(q); + + const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null; + defer if (escaped_fragment) |f| client.allocator.free(f); + try writer.writeAll(@tagName(headers.method)); try writer.writeByte(' '); - try writer.writeAll(uri.path); + try writer.writeAll(escaped_path); + if (escaped_query) |q| { + try writer.writeByte('?'); + try writer.writeAll(q); + } + if (escaped_fragment) |f| { + try writer.writeByte('#'); + try writer.writeAll(f); + } try writer.writeByte(' '); try writer.writeAll(@tagName(headers.version)); try writer.writeAll("\r\nHost: "); try writer.writeAll(host); + try writer.writeAll("\r\nUser-Agent: "); + try writer.writeAll(headers.user_agent); if (headers.connection == .close) { try writer.writeAll("\r\nConnection: close"); } else { diff --git a/lib/std/net.zig b/lib/std/net.zig index cf112cbab9..7222433fd5 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -741,9 +741,9 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { return Stream{ .handle = sockfd }; } -const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error { +const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error{ // TODO: break this up into error sets from the various underlying functions - + TemporaryNameServerFailure, NameServerFailure, AddressFamilyNotSupported, @@ -760,7 +760,7 @@ const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || Incomplete, InvalidIpv4Mapping, InvalidIPAddressFormat, - + InterfaceNotFound, FileSystem, }; From 634e7155048aeaf15553d866783930f3d22b375c Mon Sep 17 00:00:00 2001 From: Nameless Date: Wed, 8 Mar 2023 08:20:53 -0600 Subject: [PATCH 5/6] std.http: split Client's parts into their own files --- lib/std/http.zig | 5 + lib/std/http/Client.zig | 988 +------------------------------ lib/std/http/Client/Request.zig | 488 +++++++++++++++ lib/std/http/Client/Response.zig | 506 ++++++++++++++++ 4 files changed, 1010 insertions(+), 977 deletions(-) create mode 100644 lib/std/http/Client/Request.zig create mode 100644 lib/std/http/Client/Response.zig diff --git a/lib/std/http.zig b/lib/std/http.zig index d4cc259f19..ef89f09925 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -248,9 +248,14 @@ pub const Status = enum(u10) { pub const TransferEncoding = enum { chunked, + // compression is intentionally omitted here, as std.http.Client stores it as content-encoding +}; + +pub const ContentEncoding = enum { compress, deflate, gzip, + zstd, }; pub const Connection = enum { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 5b3a74d292..d44b1d098d 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -13,6 +13,9 @@ const Uri = std.Uri; const Allocator = std.mem.Allocator; const testing = std.testing; +pub const Request = @import("Client/Request.zig"); +pub const Response = @import("Client/Response.zig"); + /// Used for tcpConnectToHost and storing HTTP headers when an externally /// managed buffer is not provided. allocator: Allocator, @@ -25,8 +28,8 @@ connection_mutex: std.Thread.Mutex = .{}, connection_pool: ConnectionPool = .{}, connection_used: ConnectionPool = .{}, -const ConnectionPool = std.TailQueue(Connection); -const ConnectionNode = ConnectionPool.Node; +pub const ConnectionPool = std.TailQueue(Connection); +pub const ConnectionNode = ConnectionPool.Node; /// Acquires an existing connection from the connection pool. This function is threadsafe. /// If the caller already holds the connection mutex, it should pass `true` for `held`. @@ -55,8 +58,9 @@ pub fn release(client: *Client, node: *ConnectionNode) void { client.connection_pool.append(node); } -const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw); -const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw); +pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw); +pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw); +pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.ReaderRaw, .{}); pub const Connection = struct { stream: net.Stream, @@ -137,976 +141,6 @@ pub const Connection = struct { } }; -pub const Request = struct { - const read_buffer_size = 8192; - const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size); - - uri: Uri, - client: *Client, - connection: *ConnectionNode, - response: Response, - /// These are stored in Request so that they are available when following - /// redirects. - headers: Headers, - - redirects_left: u32, - handle_redirects: bool, - compression_init: bool, - - /// Used as a allocator for resolving redirects locations. - arena: std.heap.ArenaAllocator, - - /// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning. - read_buffer: [read_buffer_size]u8 = undefined, - read_buffer_start: ReadBufferIndex = 0, - read_buffer_len: ReadBufferIndex = 0, - - pub const Response = struct { - headers: Response.Headers, - state: State, - header_bytes_owned: bool, - /// This could either be a fixed buffer provided by the API user or it - /// could be our own array list. - header_bytes: std.ArrayListUnmanaged(u8), - max_header_bytes: usize, - next_chunk_length: u64, - done: bool = false, - - compression: union(enum) { - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - none: void, - } = .none, - - pub const Headers = struct { - status: http.Status, - version: http.Version, - location: ?[]const u8 = null, - content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, // This should only ever be chunked, compression is handled separately. - transfer_compression: ?http.TransferEncoding = null, - connection: http.Connection = .close, - - number_of_headers: usize = 0, - - pub fn parse(bytes: []const u8) !Response.Headers { - var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); - - const first_line = it.first(); - if (first_line.len < 12) - return error.ShortHttpStatusLine; - - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.BadHttpVersion, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); - - var headers: Response.Headers = .{ - .version = version, - .status = status, - }; - - while (it.next()) |line| { - headers.number_of_headers += 1; - - if (line.len == 0) return error.HttpHeadersInvalid; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - var line_it = mem.split(u8, line, ": "); - const header_name = line_it.first(); - const header_value = line_it.rest(); - if (std.ascii.eqlIgnoreCase(header_name, "location")) { - if (headers.location != null) return error.HttpHeadersInvalid; - headers.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (headers.content_length != null) return error.HttpHeadersInvalid; - headers.content_length = try std.fmt.parseInt(u64, header_value, 10); - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - if (headers.transfer_encoding != null or headers.transfer_compression != null) return error.HttpHeadersInvalid; - - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = std.mem.splitBackwards(u8, header_value, ","); - - if (iter.next()) |first| { - const kind = std.meta.stringToEnum( - http.TransferEncoding, - std.mem.trim(u8, first, " "), - ) orelse - return error.HttpTransferEncodingUnsupported; - - switch (kind) { - .chunked => headers.transfer_encoding = .chunked, - .compress => headers.transfer_compression = .compress, - .deflate => headers.transfer_compression = .deflate, - .gzip => headers.transfer_compression = .gzip, - } - } - - if (iter.next()) |second| { - if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; - - const kind = std.meta.stringToEnum( - http.TransferEncoding, - std.mem.trim(u8, second, " "), - ) orelse - return error.HttpTransferEncodingUnsupported; - - switch (kind) { - .chunked => return error.HttpHeadersInvalid, // chunked must come last - .compress => return error.HttpTransferEncodingUnsupported, // compress not supported - .deflate => headers.transfer_compression = .deflate, - .gzip => headers.transfer_compression = .gzip, - } - } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (headers.transfer_compression != null) return error.HttpHeadersInvalid; - - const kind = std.meta.stringToEnum( - http.TransferEncoding, - std.mem.trim(u8, header_value, " "), - ) orelse - return error.HttpTransferEncodingUnsupported; - - switch (kind) { - .chunked => return error.HttpHeadersInvalid, // not transfer encoding - .compress => return error.HttpTransferEncodingUnsupported, // compress not supported - .deflate => headers.transfer_compression = .deflate, - .gzip => headers.transfer_compression = .gzip, - } - } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { - headers.connection = .keep_alive; - } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { - headers.connection = .close; - } else { - return error.HttpConnectionHeaderUnsupported; - } - } - } - - return headers; - } - - test "parse headers" { - const example = - "HTTP/1.1 301 Moved Permanently\r\n" ++ - "Location: https://www.example.com/\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n\r\n"; - const parsed = try Response.Headers.parse(example); - try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version); - try testing.expectEqual(http.Status.moved_permanently, parsed.status); - try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse - return error.TestFailed); - try testing.expectEqual(@as(?u64, 220), parsed.content_length); - } - - test "header continuation" { - const example = - "HTTP/1.0 200 OK\r\n" ++ - "Content-Type: text/html;\r\n charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n\r\n"; - try testing.expectError( - error.HttpHeaderContinuationsUnsupported, - Response.Headers.parse(example), - ); - } - - test "extra content length" { - const example = - "HTTP/1.0 200 OK\r\n" ++ - "Content-Length: 220\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "content-length: 220\r\n\r\n"; - try testing.expectError( - error.HttpHeadersInvalid, - Response.Headers.parse(example), - ); - } - }; - - pub const State = enum { - /// Begin header parsing states. - invalid, - start, - seen_r, - seen_rn, - seen_rnr, - finished, - /// Begin transfer-encoding: chunked parsing states. - chunk_size_prefix_r, - chunk_size_prefix_n, - chunk_size, - chunk_r, - chunk_data, - - pub fn isContent(self: State) bool { - return switch (self) { - .invalid, .start, .seen_r, .seen_rn, .seen_rnr => false, - .finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true, - }; - } - }; - - pub fn initDynamic(max: usize) Response { - return .{ - .state = .start, - .headers = undefined, - .header_bytes = .{}, - .max_header_bytes = max, - .header_bytes_owned = true, - .next_chunk_length = undefined, - }; - } - - pub fn initStatic(buf: []u8) Response { - return .{ - .state = .start, - .headers = undefined, - .header_bytes = .{ .items = buf[0..0], .capacity = buf.len }, - .max_header_bytes = buf.len, - .header_bytes_owned = false, - .next_chunk_length = undefined, - }; - } - - /// Returns how many bytes are part of HTTP headers. Always less than or - /// equal to bytes.len. If the amount returned is less than bytes.len, it - /// means the headers ended and the first byte after the double \r\n\r\n is - /// located at `bytes[result]`. - pub fn findHeadersEnd(r: *Response, bytes: []const u8) usize { - var index: usize = 0; - - // TODO: https://github.com/ziglang/zig/issues/8220 - state: while (true) { - switch (r.state) { - .invalid => unreachable, - .finished => unreachable, - .start => while (true) { - switch (bytes.len - index) { - 0 => return index, - 1 => { - if (bytes[index] == '\r') - r.state = .seen_r; - return index + 1; - }, - 2 => { - if (int16(bytes[index..][0..2]) == int16("\r\n")) { - r.state = .seen_rn; - } else if (bytes[index + 1] == '\r') { - r.state = .seen_r; - } - return index + 2; - }, - 3 => { - if (int16(bytes[index..][0..2]) == int16("\r\n") and - bytes[index + 2] == '\r') - { - r.state = .seen_rnr; - } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n")) { - r.state = .seen_rn; - } else if (bytes[index + 2] == '\r') { - r.state = .seen_r; - } - return index + 3; - }, - 4...15 => { - if (int32(bytes[index..][0..4]) == int32("\r\n\r\n")) { - r.state = .finished; - return index + 4; - } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n") and - bytes[index + 3] == '\r') - { - r.state = .seen_rnr; - index += 4; - continue :state; - } else if (int16(bytes[index + 2 ..][0..2]) == int16("\r\n")) { - r.state = .seen_rn; - index += 4; - continue :state; - } else if (bytes[index + 3] == '\r') { - r.state = .seen_r; - index += 4; - continue :state; - } - index += 4; - continue; - }, - else => { - const chunk = bytes[index..][0..16]; - const v: @Vector(16, u8) = chunk.*; - const matches_r = v == @splat(16, @as(u8, '\r')); - const iota = std.simd.iota(u8, 16); - const default = @splat(16, @as(u8, 16)); - const sub_index = @reduce(.Min, @select(u8, matches_r, iota, default)); - switch (sub_index) { - 0...12 => { - index += sub_index + 4; - if (int32(chunk[sub_index..][0..4]) == int32("\r\n\r\n")) { - r.state = .finished; - return index; - } - continue; - }, - 13 => { - index += 16; - if (int16(chunk[14..][0..2]) == int16("\n\r")) { - r.state = .seen_rnr; - continue :state; - } - continue; - }, - 14 => { - index += 16; - if (chunk[15] == '\n') { - r.state = .seen_rn; - continue :state; - } - continue; - }, - 15 => { - r.state = .seen_r; - index += 16; - continue :state; - }, - 16 => { - index += 16; - continue; - }, - else => unreachable, - } - }, - } - }, - - .seen_r => switch (bytes.len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\n' => r.state = .seen_rn, - '\r' => r.state = .seen_r, - else => r.state = .start, - } - return index + 1; - }, - 2 => { - if (int16(bytes[index..][0..2]) == int16("\n\r")) { - r.state = .seen_rnr; - return index + 2; - } - r.state = .start; - return index + 2; - }, - else => { - if (int16(bytes[index..][0..2]) == int16("\n\r") and - bytes[index + 2] == '\n') - { - r.state = .finished; - return index + 3; - } - index += 3; - r.state = .start; - continue :state; - }, - }, - .seen_rn => switch (bytes.len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\r' => r.state = .seen_rnr, - else => r.state = .start, - } - return index + 1; - }, - else => { - if (int16(bytes[index..][0..2]) == int16("\r\n")) { - r.state = .finished; - return index + 2; - } - index += 2; - r.state = .start; - continue :state; - }, - }, - .seen_rnr => switch (bytes.len - index) { - 0 => return index, - else => { - if (bytes[index] == '\n') { - r.state = .finished; - return index + 1; - } - index += 1; - r.state = .start; - continue :state; - }, - }, - .chunk_size_prefix_r => unreachable, - .chunk_size_prefix_n => unreachable, - .chunk_size => unreachable, - .chunk_r => unreachable, - .chunk_data => unreachable, - } - - return index; - } - } - - pub fn findChunkedLen(r: *Response, bytes: []const u8) usize { - var i: usize = 0; - if (r.state == .chunk_size) { - while (i < bytes.len) : (i += 1) { - const digit = switch (bytes[i]) { - '0'...'9' => |b| b - '0', - 'A'...'Z' => |b| b - 'A' + 10, - 'a'...'z' => |b| b - 'a' + 10, - '\r' => { - r.state = .chunk_r; - i += 1; - break; - }, - else => { - r.state = .invalid; - return i; - }, - }; - const mul = @mulWithOverflow(r.next_chunk_length, 16); - if (mul[1] != 0) { - r.state = .invalid; - return i; - } - const add = @addWithOverflow(mul[0], digit); - if (add[1] != 0) { - r.state = .invalid; - return i; - } - r.next_chunk_length = add[0]; - } else { - return i; - } - } - assert(r.state == .chunk_r); - if (i == bytes.len) return i; - - if (bytes[i] == '\n') { - r.state = .chunk_data; - return i + 1; - } else { - r.state = .invalid; - return i; - } - } - - fn parseInt3(nnn: @Vector(3, u8)) u10 { - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); - } - - test parseInt3 { - const expectEqual = std.testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000".*)); - try expectEqual(@as(u10, 418), parseInt3("418".*)); - try expectEqual(@as(u10, 999), parseInt3("999".*)); - } - - test "find headers end basic" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4")); - try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18")); - try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah")); - } - - test "find headers end vectorized" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - const example = - "HTTP/1.1 301 Moved Permanently\r\n" ++ - "Location: https://www.example.com/\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n" ++ - "\r\ncontent"; - try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example)); - } - - test "find headers end bug" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - const trail = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; - const example = - "HTTP/1.1 200 OK\r\n" ++ - "Access-Control-Allow-Origin: https://render.githubusercontent.com\r\n" ++ - "content-disposition: attachment; filename=zig-0.10.0.tar.gz\r\n" ++ - "Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox\r\n" ++ - "Content-Type: application/x-gzip\r\n" ++ - "ETag: \"bfae0af6b01c7c0d89eb667cb5f0e65265968aeebda2689177e6b26acd3155ca\"\r\n" ++ - "Strict-Transport-Security: max-age=31536000\r\n" ++ - "Vary: Authorization,Accept-Encoding,Origin\r\n" ++ - "X-Content-Type-Options: nosniff\r\n" ++ - "X-Frame-Options: deny\r\n" ++ - "X-XSS-Protection: 1; mode=block\r\n" ++ - "Date: Fri, 06 Jan 2023 22:26:22 GMT\r\n" ++ - "Transfer-Encoding: chunked\r\n" ++ - "X-GitHub-Request-Id: 89C6:17E9:A7C9E:124B51:63B8A00E\r\n" ++ - "connection: close\r\n\r\n" ++ trail; - try testing.expectEqual(@as(usize, example.len - trail.len), r.findHeadersEnd(example)); - } - }; - - pub const RequestTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, - }; - - pub const Headers = struct { - version: http.Version = .@"HTTP/1.1", - method: http.Method = .GET, - user_agent: []const u8 = "Zig (std.http)", - connection: http.Connection = .keep_alive, - transfer_encoding: RequestTransfer = .none, - - custom: []const http.CustomHeader = &[_]http.CustomHeader{}, - }; - - pub const Options = struct { - handle_redirects: bool = true, - max_redirects: u32 = 3, - header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 }, - - pub const HeaderStrategy = union(enum) { - /// In this case, the client's Allocator will be used to store the - /// entire HTTP header. This value is the maximum total size of - /// HTTP headers allowed, otherwise - /// error.HttpHeadersExceededSizeLimit is returned from read(). - dynamic: usize, - /// This is used to store the entire HTTP header. If the HTTP - /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` - /// is returned from read(). When this is used, `error.OutOfMemory` - /// cannot be returned from `read()`. - static: []u8, - }; - }; - - /// Frees all resources associated with the request. - pub fn deinit(req: *Request) void { - switch (req.response.compression) { - .none => {}, - .deflate => |*deflate| deflate.deinit(), - .gzip => |*gzip| gzip.deinit(), - } - - if (req.response.header_bytes_owned) { - req.response.header_bytes.deinit(req.client.allocator); - } - - if (!req.response.done) { - // If the response wasn't fully read, then we need to close the connection. - req.connection.data.closing = true; - req.client.release(req.connection); - } - - req.arena.deinit(); - req.* = undefined; - } - - const ReadRawError = Connection.ReadError || Uri.ParseError || RequestError || error{ - UnexpectedEndOfStream, - TooManyHttpRedirects, - HttpRedirectMissingLocation, - HttpHeadersInvalid, - }; - - const ReaderRaw = std.io.Reader(*Request, ReadRawError, readRaw); - - /// Read from the underlying stream, without decompressing or parsing the headers. Must be called - /// after waitForCompleteHead() has returned successfully. - pub fn readRaw(req: *Request, buffer: []u8) ReadRawError!usize { - assert(req.response.state.isContent()); - - var index: usize = 0; - while (index == 0) { - const amt = try req.readRawAdvanced(buffer[index..]); - if (amt == 0 and req.response.done) break; - index += amt; - } - - return index; - } - - fn checkForCompleteHead(req: *Request, buffer: []u8) !usize { - switch (req.response.state) { - .invalid => unreachable, - .start, .seen_r, .seen_rn, .seen_rnr => {}, - else => return 0, // No more headers to read. - } - - const i = req.response.findHeadersEnd(buffer[0..]); - if (req.response.state == .invalid) return error.HttpHeadersInvalid; - - const headers_data = buffer[0..i]; - if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { - return error.HttpHeadersExceededSizeLimit; - } - try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); - - if (req.response.state == .finished) { - req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); - - if (req.response.headers.connection == .keep_alive) { - req.connection.data.closing = false; - } else { - req.connection.data.closing = true; - } - - if (req.response.headers.transfer_encoding) |transfer_encoding| { - switch (transfer_encoding) { - .chunked => { - req.response.next_chunk_length = 0; - req.response.state = .chunk_size; - }, - .compress => unreachable, - .deflate => unreachable, - .gzip => unreachable, - } - } else if (req.response.headers.content_length) |content_length| { - req.response.next_chunk_length = content_length; - - if (content_length == 0) req.response.done = true; - } else { - req.response.done = true; - } - - return i; - } - - return 0; - } - - pub const WaitForCompleteHeadError = ReadRawError || error{ - UnexpectedEndOfStream, - - HttpHeadersExceededSizeLimit, - ShortHttpStatusLine, - BadHttpVersion, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - }; - - /// Reads a complete response head. Any leftover data is stored in the request. This function is idempotent. - pub fn waitForCompleteHead(req: *Request) WaitForCompleteHeadError!void { - if (req.response.state.isContent()) return; - - while (true) { - const nread = try req.connection.data.read(req.read_buffer[0..]); - const amt = try checkForCompleteHead(req, req.read_buffer[0..nread]); - - if (amt != 0) { - req.read_buffer_start = @intCast(ReadBufferIndex, amt); - req.read_buffer_len = @intCast(ReadBufferIndex, nread); - return; - } else if (nread == 0) { - return error.UnexpectedEndOfStream; - } - } - } - - /// This one can return 0 without meaning EOF. - fn readRawAdvanced(req: *Request, buffer: []u8) !usize { - assert(req.response.state.isContent()); - if (req.response.done) return 0; - - // var in: []const u8 = undefined; - if (req.read_buffer_start == req.read_buffer_len) { - const nread = try req.connection.data.read(req.read_buffer[0..]); - if (nread == 0) return error.UnexpectedEndOfStream; - - req.read_buffer_start = 0; - req.read_buffer_len = @intCast(ReadBufferIndex, nread); - } - - var out_index: usize = 0; - while (true) { - switch (req.response.state) { - .invalid, .start, .seen_r, .seen_rn, .seen_rnr => unreachable, - .finished => { - // TODO https://github.com/ziglang/zig/issues/14039 - const buf_avail = req.read_buffer_len - req.read_buffer_start; - const data_avail = req.response.next_chunk_length; - const out_avail = buffer.len; - - if (req.handle_redirects and req.response.headers.status.class() == .redirect) { - const can_read = @intCast(usize, @min(buf_avail, data_avail)); - req.response.next_chunk_length -= can_read; - - if (req.response.next_chunk_length == 0) { - req.client.release(req.connection); - req.connection = undefined; - req.response.done = true; - } - - return 0; // skip over as much data as possible - } - - const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); - req.response.next_chunk_length -= can_read; - - mem.copy(u8, buffer[0..], req.read_buffer[req.read_buffer_start..][0..can_read]); - req.read_buffer_start += @intCast(ReadBufferIndex, can_read); - - if (req.response.next_chunk_length == 0) { - req.client.release(req.connection); - req.connection = undefined; - req.response.done = true; - } - - return can_read; - }, - .chunk_size_prefix_r => switch (req.read_buffer_len - req.read_buffer_start) { - 0 => return out_index, - 1 => switch (req.read_buffer[req.read_buffer_start]) { - '\r' => { - req.response.state = .chunk_size_prefix_n; - return out_index; - }, - else => { - req.response.state = .invalid; - return error.HttpHeadersInvalid; - }, - }, - else => switch (int16(req.read_buffer[req.read_buffer_start..][0..2])) { - int16("\r\n") => { - req.read_buffer_start += 2; - req.response.state = .chunk_size; - continue; - }, - else => { - req.response.state = .invalid; - return error.HttpHeadersInvalid; - }, - }, - }, - .chunk_size_prefix_n => switch (req.read_buffer_len - req.read_buffer_start) { - 0 => return out_index, - else => switch (req.read_buffer[req.read_buffer_start]) { - '\n' => { - req.read_buffer_start += 1; - req.response.state = .chunk_size; - continue; - }, - else => { - req.response.state = .invalid; - return error.HttpHeadersInvalid; - }, - }, - }, - .chunk_size, .chunk_r => { - const i = req.response.findChunkedLen(req.read_buffer[req.read_buffer_start..req.read_buffer_len]); - switch (req.response.state) { - .invalid => return error.HttpHeadersInvalid, - .chunk_data => { - if (req.response.next_chunk_length == 0) { - req.response.done = true; - req.client.release(req.connection); - req.connection = undefined; - - return out_index; - } - - req.read_buffer_start += @intCast(ReadBufferIndex, i); - continue; - }, - .chunk_size => return out_index, - else => unreachable, - } - }, - .chunk_data => { - // TODO https://github.com/ziglang/zig/issues/14039 - const buf_avail = req.read_buffer_len - req.read_buffer_start; - const data_avail = req.response.next_chunk_length; - const out_avail = buffer.len - out_index; - - if (req.handle_redirects and req.response.headers.status.class() == .redirect) { - const can_read = @intCast(usize, @min(buf_avail, data_avail)); - req.response.next_chunk_length -= can_read; - - if (req.response.next_chunk_length == 0) { - req.client.release(req.connection); - req.connection = undefined; - req.response.done = true; - continue; - } - - return 0; // skip over as much data as possible - } - - const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); - req.response.next_chunk_length -= can_read; - - mem.copy(u8, buffer[out_index..], req.read_buffer[req.read_buffer_start..][0..can_read]); - req.read_buffer_start += @intCast(ReadBufferIndex, can_read); - out_index += can_read; - - if (req.response.next_chunk_length == 0) { - req.response.state = .chunk_size_prefix_r; - - continue; - } - - return out_index; - }, - } - } - } - - pub const ReadError = DeflateDecompressor.Error || GzipDecompressor.Error || WaitForCompleteHeadError || error{ - BadHeader, - InvalidCompression, - StreamTooLong, - InvalidWindowSize, - }; - - pub const Reader = std.io.Reader(*Request, ReadError, read); - - pub fn reader(req: *Request) Reader { - return .{ .context = req }; - } - - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - while (true) { - if (!req.response.state.isContent()) try req.waitForCompleteHead(); - - if (req.handle_redirects and req.response.headers.status.class() == .redirect) { - assert(try req.readRaw(buffer) == 0); - - if (req.redirects_left == 0) return error.TooManyHttpRedirects; - - const location = req.response.headers.location orelse - return error.HttpRedirectMissingLocation; - const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location); - - var new_arena = std.heap.ArenaAllocator.init(req.client.allocator); - const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator()); - errdefer new_arena.deinit(); - - req.arena.deinit(); - req.arena = new_arena; - - const new_req = try req.client.request(resolved_url, req.headers, .{ - .max_redirects = req.redirects_left - 1, - .header_strategy = if (req.response.header_bytes_owned) .{ - .dynamic = req.response.max_header_bytes, - } else .{ - .static = req.response.header_bytes.unusedCapacitySlice(), - }, - }); - req.deinit(); - req.* = new_req; - } else { - break; - } - } - - if (req.response.compression == .none) { - if (req.response.headers.transfer_compression) |compression| { - switch (compression) { - .compress => unreachable, - .deflate => req.response.compression = .{ - .deflate = try std.compress.zlib.zlibStream(req.client.allocator, ReaderRaw{ .context = req }), - }, - .gzip => req.response.compression = .{ - .gzip = try std.compress.gzip.decompress(req.client.allocator, ReaderRaw{ .context = req }), - }, - .chunked => unreachable, - } - } - } - - return switch (req.response.compression) { - .deflate => |*deflate| try deflate.read(buffer), - .gzip => |*gzip| try gzip.read(buffer), - else => try req.readRaw(buffer), - }; - } - - pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - pub const WriteError = Connection.WriteError || error{MessageTooLong}; - - pub const Writer = std.io.Writer(*Request, WriteError, write); - - pub fn writer(req: *Request) Writer { - return .{ .context = req }; - } - - /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. - pub fn write(req: *Request, bytes: []const u8) !usize { - switch (req.headers.transfer_encoding) { - .chunked => { - try req.connection.data.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.data.writeAll(bytes); - try req.connection.data.writeAll("\r\n"); - - return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try req.connection.data.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } - } - - /// Finish the body of a request. This notifies the server that you have no more data to send. - pub fn finish(req: *Request) !void { - switch (req.headers.transfer_encoding) { - .chunked => try req.connection.data.writeAll("0\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } - } - - inline fn int16(array: *const [2]u8) u16 { - return @bitCast(u16, array.*); - } - - inline fn int32(array: *const [4]u8) u32 { - return @bitCast(u32, array.*); - } - - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(u64, array.*); - } - - test { - const builtin = @import("builtin"); - - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - _ = Response; - } -}; - pub fn deinit(client: *Client) void { client.connection_mutex.lock(); @@ -1231,8 +265,8 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req .handle_redirects = options.handle_redirects, .compression_init = false, .response = switch (options.header_strategy) { - .dynamic => |max| Request.Response.initDynamic(max), - .static => |buf| Request.Response.initStatic(buf), + .dynamic => |max| Response.initDynamic(max), + .static => |buf| Response.initStatic(buf), }, .arena = undefined, }; @@ -1274,7 +308,7 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req } else { try writer.writeAll("\r\nConnection: keep-alive"); } - try writer.writeAll("\r\nAccept-Encoding: gzip, deflate"); + try writer.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd"); switch (headers.transfer_encoding) { .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"), diff --git a/lib/std/http/Client/Request.zig b/lib/std/http/Client/Request.zig new file mode 100644 index 0000000000..26ce5cb7bf --- /dev/null +++ b/lib/std/http/Client/Request.zig @@ -0,0 +1,488 @@ +const std = @import("std"); +const http = std.http; +const Uri = std.Uri; +const mem = std.mem; +const assert = std.debug.assert; + +const Client = @import("../Client.zig"); +const Connection = Client.Connection; +const ConnectionNode = Client.ConnectionNode; +const Response = @import("Response.zig"); + +const Request = @This(); + +const read_buffer_size = 8192; +const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size); + +uri: Uri, +client: *Client, +connection: *ConnectionNode, +response: Response, +/// These are stored in Request so that they are available when following +/// redirects. +headers: Headers, + +redirects_left: u32, +handle_redirects: bool, +compression_init: bool, + +/// Used as a allocator for resolving redirects locations. +arena: std.heap.ArenaAllocator, + +/// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning. +read_buffer: [read_buffer_size]u8 = undefined, +read_buffer_start: ReadBufferIndex = 0, +read_buffer_len: ReadBufferIndex = 0, + +pub const RequestTransfer = union(enum) { + content_length: u64, + chunked: void, + none: void, +}; + +pub const Headers = struct { + version: http.Version = .@"HTTP/1.1", + method: http.Method = .GET, + user_agent: []const u8 = "zig (std.http)", + connection: http.Connection = .keep_alive, + transfer_encoding: RequestTransfer = .none, + + custom: []const http.CustomHeader = &[_]http.CustomHeader{}, +}; + +pub const Options = struct { + handle_redirects: bool = true, + max_redirects: u32 = 3, + header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 }, + + pub const HeaderStrategy = union(enum) { + /// In this case, the client's Allocator will be used to store the + /// entire HTTP header. This value is the maximum total size of + /// HTTP headers allowed, otherwise + /// error.HttpHeadersExceededSizeLimit is returned from read(). + dynamic: usize, + /// This is used to store the entire HTTP header. If the HTTP + /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` + /// is returned from read(). When this is used, `error.OutOfMemory` + /// cannot be returned from `read()`. + static: []u8, + }; +}; + +/// Frees all resources associated with the request. +pub fn deinit(req: *Request) void { + switch (req.response.compression) { + .none => {}, + .deflate => |*deflate| deflate.deinit(), + .gzip => |*gzip| gzip.deinit(), + .zstd => |*zstd| zstd.deinit(), + } + + if (req.response.header_bytes_owned) { + req.response.header_bytes.deinit(req.client.allocator); + } + + if (!req.response.done) { + // If the response wasn't fully read, then we need to close the connection. + req.connection.data.closing = true; + req.client.release(req.connection); + } + + req.arena.deinit(); + req.* = undefined; +} + +pub const ReadRawError = Connection.ReadError || Uri.ParseError || Client.RequestError || error{ + UnexpectedEndOfStream, + TooManyHttpRedirects, + HttpRedirectMissingLocation, + HttpHeadersInvalid, +}; + +pub const ReaderRaw = std.io.Reader(*Request, ReadRawError, readRaw); + +/// Read from the underlying stream, without decompressing or parsing the headers. Must be called +/// after waitForCompleteHead() has returned successfully. +pub fn readRaw(req: *Request, buffer: []u8) ReadRawError!usize { + assert(req.response.state.isContent()); + + var index: usize = 0; + while (index == 0) { + const amt = try req.readRawAdvanced(buffer[index..]); + if (amt == 0 and req.response.done) break; + index += amt; + } + + return index; +} + +fn checkForCompleteHead(req: *Request, buffer: []u8) !usize { + switch (req.response.state) { + .invalid => unreachable, + .start, .seen_r, .seen_rn, .seen_rnr => {}, + else => return 0, // No more headers to read. + } + + const i = req.response.findHeadersEnd(buffer[0..]); + if (req.response.state == .invalid) return error.HttpHeadersInvalid; + + const headers_data = buffer[0..i]; + if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { + return error.HttpHeadersExceededSizeLimit; + } + try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); + + if (req.response.state == .finished) { + req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); + + if (req.response.upgrade) |_| { + req.connection.data.closing = false; + req.response.done = true; + return i; + } + + if (req.response.headers.connection == .keep_alive) { + req.connection.data.closing = false; + } else { + req.connection.data.closing = true; + } + + if (req.response.headers.transfer_encoding) |transfer_encoding| { + switch (transfer_encoding) { + .chunked => { + req.response.next_chunk_length = 0; + req.response.state = .chunk_size; + }, + } + } else if (req.response.headers.content_length) |content_length| { + req.response.next_chunk_length = content_length; + + if (content_length == 0) req.response.done = true; + } else { + req.response.done = true; + } + + return i; + } + + return 0; +} + +pub const WaitForCompleteHeadError = ReadRawError || error{ + UnexpectedEndOfStream, + + HttpHeadersExceededSizeLimit, + ShortHttpStatusLine, + BadHttpVersion, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, +}; + +/// Reads a complete response head. Any leftover data is stored in the request. This function is idempotent. +pub fn waitForCompleteHead(req: *Request) WaitForCompleteHeadError!void { + if (req.response.state.isContent()) return; + + while (true) { + const nread = try req.connection.data.read(req.read_buffer[0..]); + const amt = try checkForCompleteHead(req, req.read_buffer[0..nread]); + + if (amt != 0) { + req.read_buffer_start = @intCast(ReadBufferIndex, amt); + req.read_buffer_len = @intCast(ReadBufferIndex, nread); + return; + } else if (nread == 0) { + return error.UnexpectedEndOfStream; + } + } +} + +/// This one can return 0 without meaning EOF. +fn readRawAdvanced(req: *Request, buffer: []u8) !usize { + assert(req.response.state.isContent()); + if (req.response.done) return 0; + + // var in: []const u8 = undefined; + if (req.read_buffer_start == req.read_buffer_len) { + const nread = try req.connection.data.read(req.read_buffer[0..]); + if (nread == 0) return error.UnexpectedEndOfStream; + + req.read_buffer_start = 0; + req.read_buffer_len = @intCast(ReadBufferIndex, nread); + } + + var out_index: usize = 0; + while (true) { + switch (req.response.state) { + .invalid, .start, .seen_r, .seen_rn, .seen_rnr => unreachable, + .finished => { + // TODO https://github.com/ziglang/zig/issues/14039 + const buf_avail = req.read_buffer_len - req.read_buffer_start; + const data_avail = req.response.next_chunk_length; + const out_avail = buffer.len; + + if (req.handle_redirects and req.response.headers.status.class() == .redirect) { + const can_read = @intCast(usize, @min(buf_avail, data_avail)); + req.response.next_chunk_length -= can_read; + + if (req.response.next_chunk_length == 0) { + req.client.release(req.connection); + req.connection = undefined; + req.response.done = true; + } + + return 0; // skip over as much data as possible + } + + const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); + req.response.next_chunk_length -= can_read; + + mem.copy(u8, buffer[0..], req.read_buffer[req.read_buffer_start..][0..can_read]); + req.read_buffer_start += @intCast(ReadBufferIndex, can_read); + + if (req.response.next_chunk_length == 0) { + req.client.release(req.connection); + req.connection = undefined; + req.response.done = true; + } + + return can_read; + }, + .chunk_size_prefix_r => switch (req.read_buffer_len - req.read_buffer_start) { + 0 => return out_index, + 1 => switch (req.read_buffer[req.read_buffer_start]) { + '\r' => { + req.response.state = .chunk_size_prefix_n; + return out_index; + }, + else => { + req.response.state = .invalid; + return error.HttpHeadersInvalid; + }, + }, + else => switch (int16(req.read_buffer[req.read_buffer_start..][0..2])) { + int16("\r\n") => { + req.read_buffer_start += 2; + req.response.state = .chunk_size; + continue; + }, + else => { + req.response.state = .invalid; + return error.HttpHeadersInvalid; + }, + }, + }, + .chunk_size_prefix_n => switch (req.read_buffer_len - req.read_buffer_start) { + 0 => return out_index, + else => switch (req.read_buffer[req.read_buffer_start]) { + '\n' => { + req.read_buffer_start += 1; + req.response.state = .chunk_size; + continue; + }, + else => { + req.response.state = .invalid; + return error.HttpHeadersInvalid; + }, + }, + }, + .chunk_size, .chunk_r => { + const i = req.response.findChunkedLen(req.read_buffer[req.read_buffer_start..req.read_buffer_len]); + switch (req.response.state) { + .invalid => return error.HttpHeadersInvalid, + .chunk_data => { + if (req.response.next_chunk_length == 0) { + req.response.done = true; + req.client.release(req.connection); + req.connection = undefined; + + return out_index; + } + + req.read_buffer_start += @intCast(ReadBufferIndex, i); + continue; + }, + .chunk_size => return out_index, + else => unreachable, + } + }, + .chunk_data => { + // TODO https://github.com/ziglang/zig/issues/14039 + const buf_avail = req.read_buffer_len - req.read_buffer_start; + const data_avail = req.response.next_chunk_length; + const out_avail = buffer.len - out_index; + + if (req.handle_redirects and req.response.headers.status.class() == .redirect) { + const can_read = @intCast(usize, @min(buf_avail, data_avail)); + req.response.next_chunk_length -= can_read; + + if (req.response.next_chunk_length == 0) { + req.client.release(req.connection); + req.connection = undefined; + req.response.done = true; + continue; + } + + return 0; // skip over as much data as possible + } + + const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); + req.response.next_chunk_length -= can_read; + + mem.copy(u8, buffer[out_index..], req.read_buffer[req.read_buffer_start..][0..can_read]); + req.read_buffer_start += @intCast(ReadBufferIndex, can_read); + out_index += can_read; + + if (req.response.next_chunk_length == 0) { + req.response.state = .chunk_size_prefix_r; + + continue; + } + + return out_index; + }, + } + } +} + +pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ + BadHeader, + InvalidCompression, + StreamTooLong, + InvalidWindowSize, + CompressionNotSupported +}; + +pub const Reader = std.io.Reader(*Request, ReadError, read); + +pub fn reader(req: *Request) Reader { + return .{ .context = req }; +} + +pub fn read(req: *Request, buffer: []u8) ReadError!usize { + while (true) { + if (!req.response.state.isContent()) try req.waitForCompleteHead(); + + if (req.handle_redirects and req.response.headers.status.class() == .redirect) { + assert(try req.readRaw(buffer) == 0); + + if (req.redirects_left == 0) return error.TooManyHttpRedirects; + + const location = req.response.headers.location orelse + return error.HttpRedirectMissingLocation; + const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location); + + var new_arena = std.heap.ArenaAllocator.init(req.client.allocator); + const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator()); + errdefer new_arena.deinit(); + + req.arena.deinit(); + req.arena = new_arena; + + const new_req = try req.client.request(resolved_url, req.headers, .{ + .max_redirects = req.redirects_left - 1, + .header_strategy = if (req.response.header_bytes_owned) .{ + .dynamic = req.response.max_header_bytes, + } else .{ + .static = req.response.header_bytes.unusedCapacitySlice(), + }, + }); + req.deinit(); + req.* = new_req; + } else { + break; + } + } + + if (req.response.compression == .none) { + if (req.response.headers.transfer_compression) |compression| { + switch (compression) { + .compress => return error.CompressionNotSupported, + .deflate => req.response.compression = .{ + .deflate = try std.compress.zlib.zlibStream(req.client.allocator, ReaderRaw{ .context = req }), + }, + .gzip => req.response.compression = .{ + .gzip = try std.compress.gzip.decompress(req.client.allocator, ReaderRaw{ .context = req }), + }, + .zstd => req.response.compression = .{ + .zstd = std.compress.zstd.decompressStream(req.client.allocator, ReaderRaw{ .context = req }), + }, + } + } + } + + return switch (req.response.compression) { + .deflate => |*deflate| try deflate.read(buffer), + .gzip => |*gzip| try gzip.read(buffer), + .zstd => |*zstd| try zstd.read(buffer), + else => try req.readRaw(buffer), + }; +} + +pub fn readAll(req: *Request, buffer: []u8) !usize { + var index: usize = 0; + while (index < buffer.len) { + const amt = try read(req, buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; +} + +pub const WriteError = Connection.WriteError || error{MessageTooLong}; + +pub const Writer = std.io.Writer(*Request, WriteError, write); + +pub fn writer(req: *Request) Writer { + return .{ .context = req }; +} + +/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. +pub fn write(req: *Request, bytes: []const u8) !usize { + switch (req.headers.transfer_encoding) { + .chunked => { + try req.connection.data.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.data.writeAll(bytes); + try req.connection.data.writeAll("\r\n"); + + return bytes.len; + }, + .content_length => |*len| { + if (len.* < bytes.len) return error.MessageTooLong; + + const amt = try req.connection.data.write(bytes); + len.* -= amt; + return amt; + }, + .none => return error.NotWriteable, + } +} + +/// Finish the body of a request. This notifies the server that you have no more data to send. +pub fn finish(req: *Request) !void { + switch (req.headers.transfer_encoding) { + .chunked => try req.connection.data.writeAll("0\r\n"), + .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .none => {}, + } +} + +inline fn int16(array: *const [2]u8) u16 { + return @bitCast(u16, array.*); +} + +inline fn int32(array: *const [4]u8) u32 { + return @bitCast(u32, array.*); +} + +inline fn int64(array: *const [8]u8) u64 { + return @bitCast(u64, array.*); +} + +test { + const builtin = @import("builtin"); + + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + _ = Response; +} diff --git a/lib/std/http/Client/Response.zig b/lib/std/http/Client/Response.zig new file mode 100644 index 0000000000..bc064f9a20 --- /dev/null +++ b/lib/std/http/Client/Response.zig @@ -0,0 +1,506 @@ +const std = @import("std"); +const http = std.http; +const mem = std.mem; +const testing = std.testing; +const assert = std.debug.assert; + +const Client = @import("../Client.zig"); +const Response = @This(); + +headers: Headers, +state: State, +header_bytes_owned: bool, +/// This could either be a fixed buffer provided by the API user or it +/// could be our own array list. +header_bytes: std.ArrayListUnmanaged(u8), +max_header_bytes: usize, +next_chunk_length: u64, +done: bool = false, + +compression: union(enum) { + deflate: Client.DeflateDecompressor, + gzip: Client.GzipDecompressor, + zstd: Client.ZstdDecompressor, + none: void, +} = .none, + +pub const Headers = struct { + status: http.Status, + version: http.Version, + location: ?[]const u8 = null, + content_length: ?u64 = null, + transfer_encoding: ?http.TransferEncoding = null, + transfer_compression: ?http.ContentEncoding = null, + connection: http.Connection = .close, + + number_of_headers: usize = 0, + + pub fn parse(bytes: []const u8) !Headers { + var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); + + const first_line = it.first(); + if (first_line.len < 12) + return error.ShortHttpStatusLine; + + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.BadHttpVersion, + }; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; + const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); + + var headers: Headers = .{ + .version = version, + .status = status, + }; + + while (it.next()) |line| { + headers.number_of_headers += 1; + + if (line.len == 0) return error.HttpHeadersInvalid; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } + var line_it = mem.split(u8, line, ": "); + const header_name = line_it.first(); + const header_value = line_it.rest(); + if (std.ascii.eqlIgnoreCase(header_name, "location")) { + if (headers.location != null) return error.HttpHeadersInvalid; + headers.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (headers.content_length != null) return error.HttpHeadersInvalid; + headers.content_length = try std.fmt.parseInt(u64, header_value, 10); + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + if (headers.transfer_encoding != null or headers.transfer_compression != null) return error.HttpHeadersInvalid; + + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = std.mem.splitBackwards(u8, header_value, ","); + + if (iter.next()) |first| { + const trimmed = std.mem.trim(u8, first, " "); + + if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { + headers.transfer_encoding = te; + } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + headers.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |second| { + if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + + const trimmed = std.mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + headers.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (headers.transfer_compression != null) return error.HttpHeadersInvalid; + + const trimmed = std.mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + headers.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { + headers.connection = .keep_alive; + } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { + headers.connection = .close; + } else { + return error.HttpConnectionHeaderUnsupported; + } + } + } + + return headers; + } + + test "parse headers" { + const example = + "HTTP/1.1 301 Moved Permanently\r\n" ++ + "Location: https://www.example.com/\r\n" ++ + "Content-Type: text/html; charset=UTF-8\r\n" ++ + "Content-Length: 220\r\n\r\n"; + const parsed = try Headers.parse(example); + try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version); + try testing.expectEqual(http.Status.moved_permanently, parsed.status); + try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse + return error.TestFailed); + try testing.expectEqual(@as(?u64, 220), parsed.content_length); + } + + test "header continuation" { + const example = + "HTTP/1.0 200 OK\r\n" ++ + "Content-Type: text/html;\r\n charset=UTF-8\r\n" ++ + "Content-Length: 220\r\n\r\n"; + try testing.expectError( + error.HttpHeaderContinuationsUnsupported, + Headers.parse(example), + ); + } + + test "extra content length" { + const example = + "HTTP/1.0 200 OK\r\n" ++ + "Content-Length: 220\r\n" ++ + "Content-Type: text/html; charset=UTF-8\r\n" ++ + "content-length: 220\r\n\r\n"; + try testing.expectError( + error.HttpHeadersInvalid, + Headers.parse(example), + ); + } +}; + +inline fn int16(array: *const [2]u8) u16 { + return @bitCast(u16, array.*); +} + +inline fn int32(array: *const [4]u8) u32 { + return @bitCast(u32, array.*); +} + +inline fn int64(array: *const [8]u8) u64 { + return @bitCast(u64, array.*); +} + +pub const State = enum { + /// Begin header parsing states. + invalid, + start, + seen_r, + seen_rn, + seen_rnr, + finished, + /// Begin transfer-encoding: chunked parsing states. + chunk_size_prefix_r, + chunk_size_prefix_n, + chunk_size, + chunk_r, + chunk_data, + + pub fn isContent(self: State) bool { + return switch (self) { + .invalid, .start, .seen_r, .seen_rn, .seen_rnr => false, + .finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true, + }; + } +}; + +pub fn initDynamic(max: usize) Response { + return .{ + .state = .start, + .headers = undefined, + .header_bytes = .{}, + .max_header_bytes = max, + .header_bytes_owned = true, + .next_chunk_length = undefined, + }; +} + +pub fn initStatic(buf: []u8) Response { + return .{ + .state = .start, + .headers = undefined, + .header_bytes = .{ .items = buf[0..0], .capacity = buf.len }, + .max_header_bytes = buf.len, + .header_bytes_owned = false, + .next_chunk_length = undefined, + }; +} + +/// Returns how many bytes are part of HTTP headers. Always less than or +/// equal to bytes.len. If the amount returned is less than bytes.len, it +/// means the headers ended and the first byte after the double \r\n\r\n is +/// located at `bytes[result]`. +pub fn findHeadersEnd(r: *Response, bytes: []const u8) usize { + var index: usize = 0; + + // TODO: https://github.com/ziglang/zig/issues/8220 + state: while (true) { + switch (r.state) { + .invalid => unreachable, + .finished => unreachable, + .start => while (true) { + switch (bytes.len - index) { + 0 => return index, + 1 => { + if (bytes[index] == '\r') + r.state = .seen_r; + return index + 1; + }, + 2 => { + if (int16(bytes[index..][0..2]) == int16("\r\n")) { + r.state = .seen_rn; + } else if (bytes[index + 1] == '\r') { + r.state = .seen_r; + } + return index + 2; + }, + 3 => { + if (int16(bytes[index..][0..2]) == int16("\r\n") and + bytes[index + 2] == '\r') + { + r.state = .seen_rnr; + } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n")) { + r.state = .seen_rn; + } else if (bytes[index + 2] == '\r') { + r.state = .seen_r; + } + return index + 3; + }, + 4...15 => { + if (int32(bytes[index..][0..4]) == int32("\r\n\r\n")) { + r.state = .finished; + return index + 4; + } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n") and + bytes[index + 3] == '\r') + { + r.state = .seen_rnr; + index += 4; + continue :state; + } else if (int16(bytes[index + 2 ..][0..2]) == int16("\r\n")) { + r.state = .seen_rn; + index += 4; + continue :state; + } else if (bytes[index + 3] == '\r') { + r.state = .seen_r; + index += 4; + continue :state; + } + index += 4; + continue; + }, + else => { + const chunk = bytes[index..][0..16]; + const v: @Vector(16, u8) = chunk.*; + const matches_r = v == @splat(16, @as(u8, '\r')); + const iota = std.simd.iota(u8, 16); + const default = @splat(16, @as(u8, 16)); + const sub_index = @reduce(.Min, @select(u8, matches_r, iota, default)); + switch (sub_index) { + 0...12 => { + index += sub_index + 4; + if (int32(chunk[sub_index..][0..4]) == int32("\r\n\r\n")) { + r.state = .finished; + return index; + } + continue; + }, + 13 => { + index += 16; + if (int16(chunk[14..][0..2]) == int16("\n\r")) { + r.state = .seen_rnr; + continue :state; + } + continue; + }, + 14 => { + index += 16; + if (chunk[15] == '\n') { + r.state = .seen_rn; + continue :state; + } + continue; + }, + 15 => { + r.state = .seen_r; + index += 16; + continue :state; + }, + 16 => { + index += 16; + continue; + }, + else => unreachable, + } + }, + } + }, + + .seen_r => switch (bytes.len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\n' => r.state = .seen_rn, + '\r' => r.state = .seen_r, + else => r.state = .start, + } + return index + 1; + }, + 2 => { + if (int16(bytes[index..][0..2]) == int16("\n\r")) { + r.state = .seen_rnr; + return index + 2; + } + r.state = .start; + return index + 2; + }, + else => { + if (int16(bytes[index..][0..2]) == int16("\n\r") and + bytes[index + 2] == '\n') + { + r.state = .finished; + return index + 3; + } + index += 3; + r.state = .start; + continue :state; + }, + }, + .seen_rn => switch (bytes.len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => r.state = .seen_rnr, + else => r.state = .start, + } + return index + 1; + }, + else => { + if (int16(bytes[index..][0..2]) == int16("\r\n")) { + r.state = .finished; + return index + 2; + } + index += 2; + r.state = .start; + continue :state; + }, + }, + .seen_rnr => switch (bytes.len - index) { + 0 => return index, + else => { + if (bytes[index] == '\n') { + r.state = .finished; + return index + 1; + } + index += 1; + r.state = .start; + continue :state; + }, + }, + .chunk_size_prefix_r => unreachable, + .chunk_size_prefix_n => unreachable, + .chunk_size => unreachable, + .chunk_r => unreachable, + .chunk_data => unreachable, + } + + return index; + } +} + +pub fn findChunkedLen(r: *Response, bytes: []const u8) usize { + var i: usize = 0; + if (r.state == .chunk_size) { + while (i < bytes.len) : (i += 1) { + const digit = switch (bytes[i]) { + '0'...'9' => |b| b - '0', + 'A'...'Z' => |b| b - 'A' + 10, + 'a'...'z' => |b| b - 'a' + 10, + '\r' => { + r.state = .chunk_r; + i += 1; + break; + }, + else => { + r.state = .invalid; + return i; + }, + }; + const mul = @mulWithOverflow(r.next_chunk_length, 16); + if (mul[1] != 0) { + r.state = .invalid; + return i; + } + const add = @addWithOverflow(mul[0], digit); + if (add[1] != 0) { + r.state = .invalid; + return i; + } + r.next_chunk_length = add[0]; + } else { + return i; + } + } + assert(r.state == .chunk_r); + if (i == bytes.len) return i; + + if (bytes[i] == '\n') { + r.state = .chunk_data; + return i + 1; + } else { + r.state = .invalid; + return i; + } +} + +fn parseInt3(nnn: @Vector(3, u8)) u10 { + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); +} + +test parseInt3 { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000".*)); + try expectEqual(@as(u10, 418), parseInt3("418".*)); + try expectEqual(@as(u10, 999), parseInt3("999".*)); +} + +test "find headers end basic" { + var buffer: [1]u8 = undefined; + var r = Response.initStatic(&buffer); + try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4")); + try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18")); + try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah")); +} + +test "find headers end vectorized" { + var buffer: [1]u8 = undefined; + var r = Response.initStatic(&buffer); + const example = + "HTTP/1.1 301 Moved Permanently\r\n" ++ + "Location: https://www.example.com/\r\n" ++ + "Content-Type: text/html; charset=UTF-8\r\n" ++ + "Content-Length: 220\r\n" ++ + "\r\ncontent"; + try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example)); +} + +test "find headers end bug" { + var buffer: [1]u8 = undefined; + var r = Response.initStatic(&buffer); + const trail = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + const example = + "HTTP/1.1 200 OK\r\n" ++ + "Access-Control-Allow-Origin: https://render.githubusercontent.com\r\n" ++ + "content-disposition: attachment; filename=zig-0.10.0.tar.gz\r\n" ++ + "Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox\r\n" ++ + "Content-Type: application/x-gzip\r\n" ++ + "ETag: \"bfae0af6b01c7c0d89eb667cb5f0e65265968aeebda2689177e6b26acd3155ca\"\r\n" ++ + "Strict-Transport-Security: max-age=31536000\r\n" ++ + "Vary: Authorization,Accept-Encoding,Origin\r\n" ++ + "X-Content-Type-Options: nosniff\r\n" ++ + "X-Frame-Options: deny\r\n" ++ + "X-XSS-Protection: 1; mode=block\r\n" ++ + "Date: Fri, 06 Jan 2023 22:26:22 GMT\r\n" ++ + "Transfer-Encoding: chunked\r\n" ++ + "X-GitHub-Request-Id: 89C6:17E9:A7C9E:124B51:63B8A00E\r\n" ++ + "connection: close\r\n\r\n" ++ trail; + try testing.expectEqual(@as(usize, example.len - trail.len), r.findHeadersEnd(example)); +} From 524e0cd987a52a60ce1014aa27cd73f99a3b9958 Mon Sep 17 00:00:00 2001 From: Nameless Date: Wed, 8 Mar 2023 11:27:13 -0600 Subject: [PATCH 6/6] std.http: rework connection pool into its own type --- lib/std/http/Client.zig | 189 +++++++++++++++++++------------ lib/std/http/Client/Request.zig | 22 ++-- lib/std/http/Client/Response.zig | 5 +- lib/std/std.zig | 5 + 4 files changed, 134 insertions(+), 87 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index d44b1d098d..baf0239388 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -16,6 +16,9 @@ const testing = std.testing; pub const Request = @import("Client/Request.zig"); pub const Response = @import("Client/Response.zig"); +pub const default_connection_pool_size = 32; +const connection_pool_size = std.options.http_connection_pool_size; + /// Used for tcpConnectToHost and storing HTTP headers when an externally /// managed buffer is not provided. allocator: Allocator, @@ -24,39 +27,115 @@ ca_bundle: std.crypto.Certificate.Bundle = .{}, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, -connection_mutex: std.Thread.Mutex = .{}, connection_pool: ConnectionPool = .{}, -connection_used: ConnectionPool = .{}, -pub const ConnectionPool = std.TailQueue(Connection); -pub const ConnectionNode = ConnectionPool.Node; +pub const ConnectionPool = struct { + pub const Criteria = struct { + host: []const u8, + port: u16, + is_tls: bool, + }; -/// Acquires an existing connection from the connection pool. This function is threadsafe. -/// If the caller already holds the connection mutex, it should pass `true` for `held`. -pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void { - if (!held) client.connection_mutex.lock(); - defer if (!held) client.connection_mutex.unlock(); + const Queue = std.TailQueue(Connection); + pub const Node = Queue.Node; - client.connection_pool.remove(node); - client.connection_used.append(node); -} + mutex: std.Thread.Mutex = .{}, + used: Queue = .{}, + free: Queue = .{}, + free_len: usize = 0, + free_size: usize = default_connection_pool_size, -/// Tries to release a connection back to the connection pool. This function is threadsafe. -/// If the connection is marked as closing, it will be closed instead. -pub fn release(client: *Client, node: *ConnectionNode) void { - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); + /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// If no connection is found, null is returned. + pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node { + pool.mutex.lock(); + defer pool.mutex.unlock(); - client.connection_used.remove(node); + var next = pool.free.last; + while (next) |node| : (next = node.prev) { + if ((node.data.protocol == .tls) != criteria.is_tls) continue; + if (node.data.port != criteria.port) continue; + if (std.mem.eql(u8, node.data.host, criteria.host)) continue; - if (node.data.closing) { - node.data.close(client); + pool.acquireUnsafe(node); + return node; + } - return client.allocator.destroy(node); + return null; } - client.connection_pool.append(node); -} + /// Acquires an existing connection from the connection pool. This function is not threadsafe. + pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void { + pool.free.remove(node); + pool.free_len -= 1; + + pool.used.append(node); + } + + /// Acquires an existing connection from the connection pool. This function is threadsafe. + pub fn acquire(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + return pool.acquireUnsafe(node); + } + + /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// If the connection is marked as closing, it will be closed instead. + pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + pool.used.remove(node); + + if (node.data.closing) { + node.data.close(client); + + return client.allocator.destroy(node); + } + + if (pool.free_len + 1 >= pool.free_size) { + const popped = pool.free.popFirst() orelse unreachable; + + popped.data.close(client); + + return client.allocator.destroy(popped); + } + + pool.free.append(node); + pool.free_len += 1; + } + + /// Adds a newly created node to the pool of used connections. This function is threadsafe. + pub fn addUsed(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + pool.used.append(node); + } + + pub fn deinit(pool: *ConnectionPool, client: *Client) void { + pool.mutex.lock(); + + var next = pool.free.first; + while (next) |node| { + defer client.allocator.destroy(node); + next = node.next; + + node.data.close(client); + } + + next = pool.used.first; + while (next) |node| { + defer client.allocator.destroy(node); + next = node.next; + + node.data.close(client); + } + + pool.* = undefined; + } +}; pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw); pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw); @@ -142,25 +221,7 @@ pub const Connection = struct { }; pub fn deinit(client: *Client) void { - client.connection_mutex.lock(); - - var next = client.connection_pool.first; - while (next) |node| { - next = node.next; - - node.data.close(client); - - client.allocator.destroy(node); - } - - next = client.connection_used.first; - while (next) |node| { - next = node.next; - - node.data.close(client); - - client.allocator.destroy(node); - } + client.connection_pool.deinit(client); client.ca_bundle.deinit(client.allocator); client.* = undefined; @@ -168,36 +229,25 @@ pub fn deinit(client: *Client) void { pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream); -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode { - { // Search through the connection pool for a potential connection. - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node { + if (client.connection_pool.findConnection(.{ + .host = host, + .port = port, + .is_tls = protocol == .tls, + })) |node| + return node; - var potential = client.connection_pool.last; - while (potential) |node| { - const same_host = mem.eql(u8, node.data.host, host); - const same_port = node.data.port == port; - const same_protocol = node.data.protocol == protocol; - - if (same_host and same_port and same_protocol) { - client.acquire(node, true); - return node; - } - - potential = node.prev; - } - } - - const conn = try client.allocator.create(ConnectionNode); + const conn = try client.allocator.create(ConnectionPool.Node); errdefer client.allocator.destroy(conn); + conn.* = .{ .data = undefined }; - conn.* = .{ .data = .{ + conn.data = .{ .stream = try net.tcpConnectToHost(client.allocator, host, port), .tls_client = undefined, .protocol = protocol, .host = try client.allocator.dupe(u8, host), .port = port, - } }; + }; switch (protocol) { .plain => {}, @@ -210,12 +260,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio }, } - { - client.connection_mutex.lock(); - defer client.connection_mutex.unlock(); - - client.connection_used.append(conn); - } + client.connection_pool.addUsed(conn); return conn; } @@ -247,8 +292,8 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req const host = uri.host orelse return error.UriMissingHost; if (client.next_https_rescan_certs and protocol == .tls) { - client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex. - defer client.connection_mutex.unlock(); + client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex. + defer client.connection_pool.mutex.unlock(); if (client.next_https_rescan_certs) { try client.ca_bundle.rescan(client.allocator); diff --git a/lib/std/http/Client/Request.zig b/lib/std/http/Client/Request.zig index 26ce5cb7bf..9e2ebd2d6c 100644 --- a/lib/std/http/Client/Request.zig +++ b/lib/std/http/Client/Request.zig @@ -6,7 +6,7 @@ const assert = std.debug.assert; const Client = @import("../Client.zig"); const Connection = Client.Connection; -const ConnectionNode = Client.ConnectionNode; +const ConnectionNode = Client.ConnectionPool.Node; const Response = @import("Response.zig"); const Request = @This(); @@ -85,7 +85,7 @@ pub fn deinit(req: *Request) void { if (!req.response.done) { // If the response wasn't fully read, then we need to close the connection. req.connection.data.closing = true; - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); } req.arena.deinit(); @@ -135,7 +135,7 @@ fn checkForCompleteHead(req: *Request, buffer: []u8) !usize { if (req.response.state == .finished) { req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); - if (req.response.upgrade) |_| { + if (req.response.headers.upgrade) |_| { req.connection.data.closing = false; req.response.done = true; return i; @@ -226,7 +226,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { req.response.next_chunk_length -= can_read; if (req.response.next_chunk_length == 0) { - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); req.connection = undefined; req.response.done = true; } @@ -241,7 +241,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { req.read_buffer_start += @intCast(ReadBufferIndex, can_read); if (req.response.next_chunk_length == 0) { - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); req.connection = undefined; req.response.done = true; } @@ -293,7 +293,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { .chunk_data => { if (req.response.next_chunk_length == 0) { req.response.done = true; - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); req.connection = undefined; return out_index; @@ -317,7 +317,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { req.response.next_chunk_length -= can_read; if (req.response.next_chunk_length == 0) { - req.client.release(req.connection); + req.client.connection_pool.release(req.client, req.connection); req.connection = undefined; req.response.done = true; continue; @@ -345,13 +345,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize { } } -pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ - BadHeader, - InvalidCompression, - StreamTooLong, - InvalidWindowSize, - CompressionNotSupported -}; +pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize, CompressionNotSupported }; pub const Reader = std.io.Reader(*Request, ReadError, read); diff --git a/lib/std/http/Client/Response.zig b/lib/std/http/Client/Response.zig index bc064f9a20..8b2a9a4918 100644 --- a/lib/std/http/Client/Response.zig +++ b/lib/std/http/Client/Response.zig @@ -32,6 +32,7 @@ pub const Headers = struct { transfer_encoding: ?http.TransferEncoding = null, transfer_compression: ?http.ContentEncoding = null, connection: http.Connection = .close, + upgrade: ?[]const u8 = null, number_of_headers: usize = 0, @@ -93,7 +94,7 @@ pub const Headers = struct { if (iter.next()) |second| { if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; - + const trimmed = std.mem.trim(u8, second, " "); if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { @@ -122,6 +123,8 @@ pub const Headers = struct { } else { return error.HttpConnectionHeaderUnsupported; } + } else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) { + headers.upgrade = header_value; } } diff --git a/lib/std/std.zig b/lib/std/std.zig index c1c682e224..e888ade659 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -185,6 +185,11 @@ pub const options = struct { options_override.keep_sigpipe else false; + + pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size")) + options_override.http_connection_pool_size + else + http.Client.default_connection_pool_size; }; // This forces the start.zig file to be imported, and the comptime logic inside that