mirror of
https://github.com/ziglang/zig.git
synced 2026-02-13 04:48:20 +00:00
add buffering to connection instead of the http protocol, to allow passing through upgrades
This commit is contained in:
parent
08bdaf3bd6
commit
aecbfa3a1e
@ -32,7 +32,20 @@ pub const ConnectionPool = struct {
|
||||
is_tls: bool,
|
||||
};
|
||||
|
||||
const Queue = std.TailQueue(Connection);
|
||||
pub const StoredConnection = struct {
|
||||
buffered: BufferedConnection,
|
||||
host: []u8,
|
||||
port: u16,
|
||||
|
||||
closing: bool = false,
|
||||
|
||||
pub fn deinit(self: *StoredConnection, client: *Client) void {
|
||||
self.buffered.close(client);
|
||||
client.allocator.free(self.host);
|
||||
}
|
||||
};
|
||||
|
||||
const Queue = std.TailQueue(StoredConnection);
|
||||
pub const Node = Queue.Node;
|
||||
|
||||
mutex: std.Thread.Mutex = .{},
|
||||
@ -49,7 +62,7 @@ pub const ConnectionPool = struct {
|
||||
|
||||
var next = pool.free.last;
|
||||
while (next) |node| : (next = node.prev) {
|
||||
if ((node.data.protocol == .tls) != criteria.is_tls) continue;
|
||||
if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue;
|
||||
if (node.data.port != criteria.port) continue;
|
||||
if (mem.eql(u8, node.data.host, criteria.host)) continue;
|
||||
|
||||
@ -85,7 +98,7 @@ pub const ConnectionPool = struct {
|
||||
pool.used.remove(node);
|
||||
|
||||
if (node.data.closing) {
|
||||
node.data.close(client);
|
||||
node.data.deinit(client);
|
||||
|
||||
return client.allocator.destroy(node);
|
||||
}
|
||||
@ -93,7 +106,7 @@ pub const ConnectionPool = struct {
|
||||
if (pool.free_len + 1 >= pool.free_size) {
|
||||
const popped = pool.free.popFirst() orelse unreachable;
|
||||
|
||||
popped.data.close(client);
|
||||
popped.data.deinit(client);
|
||||
|
||||
return client.allocator.destroy(popped);
|
||||
}
|
||||
@ -118,7 +131,7 @@ pub const ConnectionPool = struct {
|
||||
defer client.allocator.destroy(node);
|
||||
next = node.next;
|
||||
|
||||
node.data.close(client);
|
||||
node.data.deinit(client);
|
||||
}
|
||||
|
||||
next = pool.used.first;
|
||||
@ -126,7 +139,7 @@ pub const ConnectionPool = struct {
|
||||
defer client.allocator.destroy(node);
|
||||
next = node.next;
|
||||
|
||||
node.data.close(client);
|
||||
node.data.deinit(client);
|
||||
}
|
||||
|
||||
pool.* = undefined;
|
||||
@ -140,13 +153,8 @@ pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.Transfer
|
||||
pub const Connection = struct {
|
||||
stream: net.Stream,
|
||||
/// undefined unless protocol is tls.
|
||||
tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
|
||||
tls_client: *std.crypto.tls.Client,
|
||||
protocol: Protocol,
|
||||
host: []u8,
|
||||
port: u16,
|
||||
|
||||
// This connection has been part of a non keepalive request and cannot be added to the pool.
|
||||
closing: bool = false,
|
||||
|
||||
pub const Protocol = enum { plain, tls };
|
||||
|
||||
@ -211,8 +219,89 @@ pub const Connection = struct {
|
||||
}
|
||||
|
||||
conn.stream.close();
|
||||
}
|
||||
};
|
||||
|
||||
client.allocator.free(conn.host);
|
||||
pub const BufferedConnection = struct {
|
||||
pub const buffer_size = 0x2000;
|
||||
|
||||
conn: Connection,
|
||||
buf: [buffer_size]u8 = undefined,
|
||||
start: u16 = 0,
|
||||
end: u16 = 0,
|
||||
|
||||
pub fn fill(bconn: *BufferedConnection) ReadError!void {
|
||||
if (bconn.end != bconn.start) return;
|
||||
|
||||
const nread = try bconn.conn.read(bconn.buf[0..]);
|
||||
if (nread == 0) return error.EndOfStream;
|
||||
bconn.start = 0;
|
||||
bconn.end = @truncate(u16, nread);
|
||||
}
|
||||
|
||||
pub fn peek(bconn: *BufferedConnection) []const u8 {
|
||||
return bconn.buf[bconn.start..bconn.end];
|
||||
}
|
||||
|
||||
pub fn clear(bconn: *BufferedConnection, num: u16) void {
|
||||
bconn.start += num;
|
||||
}
|
||||
|
||||
pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
|
||||
var out_index: u16 = 0;
|
||||
while (out_index < len) {
|
||||
const available = bconn.end - bconn.start;
|
||||
const left = buffer.len - out_index;
|
||||
|
||||
if (available > 0) {
|
||||
const can_read = @truncate(u16, @min(available, left));
|
||||
|
||||
std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]);
|
||||
out_index += can_read;
|
||||
bconn.start += can_read;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if (left > bconn.buf.len) {
|
||||
// skip the buffer if the output is large enough
|
||||
return bconn.conn.read(buffer[out_index..]);
|
||||
}
|
||||
|
||||
try bconn.fill();
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
|
||||
return bconn.readAtLeast(buffer, 1);
|
||||
}
|
||||
|
||||
pub const ReadError = Connection.ReadError || error{EndOfStream};
|
||||
pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
|
||||
|
||||
pub fn reader(bconn: *BufferedConnection) Reader {
|
||||
return Reader{ .context = bconn };
|
||||
}
|
||||
|
||||
pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
|
||||
return bconn.conn.writeAll(buffer);
|
||||
}
|
||||
|
||||
pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
|
||||
return bconn.conn.write(buffer);
|
||||
}
|
||||
|
||||
pub const WriteError = Connection.WriteError;
|
||||
pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
|
||||
|
||||
pub fn writer(bconn: *BufferedConnection) Writer {
|
||||
return Writer{ .context = bconn };
|
||||
}
|
||||
|
||||
pub fn close(bconn: *BufferedConnection, client: *const Client) void {
|
||||
bconn.conn.close(client);
|
||||
}
|
||||
};
|
||||
|
||||
@ -417,7 +506,7 @@ pub const Request = struct {
|
||||
req.* = undefined;
|
||||
}
|
||||
|
||||
pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
|
||||
pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
|
||||
|
||||
pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
|
||||
|
||||
@ -430,7 +519,7 @@ pub const Request = struct {
|
||||
|
||||
var index: usize = 0;
|
||||
while (index == 0) {
|
||||
const amt = try req.response.parser.read(req.connection.data.reader(), buf[index..], req.response.skip);
|
||||
const amt = try req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip);
|
||||
if (amt == 0 and req.response.parser.isComplete()) break;
|
||||
index += amt;
|
||||
}
|
||||
@ -438,10 +527,17 @@ pub const Request = struct {
|
||||
return index;
|
||||
}
|
||||
|
||||
pub const WaitForCompleteHeadError = Connection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Response.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported};
|
||||
pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Response.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported};
|
||||
|
||||
pub fn waitForCompleteHead(req: *Request) !void {
|
||||
try req.response.parser.waitForCompleteHead(req.connection.data.reader(), req.client.allocator);
|
||||
while (true) {
|
||||
try req.connection.data.buffered.fill();
|
||||
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
|
||||
req.connection.data.buffered.clear(@intCast(u16, nchecked));
|
||||
|
||||
if (req.response.parser.state.isContent()) break;
|
||||
}
|
||||
|
||||
req.response.headers = try Response.Headers.parse(req.response.parser.header_bytes.items);
|
||||
|
||||
@ -550,7 +646,7 @@ pub const Request = struct {
|
||||
return index;
|
||||
}
|
||||
|
||||
pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
|
||||
pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
|
||||
|
||||
pub const Writer = std.io.Writer(*Request, WriteError, write);
|
||||
|
||||
@ -562,16 +658,16 @@ pub const Request = struct {
|
||||
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
|
||||
switch (req.headers.transfer_encoding) {
|
||||
.chunked => {
|
||||
try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
|
||||
try req.connection.data.writeAll(bytes);
|
||||
try req.connection.data.writeAll("\r\n");
|
||||
try req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len});
|
||||
try req.connection.data.conn.writeAll(bytes);
|
||||
try req.connection.data.conn.writeAll("\r\n");
|
||||
|
||||
return bytes.len;
|
||||
},
|
||||
.content_length => |*len| {
|
||||
if (len.* < bytes.len) return error.MessageTooLong;
|
||||
|
||||
const amt = try req.connection.data.write(bytes);
|
||||
const amt = try req.connection.data.conn.write(bytes);
|
||||
len.* -= amt;
|
||||
return amt;
|
||||
},
|
||||
@ -582,7 +678,7 @@ pub const Request = struct {
|
||||
/// Finish the body of a request. This notifies the server that you have no more data to send.
|
||||
pub fn finish(req: *Request) !void {
|
||||
switch (req.headers.transfer_encoding) {
|
||||
.chunked => try req.connection.data.writeAll("0\r\n"),
|
||||
.chunked => try req.connection.data.conn.writeAll("0\r\n"),
|
||||
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
|
||||
.none => {},
|
||||
}
|
||||
@ -610,10 +706,14 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
|
||||
errdefer client.allocator.destroy(conn);
|
||||
conn.* = .{ .data = undefined };
|
||||
|
||||
const stream = try net.tcpConnectToHost(client.allocator, host, port);
|
||||
|
||||
conn.data = .{
|
||||
.stream = try net.tcpConnectToHost(client.allocator, host, port),
|
||||
.tls_client = undefined,
|
||||
.protocol = protocol,
|
||||
.buffered = .{ .conn = .{
|
||||
.stream = stream,
|
||||
.tls_client = undefined,
|
||||
.protocol = protocol,
|
||||
} },
|
||||
.host = try client.allocator.dupe(u8, host),
|
||||
.port = port,
|
||||
};
|
||||
@ -621,11 +721,11 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
|
||||
switch (protocol) {
|
||||
.plain => {},
|
||||
.tls => {
|
||||
conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
|
||||
conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
|
||||
conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client);
|
||||
conn.data.buffered.conn.tls_client.* = try std.crypto.tls.Client.init(stream, client.ca_bundle, host);
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
conn.data.tls_client.allow_truncation_attacks = true;
|
||||
conn.data.buffered.conn.tls_client.allow_truncation_attacks = true;
|
||||
},
|
||||
}
|
||||
|
||||
@ -634,7 +734,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
|
||||
return conn;
|
||||
}
|
||||
|
||||
pub const RequestError = ConnectError || Connection.WriteError || error{
|
||||
pub const RequestError = ConnectError || BufferedConnection.WriteError || error{
|
||||
UnsupportedUrlScheme,
|
||||
UriMissingHost,
|
||||
|
||||
@ -708,7 +808,7 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Opt
|
||||
req.arena = std.heap.ArenaAllocator.init(client.allocator);
|
||||
|
||||
{
|
||||
var buffered = std.io.bufferedWriter(req.connection.data.writer());
|
||||
var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
|
||||
const writer = buffered.writer();
|
||||
|
||||
const escaped_path = try Uri.escapePath(client.allocator, uri.path);
|
||||
|
||||
@ -1,276 +0,0 @@
|
||||
const std = @import("std");
|
||||
const http = std.http;
|
||||
const mem = std.mem;
|
||||
const testing = std.testing;
|
||||
const assert = std.debug.assert;
|
||||
|
||||
const protocol = @import("../protocol.zig");
|
||||
const Client = @import("../Client.zig");
|
||||
const Response = @This();
|
||||
|
||||
headers: Headers,
|
||||
state: State,
|
||||
header_bytes_owned: bool,
|
||||
/// This could either be a fixed buffer provided by the API user or it
|
||||
/// could be our own array list.
|
||||
header_bytes: std.ArrayListUnmanaged(u8),
|
||||
max_header_bytes: usize,
|
||||
next_chunk_length: u64,
|
||||
done: bool = false,
|
||||
|
||||
compression: union(enum) {
|
||||
deflate: Client.DeflateDecompressor,
|
||||
gzip: Client.GzipDecompressor,
|
||||
zstd: Client.ZstdDecompressor,
|
||||
none: void,
|
||||
} = .none,
|
||||
|
||||
pub const Headers = struct {
|
||||
status: http.Status,
|
||||
version: http.Version,
|
||||
location: ?[]const u8 = null,
|
||||
content_length: ?u64 = null,
|
||||
transfer_encoding: ?http.TransferEncoding = null,
|
||||
transfer_compression: ?http.ContentEncoding = null,
|
||||
connection: http.Connection = .close,
|
||||
upgrade: ?[]const u8 = null,
|
||||
|
||||
number_of_headers: usize = 0,
|
||||
|
||||
pub fn parse(bytes: []const u8) !Headers {
|
||||
var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n");
|
||||
|
||||
const first_line = it.first();
|
||||
if (first_line.len < 12)
|
||||
return error.ShortHttpStatusLine;
|
||||
|
||||
const version: http.Version = switch (int64(first_line[0..8])) {
|
||||
int64("HTTP/1.0") => .@"HTTP/1.0",
|
||||
int64("HTTP/1.1") => .@"HTTP/1.1",
|
||||
else => return error.BadHttpVersion,
|
||||
};
|
||||
if (first_line[8] != ' ') return error.HttpHeadersInvalid;
|
||||
const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*));
|
||||
|
||||
var headers: Headers = .{
|
||||
.version = version,
|
||||
.status = status,
|
||||
};
|
||||
|
||||
while (it.next()) |line| {
|
||||
headers.number_of_headers += 1;
|
||||
|
||||
if (line.len == 0) return error.HttpHeadersInvalid;
|
||||
switch (line[0]) {
|
||||
' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
|
||||
else => {},
|
||||
}
|
||||
var line_it = mem.split(u8, line, ": ");
|
||||
const header_name = line_it.first();
|
||||
const header_value = line_it.rest();
|
||||
if (std.ascii.eqlIgnoreCase(header_name, "location")) {
|
||||
if (headers.location != null) return error.HttpHeadersInvalid;
|
||||
headers.location = header_value;
|
||||
} else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
|
||||
if (headers.content_length != null) return error.HttpHeadersInvalid;
|
||||
headers.content_length = try std.fmt.parseInt(u64, header_value, 10);
|
||||
} else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
|
||||
// Transfer-Encoding: second, first
|
||||
// Transfer-Encoding: deflate, chunked
|
||||
var iter = std.mem.splitBackwards(u8, header_value, ",");
|
||||
|
||||
if (iter.next()) |first| {
|
||||
const trimmed = std.mem.trim(u8, first, " ");
|
||||
|
||||
if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
|
||||
if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
|
||||
headers.transfer_encoding = te;
|
||||
} else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
|
||||
if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
|
||||
headers.transfer_compression = ce;
|
||||
} else {
|
||||
return error.HttpTransferEncodingUnsupported;
|
||||
}
|
||||
}
|
||||
|
||||
if (iter.next()) |second| {
|
||||
if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
|
||||
|
||||
const trimmed = std.mem.trim(u8, second, " ");
|
||||
|
||||
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
|
||||
headers.transfer_compression = ce;
|
||||
} else {
|
||||
return error.HttpTransferEncodingUnsupported;
|
||||
}
|
||||
}
|
||||
|
||||
if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
|
||||
} else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
|
||||
if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
|
||||
|
||||
const trimmed = std.mem.trim(u8, header_value, " ");
|
||||
|
||||
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
|
||||
headers.transfer_compression = ce;
|
||||
} else {
|
||||
return error.HttpTransferEncodingUnsupported;
|
||||
}
|
||||
} else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
|
||||
if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
|
||||
headers.connection = .keep_alive;
|
||||
} else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
|
||||
headers.connection = .close;
|
||||
} else {
|
||||
return error.HttpConnectionHeaderUnsupported;
|
||||
}
|
||||
} else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) {
|
||||
headers.upgrade = header_value;
|
||||
}
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
test "parse headers" {
|
||||
const example =
|
||||
"HTTP/1.1 301 Moved Permanently\r\n" ++
|
||||
"Location: https://www.example.com/\r\n" ++
|
||||
"Content-Type: text/html; charset=UTF-8\r\n" ++
|
||||
"Content-Length: 220\r\n\r\n";
|
||||
const parsed = try Headers.parse(example);
|
||||
try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version);
|
||||
try testing.expectEqual(http.Status.moved_permanently, parsed.status);
|
||||
try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse
|
||||
return error.TestFailed);
|
||||
try testing.expectEqual(@as(?u64, 220), parsed.content_length);
|
||||
}
|
||||
|
||||
test "header continuation" {
|
||||
const example =
|
||||
"HTTP/1.0 200 OK\r\n" ++
|
||||
"Content-Type: text/html;\r\n charset=UTF-8\r\n" ++
|
||||
"Content-Length: 220\r\n\r\n";
|
||||
try testing.expectError(
|
||||
error.HttpHeaderContinuationsUnsupported,
|
||||
Headers.parse(example),
|
||||
);
|
||||
}
|
||||
|
||||
test "extra content length" {
|
||||
const example =
|
||||
"HTTP/1.0 200 OK\r\n" ++
|
||||
"Content-Length: 220\r\n" ++
|
||||
"Content-Type: text/html; charset=UTF-8\r\n" ++
|
||||
"content-length: 220\r\n\r\n";
|
||||
try testing.expectError(
|
||||
error.HttpHeadersInvalid,
|
||||
Headers.parse(example),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
inline fn int64(array: *const [8]u8) u64 {
|
||||
return @bitCast(u64, array.*);
|
||||
}
|
||||
|
||||
pub const State = enum {
|
||||
/// Begin header parsing states.
|
||||
invalid,
|
||||
start,
|
||||
seen_r,
|
||||
seen_rn,
|
||||
seen_rnr,
|
||||
finished,
|
||||
/// Begin transfer-encoding: chunked parsing states.
|
||||
chunk_size_prefix_r,
|
||||
chunk_size_prefix_n,
|
||||
chunk_size,
|
||||
chunk_r,
|
||||
chunk_data,
|
||||
|
||||
pub fn isContent(self: State) bool {
|
||||
return switch (self) {
|
||||
.invalid, .start, .seen_r, .seen_rn, .seen_rnr => false,
|
||||
.finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub fn initDynamic(max: usize) Response {
|
||||
return .{
|
||||
.state = .start,
|
||||
.headers = undefined,
|
||||
.header_bytes = .{},
|
||||
.max_header_bytes = max,
|
||||
.header_bytes_owned = true,
|
||||
.next_chunk_length = undefined,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn initStatic(buf: []u8) Response {
|
||||
return .{
|
||||
.state = .start,
|
||||
.headers = undefined,
|
||||
.header_bytes = .{ .items = buf[0..0], .capacity = buf.len },
|
||||
.max_header_bytes = buf.len,
|
||||
.header_bytes_owned = false,
|
||||
.next_chunk_length = undefined,
|
||||
};
|
||||
}
|
||||
|
||||
fn parseInt3(nnn: @Vector(3, u8)) u10 {
|
||||
const zero: @Vector(3, u8) = .{ '0', '0', '0' };
|
||||
const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
|
||||
return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm);
|
||||
}
|
||||
|
||||
test parseInt3 {
|
||||
const expectEqual = std.testing.expectEqual;
|
||||
try expectEqual(@as(u10, 0), parseInt3("000".*));
|
||||
try expectEqual(@as(u10, 418), parseInt3("418".*));
|
||||
try expectEqual(@as(u10, 999), parseInt3("999".*));
|
||||
}
|
||||
|
||||
test "find headers end basic" {
|
||||
var buffer: [1]u8 = undefined;
|
||||
var r = Response.initStatic(&buffer);
|
||||
try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4"));
|
||||
try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18"));
|
||||
try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah"));
|
||||
}
|
||||
|
||||
test "find headers end vectorized" {
|
||||
var buffer: [1]u8 = undefined;
|
||||
var r = Response.initStatic(&buffer);
|
||||
const example =
|
||||
"HTTP/1.1 301 Moved Permanently\r\n" ++
|
||||
"Location: https://www.example.com/\r\n" ++
|
||||
"Content-Type: text/html; charset=UTF-8\r\n" ++
|
||||
"Content-Length: 220\r\n" ++
|
||||
"\r\ncontent";
|
||||
try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example));
|
||||
}
|
||||
|
||||
test "find headers end bug" {
|
||||
var buffer: [1]u8 = undefined;
|
||||
var r = Response.initStatic(&buffer);
|
||||
const trail = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
|
||||
const example =
|
||||
"HTTP/1.1 200 OK\r\n" ++
|
||||
"Access-Control-Allow-Origin: https://render.githubusercontent.com\r\n" ++
|
||||
"content-disposition: attachment; filename=zig-0.10.0.tar.gz\r\n" ++
|
||||
"Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox\r\n" ++
|
||||
"Content-Type: application/x-gzip\r\n" ++
|
||||
"ETag: \"bfae0af6b01c7c0d89eb667cb5f0e65265968aeebda2689177e6b26acd3155ca\"\r\n" ++
|
||||
"Strict-Transport-Security: max-age=31536000\r\n" ++
|
||||
"Vary: Authorization,Accept-Encoding,Origin\r\n" ++
|
||||
"X-Content-Type-Options: nosniff\r\n" ++
|
||||
"X-Frame-Options: deny\r\n" ++
|
||||
"X-XSS-Protection: 1; mode=block\r\n" ++
|
||||
"Date: Fri, 06 Jan 2023 22:26:22 GMT\r\n" ++
|
||||
"Transfer-Encoding: chunked\r\n" ++
|
||||
"X-GitHub-Request-Id: 89C6:17E9:A7C9E:124B51:63B8A00E\r\n" ++
|
||||
"connection: close\r\n\r\n" ++ trail;
|
||||
try testing.expectEqual(@as(usize, example.len - trail.len), r.findHeadersEnd(example));
|
||||
}
|
||||
@ -74,6 +74,89 @@ pub const Connection = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const BufferedConnection = struct {
|
||||
pub const buffer_size = 0x2000;
|
||||
|
||||
conn: Connection,
|
||||
buf: [buffer_size]u8 = undefined,
|
||||
start: u16 = 0,
|
||||
end: u16 = 0,
|
||||
|
||||
pub fn fill(bconn: *BufferedConnection) ReadError!void {
|
||||
if (bconn.end != bconn.start) return;
|
||||
|
||||
const nread = try bconn.conn.read(bconn.buf[0..]);
|
||||
if (nread == 0) return error.EndOfStream;
|
||||
bconn.start = 0;
|
||||
bconn.end = @truncate(u16, nread);
|
||||
}
|
||||
|
||||
pub fn peek(bconn: *BufferedConnection) []const u8 {
|
||||
return bconn.buf[bconn.start..bconn.end];
|
||||
}
|
||||
|
||||
pub fn clear(bconn: *BufferedConnection, num: u16) void {
|
||||
bconn.start += num;
|
||||
}
|
||||
|
||||
pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
|
||||
var out_index: u16 = 0;
|
||||
while (out_index < len) {
|
||||
const available = bconn.end - bconn.start;
|
||||
const left = buffer.len - out_index;
|
||||
|
||||
if (available > 0) {
|
||||
const can_read = @truncate(u16, @min(available, left));
|
||||
|
||||
std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]);
|
||||
out_index += can_read;
|
||||
bconn.start += can_read;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if (left > bconn.buf.len) {
|
||||
// skip the buffer if the output is large enough
|
||||
return bconn.conn.read(buffer[out_index..]);
|
||||
}
|
||||
|
||||
try bconn.fill();
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
|
||||
return bconn.readAtLeast(buffer, 1);
|
||||
}
|
||||
|
||||
pub const ReadError = Connection.ReadError || error{EndOfStream};
|
||||
pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
|
||||
|
||||
pub fn reader(bconn: *BufferedConnection) Reader {
|
||||
return Reader{ .context = bconn };
|
||||
}
|
||||
|
||||
pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
|
||||
return bconn.conn.writeAll(buffer);
|
||||
}
|
||||
|
||||
pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
|
||||
return bconn.conn.write(buffer);
|
||||
}
|
||||
|
||||
pub const WriteError = Connection.WriteError;
|
||||
pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
|
||||
|
||||
pub fn writer(bconn: *BufferedConnection) Writer {
|
||||
return Writer{ .context = bconn };
|
||||
}
|
||||
|
||||
pub fn close(bconn: *BufferedConnection) void {
|
||||
bconn.conn.close();
|
||||
}
|
||||
};
|
||||
|
||||
pub const Request = struct {
|
||||
pub const Headers = struct {
|
||||
method: http.Method,
|
||||
@ -222,7 +305,7 @@ pub const Response = struct {
|
||||
|
||||
server: *Server,
|
||||
address: net.Address,
|
||||
connection: Connection,
|
||||
connection: BufferedConnection,
|
||||
|
||||
headers: Headers = .{},
|
||||
request: Request,
|
||||
@ -237,10 +320,10 @@ pub const Response = struct {
|
||||
|
||||
if (!res.request.parser.done) {
|
||||
// If the response wasn't fully read, then we need to close the connection.
|
||||
res.connection.closing = true;
|
||||
res.connection.conn.closing = true;
|
||||
}
|
||||
|
||||
if (res.connection.closing) {
|
||||
if (res.connection.conn.closing) {
|
||||
res.connection.close();
|
||||
|
||||
if (res.request.parser.header_bytes_owned) {
|
||||
@ -296,7 +379,7 @@ pub const Response = struct {
|
||||
try buffered.flush();
|
||||
}
|
||||
|
||||
pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
|
||||
pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
|
||||
|
||||
pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead);
|
||||
|
||||
@ -309,7 +392,7 @@ pub const Response = struct {
|
||||
|
||||
var index: usize = 0;
|
||||
while (index == 0) {
|
||||
const amt = try res.request.parser.read(res.connection.reader(), buf[index..], false);
|
||||
const amt = try res.request.parser.read(&res.connection, buf[index..], false);
|
||||
if (amt == 0 and res.request.parser.isComplete()) break;
|
||||
index += amt;
|
||||
}
|
||||
@ -317,17 +400,24 @@ pub const Response = struct {
|
||||
return index;
|
||||
}
|
||||
|
||||
pub const WaitForCompleteHeadError = Connection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported};
|
||||
pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported};
|
||||
|
||||
pub fn waitForCompleteHead(res: *Response) !void {
|
||||
try res.request.parser.waitForCompleteHead(res.connection.reader(), res.server.allocator);
|
||||
while (true) {
|
||||
try res.connection.fill();
|
||||
|
||||
const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek());
|
||||
res.connection.clear(@intCast(u16, nchecked));
|
||||
|
||||
if (res.request.parser.state.isContent()) break;
|
||||
}
|
||||
|
||||
res.request.headers = try Request.Headers.parse(res.request.parser.header_bytes.items);
|
||||
|
||||
if (res.headers.connection == .keep_alive and res.request.headers.connection == .keep_alive) {
|
||||
res.connection.closing = false;
|
||||
res.connection.conn.closing = false;
|
||||
} else {
|
||||
res.connection.closing = true;
|
||||
res.connection.conn.closing = true;
|
||||
}
|
||||
|
||||
if (res.request.headers.transfer_encoding) |te| {
|
||||
@ -388,7 +478,7 @@ pub const Response = struct {
|
||||
return index;
|
||||
}
|
||||
|
||||
pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
|
||||
pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
|
||||
|
||||
pub const Writer = std.io.Writer(*Response, WriteError, write);
|
||||
|
||||
@ -479,10 +569,10 @@ pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response {
|
||||
res.* = .{
|
||||
.server = server,
|
||||
.address = in.address,
|
||||
.connection = .{
|
||||
.connection = .{ .conn = .{
|
||||
.stream = in.stream,
|
||||
.protocol = .plain,
|
||||
},
|
||||
} },
|
||||
.request = .{
|
||||
.parser = switch (options) {
|
||||
.dynamic => |max| proto.HeadersParser.initDynamic(max),
|
||||
|
||||
@ -29,9 +29,6 @@ pub const State = enum {
|
||||
}
|
||||
};
|
||||
|
||||
const read_buffer_size = 0x4000;
|
||||
const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size);
|
||||
|
||||
pub const HeadersParser = struct {
|
||||
state: State = .start,
|
||||
/// Wether or not `header_bytes` is allocated or was provided as a fixed buffer.
|
||||
@ -46,10 +43,6 @@ pub const HeadersParser = struct {
|
||||
/// A message is only done when the entire payload has been read
|
||||
done: bool = false,
|
||||
|
||||
read_buffer: [read_buffer_size]u8 = undefined,
|
||||
read_buffer_start: ReadBufferIndex = 0,
|
||||
read_buffer_len: ReadBufferIndex = 0,
|
||||
|
||||
pub fn initDynamic(max: usize) HeadersParser {
|
||||
return .{
|
||||
.header_bytes = .{},
|
||||
@ -232,7 +225,7 @@ pub const HeadersParser = struct {
|
||||
}
|
||||
},
|
||||
4...vector_len - 1 => {
|
||||
for (0..vector_len - 4) |i_usize| {
|
||||
inline for (0..vector_len - 3) |i_usize| {
|
||||
const i = @truncate(u32, i_usize);
|
||||
|
||||
const b32 = int32(chunk[i..][0..4]);
|
||||
@ -246,6 +239,27 @@ pub const HeadersParser = struct {
|
||||
return index + i + 2;
|
||||
}
|
||||
}
|
||||
|
||||
const b24 = int24(chunk[vector_len - 3 ..][0..3]);
|
||||
const b16 = intShift(u16, b24);
|
||||
const b8 = intShift(u8, b24);
|
||||
|
||||
switch (b8) {
|
||||
'\r' => r.state = .seen_r,
|
||||
'\n' => r.state = .seen_n,
|
||||
else => {},
|
||||
}
|
||||
|
||||
switch (b16) {
|
||||
int16("\r\n") => r.state = .seen_rn,
|
||||
int16("\n\n") => r.state = .finished,
|
||||
else => {},
|
||||
}
|
||||
|
||||
switch (b24) {
|
||||
int24("\r\n\r") => r.state = .seen_rnr,
|
||||
else => {},
|
||||
}
|
||||
},
|
||||
else => unreachable,
|
||||
}
|
||||
@ -475,30 +489,6 @@ pub const HeadersParser = struct {
|
||||
return i;
|
||||
}
|
||||
|
||||
/// Set of errors that `waitForCompleteHead` can throw except any errors inherited by `reader`
|
||||
pub const WaitForCompleteHeadError = CheckCompleteHeadError || error{UnexpectedEndOfStream};
|
||||
|
||||
/// Waits for the complete head to be available. This function will continue trying to read until the head is complete
|
||||
/// or an error occurs.
|
||||
pub fn waitForCompleteHead(r: *HeadersParser, reader: anytype, allocator: std.mem.Allocator) !void {
|
||||
if (r.state.isContent()) return;
|
||||
|
||||
while (true) {
|
||||
if (r.read_buffer_start == r.read_buffer_len) {
|
||||
const nread = try reader.read(r.read_buffer[0..]);
|
||||
if (nread == 0) return error.UnexpectedEndOfStream;
|
||||
|
||||
r.read_buffer_start = 0;
|
||||
r.read_buffer_len = @intCast(ReadBufferIndex, nread);
|
||||
}
|
||||
|
||||
const amt = try r.checkCompleteHead(allocator, r.read_buffer[r.read_buffer_start..r.read_buffer_len]);
|
||||
r.read_buffer_start += @intCast(ReadBufferIndex, amt);
|
||||
|
||||
if (amt != 0) return;
|
||||
}
|
||||
}
|
||||
|
||||
pub const ReadError = error{
|
||||
UnexpectedEndOfStream,
|
||||
HttpHeadersExceededSizeLimit,
|
||||
@ -507,48 +497,40 @@ pub const HeadersParser = struct {
|
||||
|
||||
/// Reads the body of the message into `buffer`. If `skip` is true, the buffer will be unused and the body will be
|
||||
/// skipped. Returns the number of bytes placed in the buffer.
|
||||
pub fn read(r: *HeadersParser, reader: anytype, buffer: []u8, skip: bool) !usize {
|
||||
pub fn read(r: *HeadersParser, bconn: anytype, buffer: []u8, skip: bool) !usize {
|
||||
assert(r.state.isContent());
|
||||
if (r.done) return 0;
|
||||
|
||||
if (r.read_buffer_start == r.read_buffer_len) {
|
||||
const nread = try reader.read(r.read_buffer[0..]);
|
||||
if (nread == 0) return error.UnexpectedEndOfStream;
|
||||
|
||||
r.read_buffer_start = 0;
|
||||
r.read_buffer_len = @intCast(ReadBufferIndex, nread);
|
||||
}
|
||||
|
||||
var out_index: usize = 0;
|
||||
while (true) {
|
||||
switch (r.state) {
|
||||
.invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable,
|
||||
.finished => {
|
||||
const buf_avail = r.read_buffer_len - r.read_buffer_start;
|
||||
const data_avail = r.next_chunk_length;
|
||||
const out_avail = buffer.len;
|
||||
|
||||
// TODO https://github.com/ziglang/zig/issues/14039
|
||||
const read_available = @intCast(usize, @min(buf_avail, data_avail));
|
||||
if (skip) {
|
||||
r.next_chunk_length -= read_available;
|
||||
r.read_buffer_start += @intCast(ReadBufferIndex, read_available);
|
||||
} else {
|
||||
const can_read = @min(read_available, out_avail);
|
||||
r.next_chunk_length -= can_read;
|
||||
try bconn.fill();
|
||||
|
||||
mem.copy(u8, buffer[out_index..], r.read_buffer[r.read_buffer_start..][0..can_read]);
|
||||
r.read_buffer_start += @intCast(ReadBufferIndex, can_read);
|
||||
out_index += can_read;
|
||||
const nread = @min(bconn.peek().len, data_avail);
|
||||
bconn.clear(@intCast(u16, nread));
|
||||
r.next_chunk_length -= nread;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (r.next_chunk_length == 0) r.done = true;
|
||||
const out_avail = buffer.len;
|
||||
|
||||
return out_index;
|
||||
const can_read = @min(data_avail, out_avail);
|
||||
const nread = try bconn.read(buffer[0..can_read]);
|
||||
r.next_chunk_length -= nread;
|
||||
|
||||
return nread;
|
||||
},
|
||||
.chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
|
||||
const i = r.findChunkedLen(r.read_buffer[r.read_buffer_start..r.read_buffer_len]);
|
||||
r.read_buffer_start += @intCast(ReadBufferIndex, i);
|
||||
try bconn.fill();
|
||||
|
||||
const i = r.findChunkedLen(bconn.peek());
|
||||
bconn.clear(@intCast(u16, i));
|
||||
|
||||
switch (r.state) {
|
||||
.invalid => return error.HttpChunkInvalid,
|
||||
@ -565,22 +547,20 @@ pub const HeadersParser = struct {
|
||||
continue;
|
||||
},
|
||||
.chunk_data => {
|
||||
const buf_avail = r.read_buffer_len - r.read_buffer_start;
|
||||
const data_avail = r.next_chunk_length;
|
||||
const out_avail = buffer.len;
|
||||
const out_avail = buffer.len - out_index;
|
||||
|
||||
// TODO https://github.com/ziglang/zig/issues/14039
|
||||
const read_available = @intCast(usize, @min(buf_avail, data_avail));
|
||||
if (skip) {
|
||||
r.next_chunk_length -= read_available;
|
||||
r.read_buffer_start += @intCast(ReadBufferIndex, read_available);
|
||||
} else {
|
||||
const can_read = @min(read_available, out_avail);
|
||||
r.next_chunk_length -= can_read;
|
||||
try bconn.fill();
|
||||
|
||||
mem.copy(u8, buffer[out_index..], r.read_buffer[r.read_buffer_start..][0..can_read]);
|
||||
r.read_buffer_start += @intCast(ReadBufferIndex, can_read);
|
||||
out_index += can_read;
|
||||
const nread = @min(bconn.peek().len, data_avail);
|
||||
bconn.clear(@intCast(u16, nread));
|
||||
r.next_chunk_length -= nread;
|
||||
} else {
|
||||
const can_read = @min(data_avail, out_avail);
|
||||
const nread = try bconn.read(buffer[out_index..][0..can_read]);
|
||||
r.next_chunk_length -= nread;
|
||||
out_index += nread;
|
||||
}
|
||||
|
||||
if (r.next_chunk_length == 0) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user