From 6395ba852a88f0e0b2a2f0659f1daf9d08e90157 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 20 Feb 2024 03:30:51 -0700 Subject: [PATCH] std.http.Server: rework the API entirely Mainly, this removes the poorly named `wait`, `send`, `finish` functions, which all operated on the same "Response" object, which was actually being used as the request. Now, it looks like this: 1. std.net.Server.accept() gives you a std.net.Server.Connection 2. std.http.Server.init() with the connection 3. Server.receiveHead() gives you a Request 4. Request.reader() gives you a body reader 5. Request.respond() is a one-shot, or Request.respondStreaming() creates a Response 6. Response.writer() gives you a body writer 7. Response.end() finishes the response; Response.endChunked() allows passing response trailers. In other words, the type system now guides the API user down the correct path. receiveHead allows extra bytes to be read into the read buffer, and then will reuse those bytes for the body or the next request upon connection reuse. respond(), the one-shot function, will send the entire response in one syscall. Streaming response bodies no longer wastefully wraps every call to write with a chunk header and trailer; instead it only sends the HTTP chunk wrapper when flushing. This means the user can still control when it happens but it also does not add unnecessary chunks. Empirically, in my example project that uses this API, the usage code is significantly less noisy, it has less error handling while handling errors more correctly, it's more obvious what is happening, and it is syscall-optimal. Additionally: * Uncouple std.http.HeadParser from protocol.zig * Delete std.Server.Connection; use std.net.Server.Connection instead. - The API user supplies the read buffer when initializing the http.Server, and it is used for the HTTP head as well as a buffer for reading the body into. * Replace and document the State enum. No longer is there both "start" and "first". --- lib/std/http.zig | 2 + lib/std/http/HeadParser.zig | 370 +++++++++ lib/std/http/Server.zig | 1245 ++++++++++++++++------------ lib/std/http/Server/Connection.zig | 119 --- lib/std/http/protocol.zig | 364 +------- lib/std/http/test.zig | 75 +- 6 files changed, 1150 insertions(+), 1025 deletions(-) create mode 100644 lib/std/http/HeadParser.zig delete mode 100644 lib/std/http/Server/Connection.zig diff --git a/lib/std/http.zig b/lib/std/http.zig index bdeab598a6..613abc66b2 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -3,6 +3,7 @@ const std = @import("std.zig"); pub const Client = @import("http/Client.zig"); pub const Server = @import("http/Server.zig"); pub const protocol = @import("http/protocol.zig"); +pub const HeadParser = @import("http/HeadParser.zig"); pub const Version = enum { @"HTTP/1.0", @@ -311,5 +312,6 @@ test { _ = Method; _ = Server; _ = Status; + _ = HeadParser; _ = @import("http/test.zig"); } diff --git a/lib/std/http/HeadParser.zig b/lib/std/http/HeadParser.zig new file mode 100644 index 0000000000..07c357731a --- /dev/null +++ b/lib/std/http/HeadParser.zig @@ -0,0 +1,370 @@ +state: State = .start, + +pub const State = enum { + start, + seen_n, + seen_r, + seen_rn, + seen_rnr, + finished, +}; + +/// Returns the number of bytes consumed by headers. This is always less +/// than or equal to `bytes.len`. +/// +/// If the amount returned is less than `bytes.len`, the parser is in a +/// content state and the first byte of content is located at +/// `bytes[result]`. +pub fn feed(p: *HeadParser, bytes: []const u8) usize { + const vector_len: comptime_int = @max(std.simd.suggestVectorLength(u8) orelse 1, 8); + const len: u32 = @intCast(bytes.len); + var index: u32 = 0; + + while (true) { + switch (p.state) { + .finished => return index, + .start => switch (len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + return index + 1; + }, + 2 => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + return index + 2; + }, + 3 => { + const b24 = int24(bytes[index..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + + return index + 3; + }, + 4...vector_len - 1 => { + const b32 = int32(bytes[index..][0..4]); + const b24 = intShift(u24, b32); + const b16 = intShift(u16, b32); + const b8 = intShift(u8, b32); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + + switch (b32) { + int32("\r\n\r\n") => p.state = .finished, + else => {}, + } + + index += 4; + continue; + }, + else => { + const chunk = bytes[index..][0..vector_len]; + const matches = if (use_vectors) matches: { + const Vector = @Vector(vector_len, u8); + // const BoolVector = @Vector(vector_len, bool); + const BitVector = @Vector(vector_len, u1); + const SizeVector = @Vector(vector_len, u8); + + const v: Vector = chunk.*; + const matches_r: BitVector = @bitCast(v == @as(Vector, @splat('\r'))); + const matches_n: BitVector = @bitCast(v == @as(Vector, @splat('\n'))); + const matches_or: SizeVector = matches_r | matches_n; + + break :matches @reduce(.Add, matches_or); + } else matches: { + var matches: u8 = 0; + for (chunk) |byte| switch (byte) { + '\r', '\n' => matches += 1, + else => {}, + }; + break :matches matches; + }; + switch (matches) { + 0 => {}, + 1 => switch (chunk[vector_len - 1]) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + }, + 2 => { + const b16 = int16(chunk[vector_len - 2 ..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + }, + 3 => { + const b24 = int24(chunk[vector_len - 3 ..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + }, + 4...vector_len => { + inline for (0..vector_len - 3) |i_usize| { + const i = @as(u32, @truncate(i_usize)); + + const b32 = int32(chunk[i..][0..4]); + const b16 = intShift(u16, b32); + + if (b32 == int32("\r\n\r\n")) { + p.state = .finished; + return index + i + 4; + } else if (b16 == int16("\n\n")) { + p.state = .finished; + return index + i + 2; + } + } + + const b24 = int24(chunk[vector_len - 3 ..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + }, + else => unreachable, + } + + index += vector_len; + continue; + }, + }, + .seen_n => switch (len - index) { + 0 => return index, + else => { + switch (bytes[index]) { + '\n' => p.state = .finished, + else => p.state = .start, + } + + index += 1; + continue; + }, + }, + .seen_r => switch (len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\n' => p.state = .seen_rn, + '\r' => p.state = .seen_r, + else => p.state = .start, + } + + return index + 1; + }, + 2 => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_rn, + else => p.state = .start, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\r") => p.state = .seen_rnr, + int16("\n\n") => p.state = .finished, + else => {}, + } + + return index + 2; + }, + else => { + const b24 = int24(bytes[index..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => p.state = .start, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\n\r\n") => p.state = .finished, + else => {}, + } + + index += 3; + continue; + }, + }, + .seen_rn => switch (len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => p.state = .seen_rnr, + '\n' => p.state = .seen_n, + else => p.state = .start, + } + + return index + 1; + }, + else => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_rnr, + '\n' => p.state = .seen_n, + else => p.state = .start, + } + + switch (b16) { + int16("\r\n") => p.state = .finished, + int16("\n\n") => p.state = .finished, + else => {}, + } + + index += 2; + continue; + }, + }, + .seen_rnr => switch (len - index) { + 0 => return index, + else => { + switch (bytes[index]) { + '\n' => p.state = .finished, + else => p.state = .start, + } + + index += 1; + continue; + }, + }, + } + + return index; + } +} + +inline fn int16(array: *const [2]u8) u16 { + return @bitCast(array.*); +} + +inline fn int24(array: *const [3]u8) u24 { + return @bitCast(array.*); +} + +inline fn int32(array: *const [4]u8) u32 { + return @bitCast(array.*); +} + +inline fn intShift(comptime T: type, x: anytype) T { + switch (@import("builtin").cpu.arch.endian()) { + .little => return @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T))), + .big => return @truncate(x), + } +} + +const HeadParser = @This(); +const std = @import("std"); +const use_vectors = builtin.zig_backend != .stage2_x86_64; +const builtin = @import("builtin"); + +test feed { + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello"; + + for (0..36) |i| { + var p: HeadParser = .{}; + try std.testing.expectEqual(i, p.feed(data[0..i])); + try std.testing.expectEqual(35 - i, p.feed(data[i..])); + } +} diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index ce770f14c4..73289d713f 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,613 +1,805 @@ -connection: Connection, -/// This value is determined by Server when sending headers to the client, and -/// then used to determine the return value of `reset`. -connection_keep_alive: bool, +//! Blocking HTTP server implementation. -/// The HTTP request that this response is responding to. -/// -/// This field is only valid after calling `wait`. -request: Request, +connection: net.Server.Connection, +/// Keeps track of whether the Server is ready to accept a new request on the +/// same connection, and makes invalid API usage cause assertion failures +/// rather than HTTP protocol violations. +state: State, +/// User-provided buffer that must outlive this Server. +/// Used to store the client's entire HTTP header. +read_buffer: []u8, +/// Amount of available data inside read_buffer. +read_buffer_len: usize, +/// Index into `read_buffer` of the first byte of the next HTTP request. +next_request_start: usize, -state: State = .first, +pub const State = enum { + /// The connection is available to be used for the first time, or reused. + ready, + /// An error occurred in `receiveHead`. + receiving_head, + /// A Request object has been obtained and from there a Response can be + /// opened. + received_head, + /// The client is uploading something to this Server. + receiving_body, + /// The connection is eligible for another HTTP request, however the client + /// and server did not negotiate connection: keep-alive. + closing, +}; /// Initialize an HTTP server that can respond to multiple requests on the same /// connection. -/// The returned `Server` is ready for `reset` or `wait` to be called. -pub fn init(connection: std.net.Server.Connection, options: Server.Request.InitOptions) Server { +/// The returned `Server` is ready for `readRequest` to be called. +pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { return .{ - .connection = .{ - .stream = connection.stream, - .read_buf = undefined, - .read_start = 0, - .read_end = 0, - }, - .connection_keep_alive = false, - .request = Server.Request.init(options), + .connection = connection, + .state = .ready, + .read_buffer = read_buffer, + .read_buffer_len = 0, + .next_request_start = 0, }; } -pub const State = enum { - first, - start, - waited, - responded, - finished, +pub const ReceiveHeadError = error{ + /// Client sent too many bytes of HTTP headers. + /// The HTTP specification suggests to respond with a 431 status code + /// before closing the connection. + HttpHeadersOversize, + /// Client sent headers that did not conform to the HTTP protocol. + HttpHeadersInvalid, + /// A low level I/O error occurred trying to read the headers. + HttpHeadersUnreadable, }; -pub const ResetState = enum { reset, closing }; +/// The header bytes reference the read buffer that Server was initialized with +/// and remain alive until the next call to receiveHead. +pub fn receiveHead(s: *Server) ReceiveHeadError!Request { + assert(s.state == .ready); + s.state = .received_head; + errdefer s.state = .receiving_head; -pub const Connection = @import("Server/Connection.zig"); - -/// The mode of transport for responses. -pub const ResponseTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for request messages. -pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(Server.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(Server.TransferReader); - // https://github.com/ziglang/zig/issues/18937 - //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Server.TransferReader, .{}); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - // https://github.com/ziglang/zig/issues/18937 - //zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP request originating from a client. -pub const Request = struct { - method: http.Method, - target: []const u8, - version: http.Version, - expect: ?[]const u8, - content_type: ?[]const u8, - content_length: ?u64, - transfer_encoding: http.TransferEncoding, - transfer_compression: http.ContentEncoding, - keep_alive: bool, - parser: proto.HeadersParser, - compression: Compression, - - pub const InitOptions = struct { - /// Externally-owned memory used to store the client's entire HTTP header. - /// `error.HttpHeadersOversize` is returned from read() when a - /// client sends too many bytes of HTTP headers. - client_header_buffer: []u8, - }; - - pub fn init(options: InitOptions) Request { - return .{ - .method = undefined, - .target = undefined, - .version = undefined, - .expect = null, - .content_type = null, - .content_length = null, - .transfer_encoding = .none, - .transfer_compression = .identity, - .keep_alive = false, - .parser = proto.HeadersParser.init(options.client_header_buffer), - .compression = .none, - }; + // In case of a reused connection, move the next request's bytes to the + // beginning of the buffer. + if (s.next_request_start > 0) { + if (s.read_buffer_len > s.next_request_start) { + const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; + const dest = s.read_buffer[0..leftover.len]; + if (leftover.len <= s.next_request_start) { + @memcpy(dest, leftover); + } else { + mem.copyBackwards(u8, dest, leftover); + } + s.read_buffer_len = leftover.len; + } + s.next_request_start = 0; } - pub const ParseError = error{ - UnknownHttpMethod, - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, + var hp: http.HeadParser = .{}; + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = s.connection.stream.read(buf) catch + return error.HttpHeadersUnreadable; + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) return .{ + .server = s, + .head_end = end, + .head = Request.Head.parse(s.read_buffer[0..end]) catch + return error.HttpHeadersInvalid, + .reader_state = undefined, + }; + } +} + +pub const Request = struct { + server: *Server, + /// Index into Server's read_buffer. + head_end: usize, + head: Head, + reader_state: union { + remaining_content_length: u64, + }, + + pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); + pub const ZstdDecompressor = std.compress.zstd.Decompressor(std.io.AnyReader); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + zstd: ZstdDecompressor, + none: void, }; - pub fn parse(req: *Request, bytes: []const u8) ParseError!void { - var it = mem.splitSequence(u8, bytes, "\r\n"); + pub const Head = struct { + method: http.Method, + target: []const u8, + version: http.Version, + expect: ?[]const u8, + content_type: ?[]const u8, + content_length: ?u64, + transfer_encoding: http.TransferEncoding, + transfer_compression: http.ContentEncoding, + keep_alive: bool, + compression: Compression, - const first_line = it.next().?; - if (first_line.len < 10) - return error.HttpHeadersInvalid; - - const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse - return error.HttpHeadersInvalid; - if (method_end > 24) return error.HttpHeadersInvalid; - - const method_str = first_line[0..method_end]; - const method: http.Method = @enumFromInt(http.Method.parse(method_str)); - - const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse - return error.HttpHeadersInvalid; - if (version_start == method_end) return error.HttpHeadersInvalid; - - const version_str = first_line[version_start + 1 ..]; - if (version_str.len != 8) return error.HttpHeadersInvalid; - const version: http.Version = switch (int64(version_str[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, + pub const ParseError = error{ + UnknownHttpMethod, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionUnsupported, + MissingFinalNewline, }; - const target = first_line[method_end + 1 .. version_start]; + pub fn parse(bytes: []const u8) ParseError!Head { + var it = mem.splitSequence(u8, bytes, "\r\n"); - req.method = method; - req.target = target; - req.version = version; + const first_line = it.next().?; + if (first_line.len < 10) + return error.HttpHeadersInvalid; - while (it.next()) |line| { - if (line.len == 0) return; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } + const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (method_end > 24) return error.HttpHeadersInvalid; - var line_it = mem.splitSequence(u8, line, ": "); - const header_name = line_it.next().?; - const header_value = line_it.rest(); - if (header_value.len == 0) return error.HttpHeadersInvalid; + const method_str = first_line[0..method_end]; + const method: http.Method = @enumFromInt(http.Method.parse(method_str)); - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - req.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { - req.expect = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - req.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (req.content_length != null) return error.HttpHeadersInvalid; - req.content_length = std.fmt.parseInt(u64, header_value, 10) catch - return error.InvalidContentLength; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; + const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (version_start == method_end) return error.HttpHeadersInvalid; - const trimmed = mem.trim(u8, header_value, " "); + const version_str = first_line[version_start + 1 ..]; + if (version_str.len != 8) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(version_str[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - req.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); + const target = first_line[method_end + 1 .. version_start]; - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); + var head: Head = .{ + .method = method, + .target = target, + .version = version, + .expect = null, + .content_type = null, + .content_length = null, + .transfer_encoding = .none, + .transfer_compression = .identity, + .keep_alive = false, + .compression = .none, + }; - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (req.transfer_encoding != .none) - return error.HttpHeadersInvalid; // we already have a transfer encoding - req.transfer_encoding = transfer; - - next = iter.next(); + while (it.next()) |line| { + if (line.len == 0) return head; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, } - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); + var line_it = mem.splitSequence(u8, line, ": "); + const header_name = line_it.next().?; + const header_value = line_it.rest(); + if (header_value.len == 0) return error.HttpHeadersInvalid; - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (req.transfer_compression != .identity) - return error.HttpHeadersInvalid; // double compression is not supported - req.transfer_compression = transfer; + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + head.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { + head.expect = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + head.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (head.content_length != null) return error.HttpHeadersInvalid; + head.content_length = std.fmt.parseInt(u64, header_value, 10) catch + return error.InvalidContentLength; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + head.transfer_compression = ce; } else { return error.HttpTransferEncodingUnsupported; } + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (head.transfer_encoding != .none) + return error.HttpHeadersInvalid; // we already have a transfer encoding + head.transfer_encoding = transfer; + + next = iter.next(); + } + + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (head.transfer_compression != .identity) + return error.HttpHeadersInvalid; // double compression is not supported + head.transfer_compression = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; } + return error.MissingFinalNewline; } - return error.HttpHeadersInvalid; // missing empty line - } - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } -}; - -/// Reset this response to its initial state. This must be called before -/// handling a second request on the same connection. -pub fn reset(res: *Server) ResetState { - if (res.state == .first) { - res.state = .start; - return .reset; - } - - if (!res.request.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - res.connection_keep_alive = false; - return .closing; - } - - res.state = .start; - res.request = Request.init(.{ - .client_header_buffer = res.request.parser.header_bytes_buffer, - }); - - return if (res.connection_keep_alive) .reset else .closing; -} - -pub const SendAllError = std.net.Stream.WriteError; - -pub const SendOptions = struct { - version: http.Version = .@"HTTP/1.1", - status: http.Status = .ok, - reason: ?[]const u8 = null, - keep_alive: bool = true, - extra_headers: []const http.Header = &.{}, - content: []const u8, -}; - -/// Send an entire HTTP response to the client, including headers and body. -/// Automatically handles HEAD requests by omitting the body. -/// Uses the "content-length" header. -/// Asserts status is not `continue`. -/// Asserts there are at most 25 extra_headers. -pub fn sendAll(s: *Server, options: SendOptions) SendAllError!void { - const max_extra_headers = 25; - assert(options.status != .@"continue"); - assert(options.extra_headers.len <= max_extra_headers); - - switch (s.state) { - .waited => s.state = .finished, - .first => unreachable, // Call reset() first. - .start => unreachable, // Call wait() first. - .responded => unreachable, // Cannot mix sendAll() with send(). - .finished => unreachable, // Call reset() first. - } - - s.connection_keep_alive = options.keep_alive and s.request.keep_alive; - const keep_alive_line = if (s.connection_keep_alive) - "connection: keep-alive\r\n" - else - ""; - const phrase = options.reason orelse options.status.phrase() orelse ""; - - var first_buffer: [500]u8 = undefined; - const first_bytes = std.fmt.bufPrint( - &first_buffer, - "{s} {d} {s}\r\n{s}content-length: {d}\r\n", - .{ - @tagName(options.version), - @intFromEnum(options.status), - phrase, - keep_alive_line, - options.content.len, - }, - ) catch unreachable; - - var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .iov_base = first_bytes.ptr, - .iov_len = first_bytes.len, + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } }; - iovecs_len += 1; - for (options.extra_headers) |header| { + pub const RespondOptions = struct { + version: http.Version = .@"HTTP/1.1", + status: http.Status = .ok, + reason: ?[]const u8 = null, + keep_alive: bool = true, + extra_headers: []const http.Header = &.{}, + }; + + /// Send an entire HTTP response to the client, including headers and body. + /// + /// Automatically handles HEAD requests by omitting the body. + /// Uses the "content-length" header unless `content` is empty in which + /// case it omits the content-length header. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// Asserts status is not `continue`. + /// Asserts there are at most 25 extra_headers. + pub fn respond( + request: *Request, + content: []const u8, + options: RespondOptions, + ) Response.WriteError!void { + const max_extra_headers = 25; + assert(options.status != .@"continue"); + assert(options.extra_headers.len <= max_extra_headers); + + const keep_alive = request.discardBody(options.keep_alive); + + const phrase = options.reason orelse options.status.phrase() orelse ""; + + var first_buffer: [500]u8 = undefined; + var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); + h.writerAssumeCapacity().print("{s} {d} {s}\r\n", .{ + @tagName(options.version), @intFromEnum(options.status), phrase, + }) catch |err| switch (err) {}; + if (keep_alive) + h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + if (content.len > 0) + h.writerAssumeCapacity().print("content-length: {d}\r\n", .{content.len}) catch |err| + switch (err) {}; + + var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; + iovecs[iovecs_len] = .{ - .iov_base = header.name.ptr, - .iov_len = header.name.len, + .iov_base = h.items.ptr, + .iov_len = h.items.len, }; iovecs_len += 1; - iovecs[iovecs_len] = .{ - .iov_base = ": ", - .iov_len = 2, - }; - iovecs_len += 1; + for (options.extra_headers) |header| { + iovecs[iovecs_len] = .{ + .iov_base = header.name.ptr, + .iov_len = header.name.len, + }; + iovecs_len += 1; - iovecs[iovecs_len] = .{ - .iov_base = header.value.ptr, - .iov_len = header.value.len, - }; - iovecs_len += 1; + iovecs[iovecs_len] = .{ + .iov_base = ": ", + .iov_len = 2, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = header.value.ptr, + .iov_len = header.value.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } iovecs[iovecs_len] = .{ .iov_base = "\r\n", .iov_len = 2, }; iovecs_len += 1; + + if (request.head.method != .HEAD and content.len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = content.ptr, + .iov_len = content.len, + }; + iovecs_len += 1; + } + + try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); } - iovecs[iovecs_len] = .{ - .iov_base = "\r\n", - .iov_len = 2, + pub const RespondStreamingOptions = struct { + /// An externally managed slice of memory used to batch bytes before + /// sending. `respondStreaming` asserts this is large enough to store + /// the full HTTP response head. + /// + /// Must outlive the returned Response. + send_buffer: []u8, + /// If provided, the response will use the content-length header; + /// otherwise it will use transfer-encoding: chunked. + content_length: ?u64 = null, + /// Options that are shared with the `respond` method. + respond_options: RespondOptions = .{}, }; - iovecs_len += 1; - if (s.request.method != .HEAD) { - iovecs[iovecs_len] = .{ - .iov_base = options.content.ptr, - .iov_len = options.content.len, + /// The header is buffered but not sent until Response.flush is called. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// HEAD requests are handled transparently by setting a flag on the + /// returned Response to omit the body. However it may be worth noticing + /// that flag and skipping any expensive work that would otherwise need to + /// be done to satisfy the request. + /// + /// Asserts `send_buffer` is large enough to store the entire response header. + /// Asserts status is not `continue`. + pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { + const o = options.respond_options; + assert(o.status != .@"continue"); + + const keep_alive = request.discardBody(o.keep_alive); + const phrase = o.reason orelse o.status.phrase() orelse ""; + + var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); + + h.writerAssumeCapacity().print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }) catch |err| switch (err) {}; + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + + if (options.content_length) |len| { + h.writerAssumeCapacity().print("content-length: {d}\r\n", .{len}) catch |err| switch (err) {}; + } else { + h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); + } + + for (o.extra_headers) |header| { + h.appendSliceAssumeCapacity(header.name); + h.appendSliceAssumeCapacity(": "); + h.appendSliceAssumeCapacity(header.value); + h.appendSliceAssumeCapacity("\r\n"); + } + + h.appendSliceAssumeCapacity("\r\n"); + + return .{ + .stream = request.server.connection.stream, + .send_buffer = options.send_buffer, + .send_buffer_start = 0, + .send_buffer_end = h.items.len, + .content_length = options.content_length, + .elide_body = request.head.method == .HEAD, + .chunk_len = 0, }; - iovecs_len += 1; } - return s.connection.stream.writevAll(iovecs[0..iovecs_len]); -} + pub const ReadError = net.Stream.ReadError; + + fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + assert(s.state == .receiving_body); + + const remaining_content_length = &request.reader_state.remaining_content_length; + + if (remaining_content_length.* == 0) { + s.state = .ready; + return 0; + } + + const available_bytes = s.read_buffer_len - request.head_end; + if (available_bytes == 0) + s.read_buffer_len += try s.connection.stream.read(s.read_buffer[request.head_end..]); + + const available_buf = s.read_buffer[request.head_end..s.read_buffer_len]; + const len = @min(remaining_content_length.*, available_buf.len, buffer.len); + @memcpy(buffer[0..len], available_buf[0..len]); + remaining_content_length.* -= len; + if (remaining_content_length.* == 0) + s.state = .ready; + return len; + } + + fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + assert(s.state == .receiving_body); + _ = buffer; + @panic("TODO"); + } + + pub const ReadAllError = ReadError || error{HttpBodyOversize}; + + pub fn reader(request: *Request) std.io.AnyReader { + const s = request.server; + assert(s.state == .received_head); + s.state = .receiving_body; + switch (request.head.transfer_encoding) { + .chunked => return .{ + .readFn = read_chunked, + .context = request, + }, + .none => { + request.reader_state = .{ + .remaining_content_length = request.head.content_length orelse 0, + }; + return .{ + .readFn = read_cl, + .context = request, + }; + }, + } + } + + /// Returns whether the connection: keep-alive header should be sent to the client. + /// If it would fail, it instead sets the Server state to `receiving_body` + /// and returns false. + fn discardBody(request: *Request, keep_alive: bool) bool { + // Prepare to receive another request on the same connection. + // There are two factors to consider: + // * Any body the client sent must be discarded. + // * The Server's read_buffer may already have some bytes in it from + // whatever came after the head, which may be the next HTTP request + // or the request body. + // If the connection won't be kept alive, then none of this matters + // because the connection will be severed after the response is sent. + const s = request.server; + if (keep_alive and request.head.keep_alive) switch (s.state) { + .received_head => { + s.state = .receiving_body; + switch (request.head.transfer_encoding) { + .none => t: { + const len = request.head.content_length orelse break :t; + const head_end = request.head_end; + var total_body_discarded: usize = 0; + while (true) { + const available_bytes = s.read_buffer_len - head_end; + const remaining_len = len - total_body_discarded; + if (available_bytes >= remaining_len) { + s.next_request_start = head_end + remaining_len; + break :t; + } + total_body_discarded += available_bytes; + // Preserve request header memory until receiveHead is called. + const buf = s.read_buffer[head_end..]; + const read_n = s.connection.stream.read(buf) catch return false; + s.read_buffer_len = head_end + read_n; + } + }, + .chunked => { + @panic("TODO"); + }, + } + s.state = .ready; + return true; + }, + .receiving_body, .ready => return true, + else => unreachable, + } else { + s.state = .closing; + return false; + } + } +}; pub const Response = struct { - transfer_encoding: ResponseTransfer, -}; + stream: net.Stream, + send_buffer: []u8, + /// Index of the first byte in `send_buffer`. + /// This is 0 unless a short write happens in `write`. + send_buffer_start: usize, + /// Index of the last byte + 1 in `send_buffer`. + send_buffer_end: usize, + /// `null` means transfer-encoding: chunked. + /// As a debugging utility, counts down to zero as bytes are written. + content_length: ?u64, + elide_body: bool, + /// Indicates how much of the end of the `send_buffer` corresponds to a + /// chunk. This amount of data will be wrapped by an HTTP chunk header. + chunk_len: usize, -pub const SendError = Connection.WriteError || error{ - UnsupportedTransferEncoding, - InvalidContentLength, -}; + pub const WriteError = net.Stream.WriteError; -/// Send the HTTP response headers to the client. -pub fn send(res: *Server) SendError!void { - switch (res.state) { - .waited => res.state = .responded, - .first, .start, .responded, .finished => unreachable, - } - - var buffered = std.io.bufferedWriter(res.connection.writer()); - const w = buffered.writer(); - - try w.writeAll(@tagName(res.version)); - try w.writeByte(' '); - try w.print("{d}", .{@intFromEnum(res.status)}); - try w.writeByte(' '); - if (res.reason) |reason| { - try w.writeAll(reason); - } else if (res.status.phrase()) |phrase| { - try w.writeAll(phrase); - } - try w.writeAll("\r\n"); - - if (res.status == .@"continue") { - res.state = .waited; // we still need to send another request after this - } else { - res.connection_keep_alive = res.keep_alive and res.request.keep_alive; - if (res.connection_keep_alive) { - try w.writeAll("connection: keep-alive\r\n"); - } else { - try w.writeAll("connection: close\r\n"); + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header, then calls `flush`. + /// Otherwise, transfer-encoding: chunked is being used, and it writes the + /// end-of-stream message, then flushes the stream to the system. + /// When request method is HEAD, does not write anything to the stream. + pub fn end(r: *Response) WriteError!void { + if (r.content_length) |len| { + assert(len == 0); // Trips when end() called before all bytes written. + return flush_cl(r); } - - switch (res.transfer_encoding) { - .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), - .content_length => |content_length| try w.print("content-length: {d}\r\n", .{content_length}), - .none => {}, - } - - for (res.extra_headers) |header| { - try w.print("{s}: {s}\r\n", .{ header.name, header.value }); + if (!r.elide_body) { + return flush_chunked(r, &.{}); } + r.* = undefined; } - if (res.request.method == .HEAD) { - res.transfer_encoding = .none; - } - - try w.writeAll("\r\n"); - - try buffered.flush(); -} - -const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - -const TransferReader = std.io.Reader(*Server, TransferReadError, transferRead); - -fn transferReader(res: *Server) TransferReader { - return .{ .context = res }; -} - -fn transferRead(res: *Server, buf: []u8) TransferReadError!usize { - if (res.request.parser.done) return 0; - - var index: usize = 0; - while (index == 0) { - const amt = try res.request.parser.read(&res.connection, buf[index..], false); - if (amt == 0 and res.request.parser.done) break; - index += amt; - } - - return index; -} - -pub const WaitError = Connection.ReadError || - proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || - error{CompressionUnsupported}; - -/// Wait for the client to send a complete request head. -/// -/// For correct behavior, the following rules must be followed: -/// -/// * If this returns any error in `Connection.ReadError`, you MUST -/// immediately close the connection by calling `deinit`. -/// * If this returns `error.HttpHeadersInvalid`, you MAY immediately close -/// the connection by calling `deinit`. -/// * If this returns `error.HttpHeadersOversize`, you MUST -/// respond with a 431 status code and then call `deinit`. -/// * If this returns any error in `Request.ParseError`, you MUST respond -/// with a 400 status code and then call `deinit`. -/// * If this returns any other error, you MUST respond with a 400 status -/// code and then call `deinit`. -/// * If the request has an Expect header containing 100-continue, you MUST either: -/// * Respond with a 100 status code, then call `wait` again. -/// * Respond with a 417 status code. -pub fn wait(res: *Server) WaitError!void { - switch (res.state) { - .first, .start => res.state = .waited, - .waited, .responded, .finished => unreachable, - } - - while (true) { - try res.connection.fill(); - - const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek()); - res.connection.drop(@intCast(nchecked)); - - if (res.request.parser.state.isContent()) break; - } - - try res.request.parse(res.request.parser.get()); - - switch (res.request.transfer_encoding) { - .none => { - if (res.request.content_length) |len| { - res.request.parser.next_chunk_length = len; - - if (len == 0) res.request.parser.done = true; - } else { - res.request.parser.done = true; - } - }, - .chunked => { - res.request.parser.next_chunk_length = 0; - res.request.parser.state = .chunk_head_size; - }, - } - - if (!res.request.parser.done) { - switch (res.request.transfer_compression) { - .identity => res.request.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - .deflate => res.request.compression = .{ - .deflate = std.compress.zlib.decompressor(res.transferReader()), - }, - .gzip, .@"x-gzip" => res.request.compression = .{ - .gzip = std.compress.gzip.decompressor(res.transferReader()), - }, - .zstd => { - // https://github.com/ziglang/zig/issues/18937 - return error.CompressionUnsupported; - }, - } - } -} - -pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; - -pub const Reader = std.io.Reader(*Server, ReadError, read); - -pub fn reader(res: *Server) Reader { - return .{ .context = res }; -} - -/// Reads data from the response body. Must be called after `wait`. -pub fn read(res: *Server, buffer: []u8) ReadError!usize { - switch (res.state) { - .waited, .responded, .finished => {}, - .first, .start => unreachable, - } - - const out_index = switch (res.request.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try res.transferRead(buffer), + pub const EndChunkedOptions = struct { + trailers: []const http.Header = &.{}, }; - if (out_index == 0) { - const has_trail = !res.request.parser.state.isContent(); + /// Asserts that the Response is using transfer-encoding: chunked. + /// Writes the end-of-stream message and any optional trailers, then + /// flushes the stream to the system. + /// When request method is HEAD, does not write anything to the stream. + /// Asserts there are at most 25 trailers. + pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { + assert(r.content_length == null); + if (r.elide_body) return; + try flush_chunked(r, options.trailers); + r.* = undefined; + } - while (!res.request.parser.state.isContent()) { // read trailing headers - try res.connection.fill(); - - const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek()); - res.connection.drop(@intCast(nchecked)); - } - - if (has_trail) { - // The response headers before the trailers are already - // guaranteed to be valid, so they will always be parsed again - // and cannot return an error. - // This will *only* fail for a malformed trailer. - res.request.parse(res.request.parser.get()) catch return error.InvalidTrailers; + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + /// May return 0, which does not indicate end of stream. The caller decides + /// when the end of stream occurs by calling `end`. + pub fn write(r: *Response, bytes: []const u8) WriteError!usize { + if (r.content_length != null) { + return write_cl(r, bytes); + } else { + return write_chunked(r, bytes); } } - return out_index; -} + fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + const len = &r.content_length.?; + if (r.elide_body) { + len.* -= bytes.len; + return bytes.len; + } -/// Reads data from the response body. Must be called after `wait`. -pub fn readAll(res: *Server, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(res, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; -} + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + var iovecs: [2]std.posix.iovec_const = .{ + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len, + }, + .{ + .iov_base = bytes.ptr, + .iov_len = bytes.len, + }, + }; + const n = try r.stream.writev(&iovecs); -pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - -pub const Writer = std.io.Writer(*Server, WriteError, write); - -pub fn writer(res: *Server) Writer { - return .{ .context = res }; -} - -/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. -/// Must be called after `send` and before `finish`. -pub fn write(res: *Server, bytes: []const u8) WriteError!usize { - switch (res.state) { - .responded => {}, - .first, .waited, .start, .finished => unreachable, - } - - switch (res.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try res.connection.writer().print("{x}\r\n", .{bytes.len}); - try res.connection.writeAll(bytes); - try res.connection.writeAll("\r\n"); + if (n >= send_buffer_len) { + // It was enough to reset the buffer. + r.send_buffer_start = 0; + r.send_buffer_end = 0; + const bytes_n = n - send_buffer_len; + len.* -= bytes_n; + return bytes_n; } + // It didn't even make it through the existing buffer, let + // alone the new bytes provided. + r.send_buffer_start += n; + return 0; + } + + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + len.* -= bytes.len; + return bytes.len; + } + + fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + assert(r.content_length == null); + + if (r.elide_body) return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - const amt = try res.connection.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } -} + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + const chunk_len = r.chunk_len + bytes.len; + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; -/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. -/// Must be called after `send` and before `finish`. -pub fn writeAll(req: *Server, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } -} + var iovecs: [5]std.posix.iovec_const = .{ + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len - r.chunk_len, + }, + .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }, + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .iov_len = r.chunk_len, + }, + .{ + .iov_base = bytes.ptr, + .iov_len = bytes.len, + }, + .{ + .iov_base = "\r\n", + .iov_len = 2, + }, + }; + // TODO make this writev instead of writevAll, which involves + // complicating the logic of this function. + try r.stream.writevAll(&iovecs); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return bytes.len; + } -pub const FinishError = Connection.WriteError || error{MessageNotCompleted}; - -/// Finish the body of a request. This notifies the server that you have no more data to send. -/// Must be called after `send`. -pub fn finish(res: *Server) FinishError!void { - switch (res.state) { - .responded => res.state = .finished, - .first, .waited, .start, .finished => unreachable, + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + r.chunk_len += bytes.len; + return bytes.len; } - switch (res.transfer_encoding) { - .chunked => try res.connection.writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(r, bytes[index..]); + } } -} -const builtin = @import("builtin"); + /// Sends all buffered data to the client. + /// This is redundant after calling `end`. + pub fn flush(r: *Response) WriteError!void { + if (r.content_length != null) { + return flush_cl(r); + } else { + return flush_chunked(r, null); + } + } + + fn flush_cl(r: *Response) WriteError!void { + assert(r.content_length != null); + try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + } + + fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { + const max_trailers = 25; + if (end_trailers) |trailers| assert(trailers.len <= max_trailers); + assert(r.content_length == null); + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; + + var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; + + iovecs[iovecs_len] = .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len - r.chunk_len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .iov_len = r.chunk_len, + }; + iovecs_len += 1; + + if (end_trailers) |trailers| { + if (r.chunk_len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = "\r\n0\r\n", + .iov_len = 5, + }; + iovecs_len += 1; + } + + for (trailers) |trailer| { + iovecs[iovecs_len] = .{ + .iov_base = trailer.name.ptr, + .iov_len = trailer.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = ": ", + .iov_len = 2, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = trailer.value.ptr, + .iov_len = trailer.value.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } else if (r.chunk_len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } + + try r.stream.writevAll(iovecs[0..iovecs_len]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + } + + pub fn writer(r: *Response) std.io.AnyWriter { + return .{ + .writeFn = if (r.content_length != null) write_cl else write_chunked, + .context = r, + }; + } +}; + const std = @import("../std.zig"); -const testing = std.testing; const http = std.http; const mem = std.mem; const net = std.net; @@ -615,4 +807,3 @@ const Uri = std.Uri; const assert = std.debug.assert; const Server = @This(); -const proto = @import("protocol.zig"); diff --git a/lib/std/http/Server/Connection.zig b/lib/std/http/Server/Connection.zig deleted file mode 100644 index 74997fc140..0000000000 --- a/lib/std/http/Server/Connection.zig +++ /dev/null @@ -1,119 +0,0 @@ -stream: std.net.Stream, - -read_buf: [buffer_size]u8, -read_start: u16, -read_end: u16, - -pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - -pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - return conn.stream.readAtLeast(buffer, len) catch |err| { - switch (err) { - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; -} - -pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); -} - -pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; -} - -pub fn drop(conn: *Connection, num: u16) void { - conn.read_start += num; -} - -pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - - var out_index: u16 = 0; - while (out_index < len) { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len - out_index; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - out_index += @as(u16, @intCast(available_buffer)); - conn.read_start += @as(u16, @intCast(available_buffer)); - - break; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - out_index += available_read; - conn.read_start += available_read; - - if (out_index >= len) break; - } - - const leftover_buffer = available_buffer - available_read; - const leftover_len = len - out_index; - - if (leftover_buffer > conn.read_buf.len) { - // skip the buffer if the output is large enough - return conn.rawReadAtLeast(buffer[out_index..], leftover_len); - } - - try conn.fill(); - } - - return out_index; -} - -pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); -} - -pub const ReadError = error{ - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, -}; - -pub const Reader = std.io.Reader(*Connection, ReadError, read); - -pub fn reader(conn: *Connection) Reader { - return .{ .context = conn }; -} - -pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; -} - -pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - return conn.stream.write(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; -} - -pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, -}; - -pub const Writer = std.io.Writer(*Connection, WriteError, write); - -pub fn writer(conn: *Connection) Writer { - return .{ .context = conn }; -} - -pub fn close(conn: *Connection) void { - conn.stream.close(); -} - -const Connection = @This(); -const std = @import("../../std.zig"); -const assert = std.debug.assert; diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 4c69a79105..d080d3c389 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -73,339 +73,28 @@ pub const HeadersParser = struct { return hp.header_bytes_buffer[0..hp.header_bytes_len]; } - /// Returns the number of bytes consumed by headers. This is always less - /// than or equal to `bytes.len`. - /// You should check `r.state.isContent()` after this to check if the - /// headers are done. - /// - /// If the amount returned is less than `bytes.len`, you may assume that - /// the parser is in a content state and the - /// first byte of content is located at `bytes[result]`. pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { - const vector_len: comptime_int = @max(std.simd.suggestVectorLength(u8) orelse 1, 8); - const len: u32 = @intCast(bytes.len); - var index: u32 = 0; - - while (true) { - switch (r.state) { - .invalid => unreachable, - .finished => return index, - .start => switch (len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - return index + 1; - }, - 2 => { - const b16 = int16(bytes[index..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - return index + 2; - }, - 3 => { - const b24 = int24(bytes[index..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - - return index + 3; - }, - 4...vector_len - 1 => { - const b32 = int32(bytes[index..][0..4]); - const b24 = intShift(u24, b32); - const b16 = intShift(u16, b32); - const b8 = intShift(u8, b32); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - - switch (b32) { - int32("\r\n\r\n") => r.state = .finished, - else => {}, - } - - index += 4; - continue; - }, - else => { - const chunk = bytes[index..][0..vector_len]; - const matches = if (use_vectors) matches: { - const Vector = @Vector(vector_len, u8); - // const BoolVector = @Vector(vector_len, bool); - const BitVector = @Vector(vector_len, u1); - const SizeVector = @Vector(vector_len, u8); - - const v: Vector = chunk.*; - const matches_r: BitVector = @bitCast(v == @as(Vector, @splat('\r'))); - const matches_n: BitVector = @bitCast(v == @as(Vector, @splat('\n'))); - const matches_or: SizeVector = matches_r | matches_n; - - break :matches @reduce(.Add, matches_or); - } else matches: { - var matches: u8 = 0; - for (chunk) |byte| switch (byte) { - '\r', '\n' => matches += 1, - else => {}, - }; - break :matches matches; - }; - switch (matches) { - 0 => {}, - 1 => switch (chunk[vector_len - 1]) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - }, - 2 => { - const b16 = int16(chunk[vector_len - 2 ..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - }, - 3 => { - const b24 = int24(chunk[vector_len - 3 ..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - }, - 4...vector_len => { - inline for (0..vector_len - 3) |i_usize| { - const i = @as(u32, @truncate(i_usize)); - - const b32 = int32(chunk[i..][0..4]); - const b16 = intShift(u16, b32); - - if (b32 == int32("\r\n\r\n")) { - r.state = .finished; - return index + i + 4; - } else if (b16 == int16("\n\n")) { - r.state = .finished; - return index + i + 2; - } - } - - const b24 = int24(chunk[vector_len - 3 ..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - }, - else => unreachable, - } - - index += vector_len; - continue; - }, - }, - .seen_n => switch (len - index) { - 0 => return index, - else => { - switch (bytes[index]) { - '\n' => r.state = .finished, - else => r.state = .start, - } - - index += 1; - continue; - }, - }, - .seen_r => switch (len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\n' => r.state = .seen_rn, - '\r' => r.state = .seen_r, - else => r.state = .start, - } - - return index + 1; - }, - 2 => { - const b16 = int16(bytes[index..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_rn, - else => r.state = .start, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\r") => r.state = .seen_rnr, - int16("\n\n") => r.state = .finished, - else => {}, - } - - return index + 2; - }, - else => { - const b24 = int24(bytes[index..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => r.state = .start, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\n\r\n") => r.state = .finished, - else => {}, - } - - index += 3; - continue; - }, - }, - .seen_rn => switch (len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\r' => r.state = .seen_rnr, - '\n' => r.state = .seen_n, - else => r.state = .start, - } - - return index + 1; - }, - else => { - const b16 = int16(bytes[index..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_rnr, - '\n' => r.state = .seen_n, - else => r.state = .start, - } - - switch (b16) { - int16("\r\n") => r.state = .finished, - int16("\n\n") => r.state = .finished, - else => {}, - } - - index += 2; - continue; - }, - }, - .seen_rnr => switch (len - index) { - 0 => return index, - else => { - switch (bytes[index]) { - '\n' => r.state = .finished, - else => r.state = .start, - } - - index += 1; - continue; - }, - }, - .chunk_head_size => unreachable, - .chunk_head_ext => unreachable, - .chunk_head_r => unreachable, - .chunk_data => unreachable, - .chunk_data_suffix => unreachable, - .chunk_data_suffix_r => unreachable, - } - - return index; - } + var hp: std.http.HeadParser = .{ + .state = switch (r.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + else => unreachable, + }, + }; + const result = hp.feed(bytes); + r.state = switch (hp.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + }; + return @intCast(result); } /// Returns the number of bytes consumed by the chunk size. This is always @@ -775,17 +464,6 @@ const MockBufferedConnection = struct { } }; -test "HeadersParser.findHeadersEnd" { - var r: HeadersParser = undefined; - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello"; - - for (0..36) |i| { - r = HeadersParser.init(&.{}); - try std.testing.expectEqual(@as(u32, @intCast(i)), r.findHeadersEnd(data[0..i])); - try std.testing.expectEqual(@as(u32, @intCast(35 - i)), r.findHeadersEnd(data[i..])); - } -} - test "HeadersParser.findChunkedLen" { var r: HeadersParser = undefined; const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n"; diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 3441346baf..3dbccbcff6 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -69,31 +69,35 @@ test "trailers" { fn serverThread(http_server: *std.net.Server) anyerror!void { var header_buffer: [1024]u8 = undefined; var remaining: usize = 1; - accept: while (remaining != 0) : (remaining -= 1) { + while (remaining != 0) : (remaining -= 1) { const conn = try http_server.accept(); defer conn.stream.close(); - var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer }); + var server = std.http.Server.init(conn, &header_buffer); - res.wait() catch |err| switch (err) { - error.HttpHeadersInvalid => continue :accept, - error.EndOfStream => continue, - else => return err, - }; - try serve(&res); - - try testing.expectEqual(.reset, res.reset()); + try testing.expectEqual(.ready, server.state); + var request = try server.receiveHead(); + try serve(&request); + try testing.expectEqual(.ready, server.state); } } -fn serve(res: *std.http.Server) !void { - try testing.expectEqualStrings(res.request.target, "/trailer"); - res.transfer_encoding = .chunked; +fn serve(request: *std.http.Server.Request) !void { + try testing.expectEqualStrings(request.head.target, "/trailer"); - try res.send(); - try res.writeAll("Hello, "); - try res.writeAll("World!\n"); - try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); + var send_buffer: [1024]u8 = undefined; + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + }); + try response.writeAll("Hello, "); + try response.flush(); + try response.writeAll("World!\n"); + try response.flush(); + try response.endChunked(.{ + .trailers = &.{ + .{ .name = "X-Checksum", .value = "aaaa" }, + }, + }); } test "HTTP server handles a chunked transfer coding request" { @@ -116,34 +120,33 @@ test "HTTP server handles a chunked transfer coding request" { const max_header_size = 8192; const address = try std.net.Address.parseIp("127.0.0.1", 0); - var server = try address.listen(.{ .reuse_address = true }); - defer server.deinit(); - const server_port = server.listen_address.in.getPort(); + var socket_server = try address.listen(.{ .reuse_address = true }); + defer socket_server.deinit(); + const server_port = socket_server.listen_address.in.getPort(); const server_thread = try std.Thread.spawn(.{}, (struct { - fn apply(s: *std.net.Server) !void { + fn apply(net_server: *std.net.Server) !void { var header_buffer: [max_header_size]u8 = undefined; - const conn = try s.accept(); + const conn = try net_server.accept(); defer conn.stream.close(); - var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer }); - try res.wait(); - try expect(res.request.transfer_encoding == .chunked); - const server_body: []const u8 = "message from server!\n"; - res.transfer_encoding = .{ .content_length = server_body.len }; - res.extra_headers = &.{ - .{ .name = "content-type", .value = "text/plain" }, - }; - res.keep_alive = false; - try res.send(); + var server = std.http.Server.init(conn, &header_buffer); + var request = try server.receiveHead(); + + try expect(request.head.transfer_encoding == .chunked); var buf: [128]u8 = undefined; - const n = try res.readAll(&buf); + const n = try request.reader().readAll(&buf); try expect(std.mem.eql(u8, buf[0..n], "ABCD")); - _ = try res.writer().writeAll(server_body); - try res.finish(); + + try request.respond("message from server!\n", .{ + .extra_headers = &.{ + .{ .name = "content-type", .value = "text/plain" }, + }, + .keep_alive = false, + }); } - }).apply, .{&server}); + }).apply, .{&socket_server}); const request_bytes = "POST / HTTP/1.1\r\n" ++