zig/lib/std/http/Client.zig
2023-03-09 14:55:13 -06:00

1312 lines
51 KiB
Zig

//! TODO: send connection: keep-alive and LRU cache a configurable number of
//! open connections to skip DNS and TLS handshake for subsequent requests.
//!
//! This API is *not* thread safe.
const std = @import("../std.zig");
const mem = std.mem;
const assert = std.debug.assert;
const http = std.http;
const net = std.net;
const Client = @This();
const Uri = std.Uri;
const Allocator = std.mem.Allocator;
const testing = std.testing;
/// Used for tcpConnectToHost and storing HTTP headers when an externally
/// managed buffer is not provided.
allocator: Allocator,
ca_bundle: std.crypto.Certificate.Bundle = .{},
/// When this is `true`, the next time this client performs an HTTPS request,
/// it will first rescan the system for root certificates.
next_https_rescan_certs: bool = true,
connection_mutex: std.Thread.Mutex = .{},
connection_pool: ConnectionPool = .{},
connection_used: ConnectionPool = .{},
const ConnectionPool = std.TailQueue(Connection);
const ConnectionNode = ConnectionPool.Node;
/// Acquires an existing connection from the connection pool. This function is threadsafe.
/// If the caller already holds the connection mutex, it should pass `true` for `held`.
pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void {
if (!held) client.connection_mutex.lock();
defer if (!held) client.connection_mutex.unlock();
client.connection_pool.remove(node);
client.connection_used.append(node);
}
/// Tries to release a connection back to the connection pool. This function is threadsafe.
/// If the connection is marked as closing, it will be closed instead.
pub fn release(client: *Client, node: *ConnectionNode) void {
client.connection_mutex.lock();
defer client.connection_mutex.unlock();
client.connection_used.remove(node);
if (node.data.closing) {
node.data.close(client);
return client.allocator.destroy(node);
}
client.connection_pool.append(node);
}
const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
pub const Connection = struct {
stream: net.Stream,
/// undefined unless protocol is tls.
tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
protocol: Protocol,
host: []u8,
port: u16,
// This connection has been part of a non keepalive request and cannot be added to the pool.
closing: bool = false,
pub const Protocol = enum { plain, tls };
pub fn read(conn: *Connection, buffer: []u8) !usize {
switch (conn.protocol) {
.plain => return conn.stream.read(buffer),
.tls => return conn.tls_client.read(conn.stream, buffer),
}
}
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
switch (conn.protocol) {
.plain => return conn.stream.readAtLeast(buffer, len),
.tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
}
}
pub const ReadError = net.Stream.ReadError || error{
TlsConnectionTruncated,
TlsRecordOverflow,
TlsDecodeError,
TlsAlert,
TlsBadRecordMac,
Overflow,
TlsBadLength,
TlsIllegalParameter,
TlsUnexpectedMessage,
};
pub const Reader = std.io.Reader(*Connection, ReadError, read);
pub fn reader(conn: *Connection) Reader {
return Reader{ .context = conn };
}
pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
switch (conn.protocol) {
.plain => return conn.stream.writeAll(buffer),
.tls => return conn.tls_client.writeAll(conn.stream, buffer),
}
}
pub fn write(conn: *Connection, buffer: []const u8) !usize {
switch (conn.protocol) {
.plain => return conn.stream.write(buffer),
.tls => return conn.tls_client.write(conn.stream, buffer),
}
}
pub const WriteError = net.Stream.WriteError || error{};
pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer {
return Writer{ .context = conn };
}
pub fn close(conn: *Connection, client: *const Client) void {
if (conn.protocol == .tls) {
// try to cleanly close the TLS connection, for any server that cares.
_ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
client.allocator.destroy(conn.tls_client);
}
conn.stream.close();
client.allocator.free(conn.host);
}
};
pub const Request = struct {
const read_buffer_size = 8192;
const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size);
uri: Uri,
client: *Client,
connection: *ConnectionNode,
response: Response,
/// These are stored in Request so that they are available when following
/// redirects.
headers: Headers,
redirects_left: u32,
handle_redirects: bool,
compression_init: bool,
/// Used as a allocator for resolving redirects locations.
arena: std.heap.ArenaAllocator,
/// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning.
read_buffer: [read_buffer_size]u8 = undefined,
read_buffer_start: ReadBufferIndex = 0,
read_buffer_len: ReadBufferIndex = 0,
pub const Response = struct {
headers: Response.Headers,
state: State,
header_bytes_owned: bool,
/// This could either be a fixed buffer provided by the API user or it
/// could be our own array list.
header_bytes: std.ArrayListUnmanaged(u8),
max_header_bytes: usize,
next_chunk_length: u64,
done: bool = false,
compression: union(enum) {
deflate: DeflateDecompressor,
gzip: GzipDecompressor,
none: void,
} = .none,
pub const Headers = struct {
status: http.Status,
version: http.Version,
location: ?[]const u8 = null,
content_length: ?u64 = null,
transfer_encoding: ?http.TransferEncoding = null, // This should only ever be chunked, compression is handled separately.
transfer_compression: ?http.TransferEncoding = null,
connection: http.Connection = .close,
number_of_headers: usize = 0,
pub fn parse(bytes: []const u8) !Response.Headers {
var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n");
const first_line = it.first();
if (first_line.len < 12)
return error.ShortHttpStatusLine;
const version: http.Version = switch (int64(first_line[0..8])) {
int64("HTTP/1.0") => .@"HTTP/1.0",
int64("HTTP/1.1") => .@"HTTP/1.1",
else => return error.BadHttpVersion,
};
if (first_line[8] != ' ') return error.HttpHeadersInvalid;
const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*));
var headers: Response.Headers = .{
.version = version,
.status = status,
};
while (it.next()) |line| {
headers.number_of_headers += 1;
if (line.len == 0) return error.HttpHeadersInvalid;
switch (line[0]) {
' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
else => {},
}
var line_it = mem.split(u8, line, ": ");
const header_name = line_it.first();
const header_value = line_it.rest();
if (std.ascii.eqlIgnoreCase(header_name, "location")) {
if (headers.location != null) return error.HttpHeadersInvalid;
headers.location = header_value;
} else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
if (headers.content_length != null) return error.HttpHeadersInvalid;
headers.content_length = try std.fmt.parseInt(u64, header_value, 10);
} else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
if (headers.transfer_encoding != null or headers.transfer_compression != null) return error.HttpHeadersInvalid;
// Transfer-Encoding: second, first
// Transfer-Encoding: deflate, chunked
var iter = std.mem.splitBackwards(u8, header_value, ",");
if (iter.next()) |first| {
const kind = std.meta.stringToEnum(
http.TransferEncoding,
std.mem.trim(u8, first, " "),
) orelse
return error.HttpTransferEncodingUnsupported;
switch (kind) {
.chunked => headers.transfer_encoding = .chunked,
.compress => headers.transfer_compression = .compress,
.deflate => headers.transfer_compression = .deflate,
.gzip => headers.transfer_compression = .gzip,
}
}
if (iter.next()) |second| {
if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
const kind = std.meta.stringToEnum(
http.TransferEncoding,
std.mem.trim(u8, second, " "),
) orelse
return error.HttpTransferEncodingUnsupported;
switch (kind) {
.chunked => return error.HttpHeadersInvalid, // chunked must come last
.compress => return error.HttpTransferEncodingUnsupported, // compress not supported
.deflate => headers.transfer_compression = .deflate,
.gzip => headers.transfer_compression = .gzip,
}
}
if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
} else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
const kind = std.meta.stringToEnum(
http.TransferEncoding,
std.mem.trim(u8, header_value, " "),
) orelse
return error.HttpTransferEncodingUnsupported;
switch (kind) {
.chunked => return error.HttpHeadersInvalid, // not transfer encoding
.compress => return error.HttpTransferEncodingUnsupported, // compress not supported
.deflate => headers.transfer_compression = .deflate,
.gzip => headers.transfer_compression = .gzip,
}
} else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
headers.connection = .keep_alive;
} else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
headers.connection = .close;
} else {
return error.HttpConnectionHeaderUnsupported;
}
}
}
return headers;
}
test "parse headers" {
const example =
"HTTP/1.1 301 Moved Permanently\r\n" ++
"Location: https://www.example.com/\r\n" ++
"Content-Type: text/html; charset=UTF-8\r\n" ++
"Content-Length: 220\r\n\r\n";
const parsed = try Response.Headers.parse(example);
try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version);
try testing.expectEqual(http.Status.moved_permanently, parsed.status);
try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse
return error.TestFailed);
try testing.expectEqual(@as(?u64, 220), parsed.content_length);
}
test "header continuation" {
const example =
"HTTP/1.0 200 OK\r\n" ++
"Content-Type: text/html;\r\n charset=UTF-8\r\n" ++
"Content-Length: 220\r\n\r\n";
try testing.expectError(
error.HttpHeaderContinuationsUnsupported,
Response.Headers.parse(example),
);
}
test "extra content length" {
const example =
"HTTP/1.0 200 OK\r\n" ++
"Content-Length: 220\r\n" ++
"Content-Type: text/html; charset=UTF-8\r\n" ++
"content-length: 220\r\n\r\n";
try testing.expectError(
error.HttpHeadersInvalid,
Response.Headers.parse(example),
);
}
};
pub const State = enum {
/// Begin header parsing states.
invalid,
start,
seen_r,
seen_rn,
seen_rnr,
finished,
/// Begin transfer-encoding: chunked parsing states.
chunk_size_prefix_r,
chunk_size_prefix_n,
chunk_size,
chunk_r,
chunk_data,
pub fn isContent(self: State) bool {
return switch (self) {
.invalid, .start, .seen_r, .seen_rn, .seen_rnr => false,
.finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true,
};
}
};
pub fn initDynamic(max: usize) Response {
return .{
.state = .start,
.headers = undefined,
.header_bytes = .{},
.max_header_bytes = max,
.header_bytes_owned = true,
.next_chunk_length = undefined,
};
}
pub fn initStatic(buf: []u8) Response {
return .{
.state = .start,
.headers = undefined,
.header_bytes = .{ .items = buf[0..0], .capacity = buf.len },
.max_header_bytes = buf.len,
.header_bytes_owned = false,
.next_chunk_length = undefined,
};
}
/// Returns how many bytes are part of HTTP headers. Always less than or
/// equal to bytes.len. If the amount returned is less than bytes.len, it
/// means the headers ended and the first byte after the double \r\n\r\n is
/// located at `bytes[result]`.
pub fn findHeadersEnd(r: *Response, bytes: []const u8) usize {
var index: usize = 0;
// TODO: https://github.com/ziglang/zig/issues/8220
state: while (true) {
switch (r.state) {
.invalid => unreachable,
.finished => unreachable,
.start => while (true) {
switch (bytes.len - index) {
0 => return index,
1 => {
if (bytes[index] == '\r')
r.state = .seen_r;
return index + 1;
},
2 => {
if (int16(bytes[index..][0..2]) == int16("\r\n")) {
r.state = .seen_rn;
} else if (bytes[index + 1] == '\r') {
r.state = .seen_r;
}
return index + 2;
},
3 => {
if (int16(bytes[index..][0..2]) == int16("\r\n") and
bytes[index + 2] == '\r')
{
r.state = .seen_rnr;
} else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n")) {
r.state = .seen_rn;
} else if (bytes[index + 2] == '\r') {
r.state = .seen_r;
}
return index + 3;
},
4...15 => {
if (int32(bytes[index..][0..4]) == int32("\r\n\r\n")) {
r.state = .finished;
return index + 4;
} else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n") and
bytes[index + 3] == '\r')
{
r.state = .seen_rnr;
index += 4;
continue :state;
} else if (int16(bytes[index + 2 ..][0..2]) == int16("\r\n")) {
r.state = .seen_rn;
index += 4;
continue :state;
} else if (bytes[index + 3] == '\r') {
r.state = .seen_r;
index += 4;
continue :state;
}
index += 4;
continue;
},
else => {
const chunk = bytes[index..][0..16];
const v: @Vector(16, u8) = chunk.*;
const matches_r = v == @splat(16, @as(u8, '\r'));
const iota = std.simd.iota(u8, 16);
const default = @splat(16, @as(u8, 16));
const sub_index = @reduce(.Min, @select(u8, matches_r, iota, default));
switch (sub_index) {
0...12 => {
index += sub_index + 4;
if (int32(chunk[sub_index..][0..4]) == int32("\r\n\r\n")) {
r.state = .finished;
return index;
}
continue;
},
13 => {
index += 16;
if (int16(chunk[14..][0..2]) == int16("\n\r")) {
r.state = .seen_rnr;
continue :state;
}
continue;
},
14 => {
index += 16;
if (chunk[15] == '\n') {
r.state = .seen_rn;
continue :state;
}
continue;
},
15 => {
r.state = .seen_r;
index += 16;
continue :state;
},
16 => {
index += 16;
continue;
},
else => unreachable,
}
},
}
},
.seen_r => switch (bytes.len - index) {
0 => return index,
1 => {
switch (bytes[index]) {
'\n' => r.state = .seen_rn,
'\r' => r.state = .seen_r,
else => r.state = .start,
}
return index + 1;
},
2 => {
if (int16(bytes[index..][0..2]) == int16("\n\r")) {
r.state = .seen_rnr;
return index + 2;
}
r.state = .start;
return index + 2;
},
else => {
if (int16(bytes[index..][0..2]) == int16("\n\r") and
bytes[index + 2] == '\n')
{
r.state = .finished;
return index + 3;
}
index += 3;
r.state = .start;
continue :state;
},
},
.seen_rn => switch (bytes.len - index) {
0 => return index,
1 => {
switch (bytes[index]) {
'\r' => r.state = .seen_rnr,
else => r.state = .start,
}
return index + 1;
},
else => {
if (int16(bytes[index..][0..2]) == int16("\r\n")) {
r.state = .finished;
return index + 2;
}
index += 2;
r.state = .start;
continue :state;
},
},
.seen_rnr => switch (bytes.len - index) {
0 => return index,
else => {
if (bytes[index] == '\n') {
r.state = .finished;
return index + 1;
}
index += 1;
r.state = .start;
continue :state;
},
},
.chunk_size_prefix_r => unreachable,
.chunk_size_prefix_n => unreachable,
.chunk_size => unreachable,
.chunk_r => unreachable,
.chunk_data => unreachable,
}
return index;
}
}
pub fn findChunkedLen(r: *Response, bytes: []const u8) usize {
var i: usize = 0;
if (r.state == .chunk_size) {
while (i < bytes.len) : (i += 1) {
const digit = switch (bytes[i]) {
'0'...'9' => |b| b - '0',
'A'...'Z' => |b| b - 'A' + 10,
'a'...'z' => |b| b - 'a' + 10,
'\r' => {
r.state = .chunk_r;
i += 1;
break;
},
else => {
r.state = .invalid;
return i;
},
};
const mul = @mulWithOverflow(r.next_chunk_length, 16);
if (mul[1] != 0) {
r.state = .invalid;
return i;
}
const add = @addWithOverflow(mul[0], digit);
if (add[1] != 0) {
r.state = .invalid;
return i;
}
r.next_chunk_length = add[0];
} else {
return i;
}
}
assert(r.state == .chunk_r);
if (i == bytes.len) return i;
if (bytes[i] == '\n') {
r.state = .chunk_data;
return i + 1;
} else {
r.state = .invalid;
return i;
}
}
fn parseInt3(nnn: @Vector(3, u8)) u10 {
const zero: @Vector(3, u8) = .{ '0', '0', '0' };
const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm);
}
test parseInt3 {
const expectEqual = std.testing.expectEqual;
try expectEqual(@as(u10, 0), parseInt3("000".*));
try expectEqual(@as(u10, 418), parseInt3("418".*));
try expectEqual(@as(u10, 999), parseInt3("999".*));
}
test "find headers end basic" {
var buffer: [1]u8 = undefined;
var r = Response.initStatic(&buffer);
try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4"));
try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18"));
try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah"));
}
test "find headers end vectorized" {
var buffer: [1]u8 = undefined;
var r = Response.initStatic(&buffer);
const example =
"HTTP/1.1 301 Moved Permanently\r\n" ++
"Location: https://www.example.com/\r\n" ++
"Content-Type: text/html; charset=UTF-8\r\n" ++
"Content-Length: 220\r\n" ++
"\r\ncontent";
try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example));
}
test "find headers end bug" {
var buffer: [1]u8 = undefined;
var r = Response.initStatic(&buffer);
const trail = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
const example =
"HTTP/1.1 200 OK\r\n" ++
"Access-Control-Allow-Origin: https://render.githubusercontent.com\r\n" ++
"content-disposition: attachment; filename=zig-0.10.0.tar.gz\r\n" ++
"Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox\r\n" ++
"Content-Type: application/x-gzip\r\n" ++
"ETag: \"bfae0af6b01c7c0d89eb667cb5f0e65265968aeebda2689177e6b26acd3155ca\"\r\n" ++
"Strict-Transport-Security: max-age=31536000\r\n" ++
"Vary: Authorization,Accept-Encoding,Origin\r\n" ++
"X-Content-Type-Options: nosniff\r\n" ++
"X-Frame-Options: deny\r\n" ++
"X-XSS-Protection: 1; mode=block\r\n" ++
"Date: Fri, 06 Jan 2023 22:26:22 GMT\r\n" ++
"Transfer-Encoding: chunked\r\n" ++
"X-GitHub-Request-Id: 89C6:17E9:A7C9E:124B51:63B8A00E\r\n" ++
"connection: close\r\n\r\n" ++ trail;
try testing.expectEqual(@as(usize, example.len - trail.len), r.findHeadersEnd(example));
}
};
pub const RequestTransfer = union(enum) {
content_length: u64,
chunked: void,
none: void,
};
pub const Headers = struct {
version: http.Version = .@"HTTP/1.1",
method: http.Method = .GET,
user_agent: []const u8 = "Zig (std.http)",
connection: http.Connection = .keep_alive,
transfer_encoding: RequestTransfer = .none,
custom: []const http.CustomHeader = &[_]http.CustomHeader{},
};
pub const Options = struct {
handle_redirects: bool = true,
max_redirects: u32 = 3,
header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
pub const HeaderStrategy = union(enum) {
/// In this case, the client's Allocator will be used to store the
/// entire HTTP header. This value is the maximum total size of
/// HTTP headers allowed, otherwise
/// error.HttpHeadersExceededSizeLimit is returned from read().
dynamic: usize,
/// This is used to store the entire HTTP header. If the HTTP
/// header is too big to fit, `error.HttpHeadersExceededSizeLimit`
/// is returned from read(). When this is used, `error.OutOfMemory`
/// cannot be returned from `read()`.
static: []u8,
};
};
/// Frees all resources associated with the request.
pub fn deinit(req: *Request) void {
switch (req.response.compression) {
.none => {},
.deflate => |*deflate| deflate.deinit(),
.gzip => |*gzip| gzip.deinit(),
}
if (req.response.header_bytes_owned) {
req.response.header_bytes.deinit(req.client.allocator);
}
if (!req.response.done) {
// If the response wasn't fully read, then we need to close the connection.
req.connection.data.closing = true;
req.client.release(req.connection);
}
req.arena.deinit();
req.* = undefined;
}
const ReadRawError = Connection.ReadError || Uri.ParseError || RequestError || error{
UnexpectedEndOfStream,
TooManyHttpRedirects,
HttpRedirectMissingLocation,
HttpHeadersInvalid,
};
const ReaderRaw = std.io.Reader(*Request, ReadRawError, readRaw);
/// Read from the underlying stream, without decompressing or parsing the headers. Must be called
/// after waitForCompleteHead() has returned successfully.
pub fn readRaw(req: *Request, buffer: []u8) ReadRawError!usize {
assert(req.response.state.isContent());
var index: usize = 0;
while (index == 0) {
const amt = try req.readRawAdvanced(buffer[index..]);
if (amt == 0 and req.response.done) break;
index += amt;
}
return index;
}
fn checkForCompleteHead(req: *Request, buffer: []u8) !usize {
switch (req.response.state) {
.invalid => unreachable,
.start, .seen_r, .seen_rn, .seen_rnr => {},
else => return 0, // No more headers to read.
}
const i = req.response.findHeadersEnd(buffer[0..]);
if (req.response.state == .invalid) return error.HttpHeadersInvalid;
const headers_data = buffer[0..i];
if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) {
return error.HttpHeadersExceededSizeLimit;
}
try req.response.header_bytes.appendSlice(req.client.allocator, headers_data);
if (req.response.state == .finished) {
req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
if (req.response.headers.connection == .keep_alive) {
req.connection.data.closing = false;
} else {
req.connection.data.closing = true;
}
if (req.response.headers.transfer_encoding) |transfer_encoding| {
switch (transfer_encoding) {
.chunked => {
req.response.next_chunk_length = 0;
req.response.state = .chunk_size;
},
.compress => unreachable,
.deflate => unreachable,
.gzip => unreachable,
}
} else if (req.response.headers.content_length) |content_length| {
req.response.next_chunk_length = content_length;
if (content_length == 0) req.response.done = true;
} else {
req.response.done = true;
}
return i;
}
return 0;
}
pub const WaitForCompleteHeadError = ReadRawError || error{
UnexpectedEndOfStream,
HttpHeadersExceededSizeLimit,
ShortHttpStatusLine,
BadHttpVersion,
HttpHeaderContinuationsUnsupported,
HttpTransferEncodingUnsupported,
HttpConnectionHeaderUnsupported,
};
/// Reads a complete response head. Any leftover data is stored in the request. This function is idempotent.
pub fn waitForCompleteHead(req: *Request) WaitForCompleteHeadError!void {
if (req.response.state.isContent()) return;
while (true) {
const nread = try req.connection.data.read(req.read_buffer[0..]);
const amt = try checkForCompleteHead(req, req.read_buffer[0..nread]);
if (amt != 0) {
req.read_buffer_start = @intCast(ReadBufferIndex, amt);
req.read_buffer_len = @intCast(ReadBufferIndex, nread);
return;
} else if (nread == 0) {
return error.UnexpectedEndOfStream;
}
}
}
/// This one can return 0 without meaning EOF.
fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
assert(req.response.state.isContent());
if (req.response.done) return 0;
// var in: []const u8 = undefined;
if (req.read_buffer_start == req.read_buffer_len) {
const nread = try req.connection.data.read(req.read_buffer[0..]);
if (nread == 0) return error.UnexpectedEndOfStream;
req.read_buffer_start = 0;
req.read_buffer_len = @intCast(ReadBufferIndex, nread);
}
var out_index: usize = 0;
while (true) {
switch (req.response.state) {
.invalid, .start, .seen_r, .seen_rn, .seen_rnr => unreachable,
.finished => {
// TODO https://github.com/ziglang/zig/issues/14039
const buf_avail = req.read_buffer_len - req.read_buffer_start;
const data_avail = req.response.next_chunk_length;
const out_avail = buffer.len;
if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
const can_read = @intCast(usize, @min(buf_avail, data_avail));
req.response.next_chunk_length -= can_read;
if (req.response.next_chunk_length == 0) {
req.client.release(req.connection);
req.connection = undefined;
req.response.done = true;
}
return 0; // skip over as much data as possible
}
const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail));
req.response.next_chunk_length -= can_read;
mem.copy(u8, buffer[0..], req.read_buffer[req.read_buffer_start..][0..can_read]);
req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
if (req.response.next_chunk_length == 0) {
req.client.release(req.connection);
req.connection = undefined;
req.response.done = true;
}
return can_read;
},
.chunk_size_prefix_r => switch (req.read_buffer_len - req.read_buffer_start) {
0 => return out_index,
1 => switch (req.read_buffer[req.read_buffer_start]) {
'\r' => {
req.response.state = .chunk_size_prefix_n;
return out_index;
},
else => {
req.response.state = .invalid;
return error.HttpHeadersInvalid;
},
},
else => switch (int16(req.read_buffer[req.read_buffer_start..][0..2])) {
int16("\r\n") => {
req.read_buffer_start += 2;
req.response.state = .chunk_size;
continue;
},
else => {
req.response.state = .invalid;
return error.HttpHeadersInvalid;
},
},
},
.chunk_size_prefix_n => switch (req.read_buffer_len - req.read_buffer_start) {
0 => return out_index,
else => switch (req.read_buffer[req.read_buffer_start]) {
'\n' => {
req.read_buffer_start += 1;
req.response.state = .chunk_size;
continue;
},
else => {
req.response.state = .invalid;
return error.HttpHeadersInvalid;
},
},
},
.chunk_size, .chunk_r => {
const i = req.response.findChunkedLen(req.read_buffer[req.read_buffer_start..req.read_buffer_len]);
switch (req.response.state) {
.invalid => return error.HttpHeadersInvalid,
.chunk_data => {
if (req.response.next_chunk_length == 0) {
req.response.done = true;
req.client.release(req.connection);
req.connection = undefined;
return out_index;
}
req.read_buffer_start += @intCast(ReadBufferIndex, i);
continue;
},
.chunk_size => return out_index,
else => unreachable,
}
},
.chunk_data => {
// TODO https://github.com/ziglang/zig/issues/14039
const buf_avail = req.read_buffer_len - req.read_buffer_start;
const data_avail = req.response.next_chunk_length;
const out_avail = buffer.len - out_index;
if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
const can_read = @intCast(usize, @min(buf_avail, data_avail));
req.response.next_chunk_length -= can_read;
if (req.response.next_chunk_length == 0) {
req.client.release(req.connection);
req.connection = undefined;
req.response.done = true;
continue;
}
return 0; // skip over as much data as possible
}
const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail));
req.response.next_chunk_length -= can_read;
mem.copy(u8, buffer[out_index..], req.read_buffer[req.read_buffer_start..][0..can_read]);
req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
out_index += can_read;
if (req.response.next_chunk_length == 0) {
req.response.state = .chunk_size_prefix_r;
continue;
}
return out_index;
},
}
}
}
pub const ReadError = DeflateDecompressor.Error || GzipDecompressor.Error || WaitForCompleteHeadError || error{
BadHeader,
InvalidCompression,
StreamTooLong,
InvalidWindowSize,
};
pub const Reader = std.io.Reader(*Request, ReadError, read);
pub fn reader(req: *Request) Reader {
return .{ .context = req };
}
pub fn read(req: *Request, buffer: []u8) ReadError!usize {
while (true) {
if (!req.response.state.isContent()) try req.waitForCompleteHead();
if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
assert(try req.readRaw(buffer) == 0);
if (req.redirects_left == 0) return error.TooManyHttpRedirects;
const location = req.response.headers.location orelse
return error.HttpRedirectMissingLocation;
const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location);
var new_arena = std.heap.ArenaAllocator.init(req.client.allocator);
const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator());
errdefer new_arena.deinit();
req.arena.deinit();
req.arena = new_arena;
const new_req = try req.client.request(resolved_url, req.headers, .{
.max_redirects = req.redirects_left - 1,
.header_strategy = if (req.response.header_bytes_owned) .{
.dynamic = req.response.max_header_bytes,
} else .{
.static = req.response.header_bytes.unusedCapacitySlice(),
},
});
req.deinit();
req.* = new_req;
} else {
break;
}
}
if (req.response.compression == .none) {
if (req.response.headers.transfer_compression) |compression| {
switch (compression) {
.compress => unreachable,
.deflate => req.response.compression = .{
.deflate = try std.compress.zlib.zlibStream(req.client.allocator, ReaderRaw{ .context = req }),
},
.gzip => req.response.compression = .{
.gzip = try std.compress.gzip.decompress(req.client.allocator, ReaderRaw{ .context = req }),
},
.chunked => unreachable,
}
}
}
return switch (req.response.compression) {
.deflate => |*deflate| try deflate.read(buffer),
.gzip => |*gzip| try gzip.read(buffer),
else => try req.readRaw(buffer),
};
}
pub fn readAll(req: *Request, buffer: []u8) !usize {
var index: usize = 0;
while (index < buffer.len) {
const amt = try read(req, buffer[index..]);
if (amt == 0) break;
index += amt;
}
return index;
}
pub const WriteError = Connection.WriteError || error{MessageTooLong};
pub const Writer = std.io.Writer(*Request, WriteError, write);
pub fn writer(req: *Request) Writer {
return .{ .context = req };
}
/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
pub fn write(req: *Request, bytes: []const u8) !usize {
switch (req.headers.transfer_encoding) {
.chunked => {
try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
try req.connection.data.writeAll(bytes);
try req.connection.data.writeAll("\r\n");
return bytes.len;
},
.content_length => |*len| {
if (len.* < bytes.len) return error.MessageTooLong;
const amt = try req.connection.data.write(bytes);
len.* -= amt;
return amt;
},
.none => return error.NotWriteable,
}
}
/// Finish the body of a request. This notifies the server that you have no more data to send.
pub fn finish(req: *Request) !void {
switch (req.headers.transfer_encoding) {
.chunked => try req.connection.data.writeAll("0\r\n"),
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
.none => {},
}
}
inline fn int16(array: *const [2]u8) u16 {
return @bitCast(u16, array.*);
}
inline fn int32(array: *const [4]u8) u32 {
return @bitCast(u32, array.*);
}
inline fn int64(array: *const [8]u8) u64 {
return @bitCast(u64, array.*);
}
test {
const builtin = @import("builtin");
if (builtin.os.tag == .wasi) return error.SkipZigTest;
_ = Response;
}
};
pub fn deinit(client: *Client) void {
client.connection_mutex.lock();
var next = client.connection_pool.first;
while (next) |node| {
next = node.next;
node.data.close(client);
client.allocator.destroy(node);
}
next = client.connection_used.first;
while (next) |node| {
next = node.next;
node.data.close(client);
client.allocator.destroy(node);
}
client.ca_bundle.deinit(client.allocator);
client.* = undefined;
}
pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode {
{ // Search through the connection pool for a potential connection.
client.connection_mutex.lock();
defer client.connection_mutex.unlock();
var potential = client.connection_pool.last;
while (potential) |node| {
const same_host = mem.eql(u8, node.data.host, host);
const same_port = node.data.port == port;
const same_protocol = node.data.protocol == protocol;
if (same_host and same_port and same_protocol) {
client.acquire(node, true);
return node;
}
potential = node.prev;
}
}
const conn = try client.allocator.create(ConnectionNode);
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = .{
.stream = try net.tcpConnectToHost(client.allocator, host, port),
.tls_client = undefined,
.protocol = protocol,
.host = try client.allocator.dupe(u8, host),
.port = port,
} };
switch (protocol) {
.plain => {},
.tls => {
conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
conn.data.tls_client.allow_truncation_attacks = true;
},
}
{
client.connection_mutex.lock();
defer client.connection_mutex.unlock();
client.connection_used.append(conn);
}
return conn;
}
pub const RequestError = ConnectError || Connection.WriteError || error{
UnsupportedUrlScheme,
UriMissingHost,
CertificateAuthorityBundleTooBig,
InvalidPadding,
MissingEndCertificateMarker,
Unseekable,
EndOfStream,
};
pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request {
const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http"))
.plain
else if (mem.eql(u8, uri.scheme, "https"))
.tls
else
return error.UnsupportedUrlScheme;
const port: u16 = uri.port orelse switch (protocol) {
.plain => 80,
.tls => 443,
};
const host = uri.host orelse return error.UriMissingHost;
if (client.next_https_rescan_certs and protocol == .tls) {
client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
defer client.connection_mutex.unlock();
if (client.next_https_rescan_certs) {
try client.ca_bundle.rescan(client.allocator);
client.next_https_rescan_certs = false;
}
}
var req: Request = .{
.uri = uri,
.client = client,
.headers = headers,
.connection = try client.connect(host, port, protocol),
.redirects_left = options.max_redirects,
.handle_redirects = options.handle_redirects,
.compression_init = false,
.response = switch (options.header_strategy) {
.dynamic => |max| Request.Response.initDynamic(max),
.static => |buf| Request.Response.initStatic(buf),
},
.arena = undefined,
};
req.arena = std.heap.ArenaAllocator.init(client.allocator);
{
var buffered = std.io.bufferedWriter(req.connection.data.writer());
const writer = buffered.writer();
const escaped_path = try Uri.escapePath(client.allocator, uri.path);
defer client.allocator.free(escaped_path);
const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null;
defer if (escaped_query) |q| client.allocator.free(q);
const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null;
defer if (escaped_fragment) |f| client.allocator.free(f);
try writer.writeAll(@tagName(headers.method));
try writer.writeByte(' ');
try writer.writeAll(escaped_path);
if (escaped_query) |q| {
try writer.writeByte('?');
try writer.writeAll(q);
}
if (escaped_fragment) |f| {
try writer.writeByte('#');
try writer.writeAll(f);
}
try writer.writeByte(' ');
try writer.writeAll(@tagName(headers.version));
try writer.writeAll("\r\nHost: ");
try writer.writeAll(host);
try writer.writeAll("\r\nUser-Agent: ");
try writer.writeAll(headers.user_agent);
if (headers.connection == .close) {
try writer.writeAll("\r\nConnection: close");
} else {
try writer.writeAll("\r\nConnection: keep-alive");
}
try writer.writeAll("\r\nAccept-Encoding: gzip, deflate");
switch (headers.transfer_encoding) {
.chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"),
.content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}),
.none => {},
}
for (headers.custom) |header| {
try writer.writeAll("\r\n");
try writer.writeAll(header.name);
try writer.writeAll(": ");
try writer.writeAll(header.value);
}
try writer.writeAll("\r\n\r\n");
try buffered.flush();
}
return req;
}
test {
const builtin = @import("builtin");
const native_endian = comptime builtin.cpu.arch.endian();
if (builtin.zig_backend == .stage2_llvm and native_endian == .Big) {
// https://github.com/ziglang/zig/issues/13782
return error.SkipZigTest;
}
if (builtin.os.tag == .wasi) return error.SkipZigTest;
_ = Request;
}