From aecbfa3a1e9aa379368e6a9a999ca42fc4803f18 Mon Sep 17 00:00:00 2001 From: Nameless Date: Thu, 30 Mar 2023 22:53:59 -0500 Subject: [PATCH] add buffering to connection instead of the http protocol, to allow passing through upgrades --- lib/std/http/Client.zig | 162 ++++++++++++++---- lib/std/http/Client/Response.zig | 276 ------------------------------- lib/std/http/Server.zig | 114 +++++++++++-- lib/std/http/protocol.zig | 118 ++++++------- 4 files changed, 282 insertions(+), 388 deletions(-) delete mode 100644 lib/std/http/Client/Response.zig diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 1139d720ce..4b4e40133a 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -32,7 +32,20 @@ pub const ConnectionPool = struct { is_tls: bool, }; - const Queue = std.TailQueue(Connection); + pub const StoredConnection = struct { + buffered: BufferedConnection, + host: []u8, + port: u16, + + closing: bool = false, + + pub fn deinit(self: *StoredConnection, client: *Client) void { + self.buffered.close(client); + client.allocator.free(self.host); + } + }; + + const Queue = std.TailQueue(StoredConnection); pub const Node = Queue.Node; mutex: std.Thread.Mutex = .{}, @@ -49,7 +62,7 @@ pub const ConnectionPool = struct { var next = pool.free.last; while (next) |node| : (next = node.prev) { - if ((node.data.protocol == .tls) != criteria.is_tls) continue; + if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue; if (node.data.port != criteria.port) continue; if (mem.eql(u8, node.data.host, criteria.host)) continue; @@ -85,7 +98,7 @@ pub const ConnectionPool = struct { pool.used.remove(node); if (node.data.closing) { - node.data.close(client); + node.data.deinit(client); return client.allocator.destroy(node); } @@ -93,7 +106,7 @@ pub const ConnectionPool = struct { if (pool.free_len + 1 >= pool.free_size) { const popped = pool.free.popFirst() orelse unreachable; - popped.data.close(client); + popped.data.deinit(client); return client.allocator.destroy(popped); } @@ -118,7 +131,7 @@ pub const ConnectionPool = struct { defer client.allocator.destroy(node); next = node.next; - node.data.close(client); + node.data.deinit(client); } next = pool.used.first; @@ -126,7 +139,7 @@ pub const ConnectionPool = struct { defer client.allocator.destroy(node); next = node.next; - node.data.close(client); + node.data.deinit(client); } pool.* = undefined; @@ -140,13 +153,8 @@ pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.Transfer pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB. + tls_client: *std.crypto.tls.Client, protocol: Protocol, - host: []u8, - port: u16, - - // This connection has been part of a non keepalive request and cannot be added to the pool. - closing: bool = false, pub const Protocol = enum { plain, tls }; @@ -211,8 +219,89 @@ pub const Connection = struct { } conn.stream.close(); + } +}; - client.allocator.free(conn.host); +pub const BufferedConnection = struct { + pub const buffer_size = 0x2000; + + conn: Connection, + buf: [buffer_size]u8 = undefined, + start: u16 = 0, + end: u16 = 0, + + pub fn fill(bconn: *BufferedConnection) ReadError!void { + if (bconn.end != bconn.start) return; + + const nread = try bconn.conn.read(bconn.buf[0..]); + if (nread == 0) return error.EndOfStream; + bconn.start = 0; + bconn.end = @truncate(u16, nread); + } + + pub fn peek(bconn: *BufferedConnection) []const u8 { + return bconn.buf[bconn.start..bconn.end]; + } + + pub fn clear(bconn: *BufferedConnection, num: u16) void { + bconn.start += num; + } + + pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { + var out_index: u16 = 0; + while (out_index < len) { + const available = bconn.end - bconn.start; + const left = buffer.len - out_index; + + if (available > 0) { + const can_read = @truncate(u16, @min(available, left)); + + std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]); + out_index += can_read; + bconn.start += can_read; + + continue; + } + + if (left > bconn.buf.len) { + // skip the buffer if the output is large enough + return bconn.conn.read(buffer[out_index..]); + } + + try bconn.fill(); + } + + return out_index; + } + + pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize { + return bconn.readAtLeast(buffer, 1); + } + + pub const ReadError = Connection.ReadError || error{EndOfStream}; + pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read); + + pub fn reader(bconn: *BufferedConnection) Reader { + return Reader{ .context = bconn }; + } + + pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { + return bconn.conn.writeAll(buffer); + } + + pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { + return bconn.conn.write(buffer); + } + + pub const WriteError = Connection.WriteError; + pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write); + + pub fn writer(bconn: *BufferedConnection) Writer { + return Writer{ .context = bconn }; + } + + pub fn close(bconn: *BufferedConnection, client: *const Client) void { + bconn.conn.close(client); } }; @@ -417,7 +506,7 @@ pub const Request = struct { req.* = undefined; } - pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError; pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); @@ -430,7 +519,7 @@ pub const Request = struct { var index: usize = 0; while (index == 0) { - const amt = try req.response.parser.read(req.connection.data.reader(), buf[index..], req.response.skip); + const amt = try req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip); if (amt == 0 and req.response.parser.isComplete()) break; index += amt; } @@ -438,10 +527,17 @@ pub const Request = struct { return index; } - pub const WaitForCompleteHeadError = Connection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Response.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported}; + pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Response.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported}; pub fn waitForCompleteHead(req: *Request) !void { - try req.response.parser.waitForCompleteHead(req.connection.data.reader(), req.client.allocator); + while (true) { + try req.connection.data.buffered.fill(); + + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek()); + req.connection.data.buffered.clear(@intCast(u16, nchecked)); + + if (req.response.parser.state.isContent()) break; + } req.response.headers = try Response.Headers.parse(req.response.parser.header_bytes.items); @@ -550,7 +646,7 @@ pub const Request = struct { return index; } - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; + pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong }; pub const Writer = std.io.Writer(*Request, WriteError, write); @@ -562,16 +658,16 @@ pub const Request = struct { pub fn write(req: *Request, bytes: []const u8) WriteError!usize { switch (req.headers.transfer_encoding) { .chunked => { - try req.connection.data.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.data.writeAll(bytes); - try req.connection.data.writeAll("\r\n"); + try req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.data.conn.writeAll(bytes); + try req.connection.data.conn.writeAll("\r\n"); return bytes.len; }, .content_length => |*len| { if (len.* < bytes.len) return error.MessageTooLong; - const amt = try req.connection.data.write(bytes); + const amt = try req.connection.data.conn.write(bytes); len.* -= amt; return amt; }, @@ -582,7 +678,7 @@ pub const Request = struct { /// Finish the body of a request. This notifies the server that you have no more data to send. pub fn finish(req: *Request) !void { switch (req.headers.transfer_encoding) { - .chunked => try req.connection.data.writeAll("0\r\n"), + .chunked => try req.connection.data.conn.writeAll("0\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } @@ -610,10 +706,14 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio errdefer client.allocator.destroy(conn); conn.* = .{ .data = undefined }; + const stream = try net.tcpConnectToHost(client.allocator, host, port); + conn.data = .{ - .stream = try net.tcpConnectToHost(client.allocator, host, port), - .tls_client = undefined, - .protocol = protocol, + .buffered = .{ .conn = .{ + .stream = stream, + .tls_client = undefined, + .protocol = protocol, + } }, .host = try client.allocator.dupe(u8, host), .port = port, }; @@ -621,11 +721,11 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio switch (protocol) { .plain => {}, .tls => { - conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); - conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host); + conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client); + conn.data.buffered.conn.tls_client.* = try std.crypto.tls.Client.init(stream, client.ca_bundle, host); // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. - conn.data.tls_client.allow_truncation_attacks = true; + conn.data.buffered.conn.tls_client.allow_truncation_attacks = true; }, } @@ -634,7 +734,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio return conn; } -pub const RequestError = ConnectError || Connection.WriteError || error{ +pub const RequestError = ConnectError || BufferedConnection.WriteError || error{ UnsupportedUrlScheme, UriMissingHost, @@ -708,7 +808,7 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Opt req.arena = std.heap.ArenaAllocator.init(client.allocator); { - var buffered = std.io.bufferedWriter(req.connection.data.writer()); + var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer()); const writer = buffered.writer(); const escaped_path = try Uri.escapePath(client.allocator, uri.path); diff --git a/lib/std/http/Client/Response.zig b/lib/std/http/Client/Response.zig deleted file mode 100644 index d81decd9cc..0000000000 --- a/lib/std/http/Client/Response.zig +++ /dev/null @@ -1,276 +0,0 @@ -const std = @import("std"); -const http = std.http; -const mem = std.mem; -const testing = std.testing; -const assert = std.debug.assert; - -const protocol = @import("../protocol.zig"); -const Client = @import("../Client.zig"); -const Response = @This(); - -headers: Headers, -state: State, -header_bytes_owned: bool, -/// This could either be a fixed buffer provided by the API user or it -/// could be our own array list. -header_bytes: std.ArrayListUnmanaged(u8), -max_header_bytes: usize, -next_chunk_length: u64, -done: bool = false, - -compression: union(enum) { - deflate: Client.DeflateDecompressor, - gzip: Client.GzipDecompressor, - zstd: Client.ZstdDecompressor, - none: void, -} = .none, - -pub const Headers = struct { - status: http.Status, - version: http.Version, - location: ?[]const u8 = null, - content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - transfer_compression: ?http.ContentEncoding = null, - connection: http.Connection = .close, - upgrade: ?[]const u8 = null, - - number_of_headers: usize = 0, - - pub fn parse(bytes: []const u8) !Headers { - var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); - - const first_line = it.first(); - if (first_line.len < 12) - return error.ShortHttpStatusLine; - - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.BadHttpVersion, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); - - var headers: Headers = .{ - .version = version, - .status = status, - }; - - while (it.next()) |line| { - headers.number_of_headers += 1; - - if (line.len == 0) return error.HttpHeadersInvalid; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - var line_it = mem.split(u8, line, ": "); - const header_name = line_it.first(); - const header_value = line_it.rest(); - if (std.ascii.eqlIgnoreCase(header_name, "location")) { - if (headers.location != null) return error.HttpHeadersInvalid; - headers.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (headers.content_length != null) return error.HttpHeadersInvalid; - headers.content_length = try std.fmt.parseInt(u64, header_value, 10); - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = std.mem.splitBackwards(u8, header_value, ","); - - if (iter.next()) |first| { - const trimmed = std.mem.trim(u8, first, " "); - - if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { - if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; - headers.transfer_encoding = te; - } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - if (headers.transfer_compression != null) return error.HttpHeadersInvalid; - headers.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - - if (iter.next()) |second| { - if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; - - const trimmed = std.mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - headers.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (headers.transfer_compression != null) return error.HttpHeadersInvalid; - - const trimmed = std.mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - headers.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { - headers.connection = .keep_alive; - } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { - headers.connection = .close; - } else { - return error.HttpConnectionHeaderUnsupported; - } - } else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) { - headers.upgrade = header_value; - } - } - - return headers; - } - - test "parse headers" { - const example = - "HTTP/1.1 301 Moved Permanently\r\n" ++ - "Location: https://www.example.com/\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n\r\n"; - const parsed = try Headers.parse(example); - try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version); - try testing.expectEqual(http.Status.moved_permanently, parsed.status); - try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse - return error.TestFailed); - try testing.expectEqual(@as(?u64, 220), parsed.content_length); - } - - test "header continuation" { - const example = - "HTTP/1.0 200 OK\r\n" ++ - "Content-Type: text/html;\r\n charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n\r\n"; - try testing.expectError( - error.HttpHeaderContinuationsUnsupported, - Headers.parse(example), - ); - } - - test "extra content length" { - const example = - "HTTP/1.0 200 OK\r\n" ++ - "Content-Length: 220\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "content-length: 220\r\n\r\n"; - try testing.expectError( - error.HttpHeadersInvalid, - Headers.parse(example), - ); - } -}; - -inline fn int64(array: *const [8]u8) u64 { - return @bitCast(u64, array.*); -} - -pub const State = enum { - /// Begin header parsing states. - invalid, - start, - seen_r, - seen_rn, - seen_rnr, - finished, - /// Begin transfer-encoding: chunked parsing states. - chunk_size_prefix_r, - chunk_size_prefix_n, - chunk_size, - chunk_r, - chunk_data, - - pub fn isContent(self: State) bool { - return switch (self) { - .invalid, .start, .seen_r, .seen_rn, .seen_rnr => false, - .finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true, - }; - } -}; - -pub fn initDynamic(max: usize) Response { - return .{ - .state = .start, - .headers = undefined, - .header_bytes = .{}, - .max_header_bytes = max, - .header_bytes_owned = true, - .next_chunk_length = undefined, - }; -} - -pub fn initStatic(buf: []u8) Response { - return .{ - .state = .start, - .headers = undefined, - .header_bytes = .{ .items = buf[0..0], .capacity = buf.len }, - .max_header_bytes = buf.len, - .header_bytes_owned = false, - .next_chunk_length = undefined, - }; -} - -fn parseInt3(nnn: @Vector(3, u8)) u10 { - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); -} - -test parseInt3 { - const expectEqual = std.testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000".*)); - try expectEqual(@as(u10, 418), parseInt3("418".*)); - try expectEqual(@as(u10, 999), parseInt3("999".*)); -} - -test "find headers end basic" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4")); - try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18")); - try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah")); -} - -test "find headers end vectorized" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - const example = - "HTTP/1.1 301 Moved Permanently\r\n" ++ - "Location: https://www.example.com/\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n" ++ - "\r\ncontent"; - try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example)); -} - -test "find headers end bug" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - const trail = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; - const example = - "HTTP/1.1 200 OK\r\n" ++ - "Access-Control-Allow-Origin: https://render.githubusercontent.com\r\n" ++ - "content-disposition: attachment; filename=zig-0.10.0.tar.gz\r\n" ++ - "Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox\r\n" ++ - "Content-Type: application/x-gzip\r\n" ++ - "ETag: \"bfae0af6b01c7c0d89eb667cb5f0e65265968aeebda2689177e6b26acd3155ca\"\r\n" ++ - "Strict-Transport-Security: max-age=31536000\r\n" ++ - "Vary: Authorization,Accept-Encoding,Origin\r\n" ++ - "X-Content-Type-Options: nosniff\r\n" ++ - "X-Frame-Options: deny\r\n" ++ - "X-XSS-Protection: 1; mode=block\r\n" ++ - "Date: Fri, 06 Jan 2023 22:26:22 GMT\r\n" ++ - "Transfer-Encoding: chunked\r\n" ++ - "X-GitHub-Request-Id: 89C6:17E9:A7C9E:124B51:63B8A00E\r\n" ++ - "connection: close\r\n\r\n" ++ trail; - try testing.expectEqual(@as(usize, example.len - trail.len), r.findHeadersEnd(example)); -} diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index ba48fb3f83..b870d267f5 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -74,6 +74,89 @@ pub const Connection = struct { } }; +pub const BufferedConnection = struct { + pub const buffer_size = 0x2000; + + conn: Connection, + buf: [buffer_size]u8 = undefined, + start: u16 = 0, + end: u16 = 0, + + pub fn fill(bconn: *BufferedConnection) ReadError!void { + if (bconn.end != bconn.start) return; + + const nread = try bconn.conn.read(bconn.buf[0..]); + if (nread == 0) return error.EndOfStream; + bconn.start = 0; + bconn.end = @truncate(u16, nread); + } + + pub fn peek(bconn: *BufferedConnection) []const u8 { + return bconn.buf[bconn.start..bconn.end]; + } + + pub fn clear(bconn: *BufferedConnection, num: u16) void { + bconn.start += num; + } + + pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { + var out_index: u16 = 0; + while (out_index < len) { + const available = bconn.end - bconn.start; + const left = buffer.len - out_index; + + if (available > 0) { + const can_read = @truncate(u16, @min(available, left)); + + std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]); + out_index += can_read; + bconn.start += can_read; + + continue; + } + + if (left > bconn.buf.len) { + // skip the buffer if the output is large enough + return bconn.conn.read(buffer[out_index..]); + } + + try bconn.fill(); + } + + return out_index; + } + + pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize { + return bconn.readAtLeast(buffer, 1); + } + + pub const ReadError = Connection.ReadError || error{EndOfStream}; + pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read); + + pub fn reader(bconn: *BufferedConnection) Reader { + return Reader{ .context = bconn }; + } + + pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { + return bconn.conn.writeAll(buffer); + } + + pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { + return bconn.conn.write(buffer); + } + + pub const WriteError = Connection.WriteError; + pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write); + + pub fn writer(bconn: *BufferedConnection) Writer { + return Writer{ .context = bconn }; + } + + pub fn close(bconn: *BufferedConnection) void { + bconn.conn.close(); + } +}; + pub const Request = struct { pub const Headers = struct { method: http.Method, @@ -222,7 +305,7 @@ pub const Response = struct { server: *Server, address: net.Address, - connection: Connection, + connection: BufferedConnection, headers: Headers = .{}, request: Request, @@ -237,10 +320,10 @@ pub const Response = struct { if (!res.request.parser.done) { // If the response wasn't fully read, then we need to close the connection. - res.connection.closing = true; + res.connection.conn.closing = true; } - if (res.connection.closing) { + if (res.connection.conn.closing) { res.connection.close(); if (res.request.parser.header_bytes_owned) { @@ -296,7 +379,7 @@ pub const Response = struct { try buffered.flush(); } - pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError; pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead); @@ -309,7 +392,7 @@ pub const Response = struct { var index: usize = 0; while (index == 0) { - const amt = try res.request.parser.read(res.connection.reader(), buf[index..], false); + const amt = try res.request.parser.read(&res.connection, buf[index..], false); if (amt == 0 and res.request.parser.isComplete()) break; index += amt; } @@ -317,17 +400,24 @@ pub const Response = struct { return index; } - pub const WaitForCompleteHeadError = Connection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported}; + pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported}; pub fn waitForCompleteHead(res: *Response) !void { - try res.request.parser.waitForCompleteHead(res.connection.reader(), res.server.allocator); + while (true) { + try res.connection.fill(); + + const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); + res.connection.clear(@intCast(u16, nchecked)); + + if (res.request.parser.state.isContent()) break; + } res.request.headers = try Request.Headers.parse(res.request.parser.header_bytes.items); if (res.headers.connection == .keep_alive and res.request.headers.connection == .keep_alive) { - res.connection.closing = false; + res.connection.conn.closing = false; } else { - res.connection.closing = true; + res.connection.conn.closing = true; } if (res.request.headers.transfer_encoding) |te| { @@ -388,7 +478,7 @@ pub const Response = struct { return index; } - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; + pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong }; pub const Writer = std.io.Writer(*Response, WriteError, write); @@ -479,10 +569,10 @@ pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response { res.* = .{ .server = server, .address = in.address, - .connection = .{ + .connection = .{ .conn = .{ .stream = in.stream, .protocol = .plain, - }, + } }, .request = .{ .parser = switch (options) { .dynamic => |max| proto.HeadersParser.initDynamic(max), diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 0006ba8df3..b898bf2a99 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -29,9 +29,6 @@ pub const State = enum { } }; -const read_buffer_size = 0x4000; -const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size); - pub const HeadersParser = struct { state: State = .start, /// Wether or not `header_bytes` is allocated or was provided as a fixed buffer. @@ -46,10 +43,6 @@ pub const HeadersParser = struct { /// A message is only done when the entire payload has been read done: bool = false, - read_buffer: [read_buffer_size]u8 = undefined, - read_buffer_start: ReadBufferIndex = 0, - read_buffer_len: ReadBufferIndex = 0, - pub fn initDynamic(max: usize) HeadersParser { return .{ .header_bytes = .{}, @@ -232,7 +225,7 @@ pub const HeadersParser = struct { } }, 4...vector_len - 1 => { - for (0..vector_len - 4) |i_usize| { + inline for (0..vector_len - 3) |i_usize| { const i = @truncate(u32, i_usize); const b32 = int32(chunk[i..][0..4]); @@ -246,6 +239,27 @@ pub const HeadersParser = struct { return index + i + 2; } } + + const b24 = int24(chunk[vector_len - 3 ..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => r.state = .seen_rn, + int16("\n\n") => r.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => r.state = .seen_rnr, + else => {}, + } }, else => unreachable, } @@ -475,30 +489,6 @@ pub const HeadersParser = struct { return i; } - /// Set of errors that `waitForCompleteHead` can throw except any errors inherited by `reader` - pub const WaitForCompleteHeadError = CheckCompleteHeadError || error{UnexpectedEndOfStream}; - - /// Waits for the complete head to be available. This function will continue trying to read until the head is complete - /// or an error occurs. - pub fn waitForCompleteHead(r: *HeadersParser, reader: anytype, allocator: std.mem.Allocator) !void { - if (r.state.isContent()) return; - - while (true) { - if (r.read_buffer_start == r.read_buffer_len) { - const nread = try reader.read(r.read_buffer[0..]); - if (nread == 0) return error.UnexpectedEndOfStream; - - r.read_buffer_start = 0; - r.read_buffer_len = @intCast(ReadBufferIndex, nread); - } - - const amt = try r.checkCompleteHead(allocator, r.read_buffer[r.read_buffer_start..r.read_buffer_len]); - r.read_buffer_start += @intCast(ReadBufferIndex, amt); - - if (amt != 0) return; - } - } - pub const ReadError = error{ UnexpectedEndOfStream, HttpHeadersExceededSizeLimit, @@ -507,48 +497,40 @@ pub const HeadersParser = struct { /// Reads the body of the message into `buffer`. If `skip` is true, the buffer will be unused and the body will be /// skipped. Returns the number of bytes placed in the buffer. - pub fn read(r: *HeadersParser, reader: anytype, buffer: []u8, skip: bool) !usize { + pub fn read(r: *HeadersParser, bconn: anytype, buffer: []u8, skip: bool) !usize { assert(r.state.isContent()); if (r.done) return 0; - if (r.read_buffer_start == r.read_buffer_len) { - const nread = try reader.read(r.read_buffer[0..]); - if (nread == 0) return error.UnexpectedEndOfStream; - - r.read_buffer_start = 0; - r.read_buffer_len = @intCast(ReadBufferIndex, nread); - } - var out_index: usize = 0; while (true) { switch (r.state) { .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable, .finished => { - const buf_avail = r.read_buffer_len - r.read_buffer_start; const data_avail = r.next_chunk_length; - const out_avail = buffer.len; - // TODO https://github.com/ziglang/zig/issues/14039 - const read_available = @intCast(usize, @min(buf_avail, data_avail)); if (skip) { - r.next_chunk_length -= read_available; - r.read_buffer_start += @intCast(ReadBufferIndex, read_available); - } else { - const can_read = @min(read_available, out_avail); - r.next_chunk_length -= can_read; + try bconn.fill(); - mem.copy(u8, buffer[out_index..], r.read_buffer[r.read_buffer_start..][0..can_read]); - r.read_buffer_start += @intCast(ReadBufferIndex, can_read); - out_index += can_read; + const nread = @min(bconn.peek().len, data_avail); + bconn.clear(@intCast(u16, nread)); + r.next_chunk_length -= nread; + + return 0; } - if (r.next_chunk_length == 0) r.done = true; + const out_avail = buffer.len; - return out_index; + const can_read = @min(data_avail, out_avail); + const nread = try bconn.read(buffer[0..can_read]); + r.next_chunk_length -= nread; + + return nread; }, .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { - const i = r.findChunkedLen(r.read_buffer[r.read_buffer_start..r.read_buffer_len]); - r.read_buffer_start += @intCast(ReadBufferIndex, i); + try bconn.fill(); + + const i = r.findChunkedLen(bconn.peek()); + bconn.clear(@intCast(u16, i)); switch (r.state) { .invalid => return error.HttpChunkInvalid, @@ -565,22 +547,20 @@ pub const HeadersParser = struct { continue; }, .chunk_data => { - const buf_avail = r.read_buffer_len - r.read_buffer_start; const data_avail = r.next_chunk_length; - const out_avail = buffer.len; + const out_avail = buffer.len - out_index; - // TODO https://github.com/ziglang/zig/issues/14039 - const read_available = @intCast(usize, @min(buf_avail, data_avail)); if (skip) { - r.next_chunk_length -= read_available; - r.read_buffer_start += @intCast(ReadBufferIndex, read_available); - } else { - const can_read = @min(read_available, out_avail); - r.next_chunk_length -= can_read; + try bconn.fill(); - mem.copy(u8, buffer[out_index..], r.read_buffer[r.read_buffer_start..][0..can_read]); - r.read_buffer_start += @intCast(ReadBufferIndex, can_read); - out_index += can_read; + const nread = @min(bconn.peek().len, data_avail); + bconn.clear(@intCast(u16, nread)); + r.next_chunk_length -= nread; + } else { + const can_read = @min(data_avail, out_avail); + const nread = try bconn.read(buffer[out_index..][0..can_read]); + r.next_chunk_length -= nread; + out_index += nread; } if (r.next_chunk_length == 0) {