diff --git a/lib/std/http.zig b/lib/std/http.zig index bdeab598a6..613abc66b2 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -3,6 +3,7 @@ const std = @import("std.zig"); pub const Client = @import("http/Client.zig"); pub const Server = @import("http/Server.zig"); pub const protocol = @import("http/protocol.zig"); +pub const HeadParser = @import("http/HeadParser.zig"); pub const Version = enum { @"HTTP/1.0", @@ -311,5 +312,6 @@ test { _ = Method; _ = Server; _ = Status; + _ = HeadParser; _ = @import("http/test.zig"); } diff --git a/lib/std/http/HeadParser.zig b/lib/std/http/HeadParser.zig new file mode 100644 index 0000000000..07c357731a --- /dev/null +++ b/lib/std/http/HeadParser.zig @@ -0,0 +1,370 @@ +state: State = .start, + +pub const State = enum { + start, + seen_n, + seen_r, + seen_rn, + seen_rnr, + finished, +}; + +/// Returns the number of bytes consumed by headers. This is always less +/// than or equal to `bytes.len`. +/// +/// If the amount returned is less than `bytes.len`, the parser is in a +/// content state and the first byte of content is located at +/// `bytes[result]`. +pub fn feed(p: *HeadParser, bytes: []const u8) usize { + const vector_len: comptime_int = @max(std.simd.suggestVectorLength(u8) orelse 1, 8); + const len: u32 = @intCast(bytes.len); + var index: u32 = 0; + + while (true) { + switch (p.state) { + .finished => return index, + .start => switch (len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + return index + 1; + }, + 2 => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + return index + 2; + }, + 3 => { + const b24 = int24(bytes[index..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + + return index + 3; + }, + 4...vector_len - 1 => { + const b32 = int32(bytes[index..][0..4]); + const b24 = intShift(u24, b32); + const b16 = intShift(u16, b32); + const b8 = intShift(u8, b32); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + + switch (b32) { + int32("\r\n\r\n") => p.state = .finished, + else => {}, + } + + index += 4; + continue; + }, + else => { + const chunk = bytes[index..][0..vector_len]; + const matches = if (use_vectors) matches: { + const Vector = @Vector(vector_len, u8); + // const BoolVector = @Vector(vector_len, bool); + const BitVector = @Vector(vector_len, u1); + const SizeVector = @Vector(vector_len, u8); + + const v: Vector = chunk.*; + const matches_r: BitVector = @bitCast(v == @as(Vector, @splat('\r'))); + const matches_n: BitVector = @bitCast(v == @as(Vector, @splat('\n'))); + const matches_or: SizeVector = matches_r | matches_n; + + break :matches @reduce(.Add, matches_or); + } else matches: { + var matches: u8 = 0; + for (chunk) |byte| switch (byte) { + '\r', '\n' => matches += 1, + else => {}, + }; + break :matches matches; + }; + switch (matches) { + 0 => {}, + 1 => switch (chunk[vector_len - 1]) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + }, + 2 => { + const b16 = int16(chunk[vector_len - 2 ..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + }, + 3 => { + const b24 = int24(chunk[vector_len - 3 ..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + }, + 4...vector_len => { + inline for (0..vector_len - 3) |i_usize| { + const i = @as(u32, @truncate(i_usize)); + + const b32 = int32(chunk[i..][0..4]); + const b16 = intShift(u16, b32); + + if (b32 == int32("\r\n\r\n")) { + p.state = .finished; + return index + i + 4; + } else if (b16 == int16("\n\n")) { + p.state = .finished; + return index + i + 2; + } + } + + const b24 = int24(chunk[vector_len - 3 ..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + }, + else => unreachable, + } + + index += vector_len; + continue; + }, + }, + .seen_n => switch (len - index) { + 0 => return index, + else => { + switch (bytes[index]) { + '\n' => p.state = .finished, + else => p.state = .start, + } + + index += 1; + continue; + }, + }, + .seen_r => switch (len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\n' => p.state = .seen_rn, + '\r' => p.state = .seen_r, + else => p.state = .start, + } + + return index + 1; + }, + 2 => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_rn, + else => p.state = .start, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\r") => p.state = .seen_rnr, + int16("\n\n") => p.state = .finished, + else => {}, + } + + return index + 2; + }, + else => { + const b24 = int24(bytes[index..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => p.state = .start, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\n\r\n") => p.state = .finished, + else => {}, + } + + index += 3; + continue; + }, + }, + .seen_rn => switch (len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => p.state = .seen_rnr, + '\n' => p.state = .seen_n, + else => p.state = .start, + } + + return index + 1; + }, + else => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_rnr, + '\n' => p.state = .seen_n, + else => p.state = .start, + } + + switch (b16) { + int16("\r\n") => p.state = .finished, + int16("\n\n") => p.state = .finished, + else => {}, + } + + index += 2; + continue; + }, + }, + .seen_rnr => switch (len - index) { + 0 => return index, + else => { + switch (bytes[index]) { + '\n' => p.state = .finished, + else => p.state = .start, + } + + index += 1; + continue; + }, + }, + } + + return index; + } +} + +inline fn int16(array: *const [2]u8) u16 { + return @bitCast(array.*); +} + +inline fn int24(array: *const [3]u8) u24 { + return @bitCast(array.*); +} + +inline fn int32(array: *const [4]u8) u32 { + return @bitCast(array.*); +} + +inline fn intShift(comptime T: type, x: anytype) T { + switch (@import("builtin").cpu.arch.endian()) { + .little => return @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T))), + .big => return @truncate(x), + } +} + +const HeadParser = @This(); +const std = @import("std"); +const use_vectors = builtin.zig_backend != .stage2_x86_64; +const builtin = @import("builtin"); + +test feed { + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello"; + + for (0..36) |i| { + var p: HeadParser = .{}; + try std.testing.expectEqual(i, p.feed(data[0..i])); + try std.testing.expectEqual(35 - i, p.feed(data[i..])); + } +} diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index ce770f14c4..73289d713f 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,613 +1,805 @@ -connection: Connection, -/// This value is determined by Server when sending headers to the client, and -/// then used to determine the return value of `reset`. -connection_keep_alive: bool, +//! Blocking HTTP server implementation. -/// The HTTP request that this response is responding to. -/// -/// This field is only valid after calling `wait`. -request: Request, +connection: net.Server.Connection, +/// Keeps track of whether the Server is ready to accept a new request on the +/// same connection, and makes invalid API usage cause assertion failures +/// rather than HTTP protocol violations. +state: State, +/// User-provided buffer that must outlive this Server. +/// Used to store the client's entire HTTP header. +read_buffer: []u8, +/// Amount of available data inside read_buffer. +read_buffer_len: usize, +/// Index into `read_buffer` of the first byte of the next HTTP request. +next_request_start: usize, -state: State = .first, +pub const State = enum { + /// The connection is available to be used for the first time, or reused. + ready, + /// An error occurred in `receiveHead`. + receiving_head, + /// A Request object has been obtained and from there a Response can be + /// opened. + received_head, + /// The client is uploading something to this Server. + receiving_body, + /// The connection is eligible for another HTTP request, however the client + /// and server did not negotiate connection: keep-alive. + closing, +}; /// Initialize an HTTP server that can respond to multiple requests on the same /// connection. -/// The returned `Server` is ready for `reset` or `wait` to be called. -pub fn init(connection: std.net.Server.Connection, options: Server.Request.InitOptions) Server { +/// The returned `Server` is ready for `readRequest` to be called. +pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { return .{ - .connection = .{ - .stream = connection.stream, - .read_buf = undefined, - .read_start = 0, - .read_end = 0, - }, - .connection_keep_alive = false, - .request = Server.Request.init(options), + .connection = connection, + .state = .ready, + .read_buffer = read_buffer, + .read_buffer_len = 0, + .next_request_start = 0, }; } -pub const State = enum { - first, - start, - waited, - responded, - finished, +pub const ReceiveHeadError = error{ + /// Client sent too many bytes of HTTP headers. + /// The HTTP specification suggests to respond with a 431 status code + /// before closing the connection. + HttpHeadersOversize, + /// Client sent headers that did not conform to the HTTP protocol. + HttpHeadersInvalid, + /// A low level I/O error occurred trying to read the headers. + HttpHeadersUnreadable, }; -pub const ResetState = enum { reset, closing }; +/// The header bytes reference the read buffer that Server was initialized with +/// and remain alive until the next call to receiveHead. +pub fn receiveHead(s: *Server) ReceiveHeadError!Request { + assert(s.state == .ready); + s.state = .received_head; + errdefer s.state = .receiving_head; -pub const Connection = @import("Server/Connection.zig"); - -/// The mode of transport for responses. -pub const ResponseTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for request messages. -pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(Server.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(Server.TransferReader); - // https://github.com/ziglang/zig/issues/18937 - //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Server.TransferReader, .{}); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - // https://github.com/ziglang/zig/issues/18937 - //zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP request originating from a client. -pub const Request = struct { - method: http.Method, - target: []const u8, - version: http.Version, - expect: ?[]const u8, - content_type: ?[]const u8, - content_length: ?u64, - transfer_encoding: http.TransferEncoding, - transfer_compression: http.ContentEncoding, - keep_alive: bool, - parser: proto.HeadersParser, - compression: Compression, - - pub const InitOptions = struct { - /// Externally-owned memory used to store the client's entire HTTP header. - /// `error.HttpHeadersOversize` is returned from read() when a - /// client sends too many bytes of HTTP headers. - client_header_buffer: []u8, - }; - - pub fn init(options: InitOptions) Request { - return .{ - .method = undefined, - .target = undefined, - .version = undefined, - .expect = null, - .content_type = null, - .content_length = null, - .transfer_encoding = .none, - .transfer_compression = .identity, - .keep_alive = false, - .parser = proto.HeadersParser.init(options.client_header_buffer), - .compression = .none, - }; + // In case of a reused connection, move the next request's bytes to the + // beginning of the buffer. + if (s.next_request_start > 0) { + if (s.read_buffer_len > s.next_request_start) { + const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; + const dest = s.read_buffer[0..leftover.len]; + if (leftover.len <= s.next_request_start) { + @memcpy(dest, leftover); + } else { + mem.copyBackwards(u8, dest, leftover); + } + s.read_buffer_len = leftover.len; + } + s.next_request_start = 0; } - pub const ParseError = error{ - UnknownHttpMethod, - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, + var hp: http.HeadParser = .{}; + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = s.connection.stream.read(buf) catch + return error.HttpHeadersUnreadable; + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) return .{ + .server = s, + .head_end = end, + .head = Request.Head.parse(s.read_buffer[0..end]) catch + return error.HttpHeadersInvalid, + .reader_state = undefined, + }; + } +} + +pub const Request = struct { + server: *Server, + /// Index into Server's read_buffer. + head_end: usize, + head: Head, + reader_state: union { + remaining_content_length: u64, + }, + + pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); + pub const ZstdDecompressor = std.compress.zstd.Decompressor(std.io.AnyReader); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + zstd: ZstdDecompressor, + none: void, }; - pub fn parse(req: *Request, bytes: []const u8) ParseError!void { - var it = mem.splitSequence(u8, bytes, "\r\n"); + pub const Head = struct { + method: http.Method, + target: []const u8, + version: http.Version, + expect: ?[]const u8, + content_type: ?[]const u8, + content_length: ?u64, + transfer_encoding: http.TransferEncoding, + transfer_compression: http.ContentEncoding, + keep_alive: bool, + compression: Compression, - const first_line = it.next().?; - if (first_line.len < 10) - return error.HttpHeadersInvalid; - - const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse - return error.HttpHeadersInvalid; - if (method_end > 24) return error.HttpHeadersInvalid; - - const method_str = first_line[0..method_end]; - const method: http.Method = @enumFromInt(http.Method.parse(method_str)); - - const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse - return error.HttpHeadersInvalid; - if (version_start == method_end) return error.HttpHeadersInvalid; - - const version_str = first_line[version_start + 1 ..]; - if (version_str.len != 8) return error.HttpHeadersInvalid; - const version: http.Version = switch (int64(version_str[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, + pub const ParseError = error{ + UnknownHttpMethod, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionUnsupported, + MissingFinalNewline, }; - const target = first_line[method_end + 1 .. version_start]; + pub fn parse(bytes: []const u8) ParseError!Head { + var it = mem.splitSequence(u8, bytes, "\r\n"); - req.method = method; - req.target = target; - req.version = version; + const first_line = it.next().?; + if (first_line.len < 10) + return error.HttpHeadersInvalid; - while (it.next()) |line| { - if (line.len == 0) return; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } + const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (method_end > 24) return error.HttpHeadersInvalid; - var line_it = mem.splitSequence(u8, line, ": "); - const header_name = line_it.next().?; - const header_value = line_it.rest(); - if (header_value.len == 0) return error.HttpHeadersInvalid; + const method_str = first_line[0..method_end]; + const method: http.Method = @enumFromInt(http.Method.parse(method_str)); - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - req.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { - req.expect = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - req.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (req.content_length != null) return error.HttpHeadersInvalid; - req.content_length = std.fmt.parseInt(u64, header_value, 10) catch - return error.InvalidContentLength; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; + const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (version_start == method_end) return error.HttpHeadersInvalid; - const trimmed = mem.trim(u8, header_value, " "); + const version_str = first_line[version_start + 1 ..]; + if (version_str.len != 8) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(version_str[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - req.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); + const target = first_line[method_end + 1 .. version_start]; - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); + var head: Head = .{ + .method = method, + .target = target, + .version = version, + .expect = null, + .content_type = null, + .content_length = null, + .transfer_encoding = .none, + .transfer_compression = .identity, + .keep_alive = false, + .compression = .none, + }; - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (req.transfer_encoding != .none) - return error.HttpHeadersInvalid; // we already have a transfer encoding - req.transfer_encoding = transfer; - - next = iter.next(); + while (it.next()) |line| { + if (line.len == 0) return head; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, } - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); + var line_it = mem.splitSequence(u8, line, ": "); + const header_name = line_it.next().?; + const header_value = line_it.rest(); + if (header_value.len == 0) return error.HttpHeadersInvalid; - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (req.transfer_compression != .identity) - return error.HttpHeadersInvalid; // double compression is not supported - req.transfer_compression = transfer; + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + head.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { + head.expect = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + head.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (head.content_length != null) return error.HttpHeadersInvalid; + head.content_length = std.fmt.parseInt(u64, header_value, 10) catch + return error.InvalidContentLength; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + head.transfer_compression = ce; } else { return error.HttpTransferEncodingUnsupported; } + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (head.transfer_encoding != .none) + return error.HttpHeadersInvalid; // we already have a transfer encoding + head.transfer_encoding = transfer; + + next = iter.next(); + } + + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (head.transfer_compression != .identity) + return error.HttpHeadersInvalid; // double compression is not supported + head.transfer_compression = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; } + return error.MissingFinalNewline; } - return error.HttpHeadersInvalid; // missing empty line - } - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } -}; - -/// Reset this response to its initial state. This must be called before -/// handling a second request on the same connection. -pub fn reset(res: *Server) ResetState { - if (res.state == .first) { - res.state = .start; - return .reset; - } - - if (!res.request.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - res.connection_keep_alive = false; - return .closing; - } - - res.state = .start; - res.request = Request.init(.{ - .client_header_buffer = res.request.parser.header_bytes_buffer, - }); - - return if (res.connection_keep_alive) .reset else .closing; -} - -pub const SendAllError = std.net.Stream.WriteError; - -pub const SendOptions = struct { - version: http.Version = .@"HTTP/1.1", - status: http.Status = .ok, - reason: ?[]const u8 = null, - keep_alive: bool = true, - extra_headers: []const http.Header = &.{}, - content: []const u8, -}; - -/// Send an entire HTTP response to the client, including headers and body. -/// Automatically handles HEAD requests by omitting the body. -/// Uses the "content-length" header. -/// Asserts status is not `continue`. -/// Asserts there are at most 25 extra_headers. -pub fn sendAll(s: *Server, options: SendOptions) SendAllError!void { - const max_extra_headers = 25; - assert(options.status != .@"continue"); - assert(options.extra_headers.len <= max_extra_headers); - - switch (s.state) { - .waited => s.state = .finished, - .first => unreachable, // Call reset() first. - .start => unreachable, // Call wait() first. - .responded => unreachable, // Cannot mix sendAll() with send(). - .finished => unreachable, // Call reset() first. - } - - s.connection_keep_alive = options.keep_alive and s.request.keep_alive; - const keep_alive_line = if (s.connection_keep_alive) - "connection: keep-alive\r\n" - else - ""; - const phrase = options.reason orelse options.status.phrase() orelse ""; - - var first_buffer: [500]u8 = undefined; - const first_bytes = std.fmt.bufPrint( - &first_buffer, - "{s} {d} {s}\r\n{s}content-length: {d}\r\n", - .{ - @tagName(options.version), - @intFromEnum(options.status), - phrase, - keep_alive_line, - options.content.len, - }, - ) catch unreachable; - - var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .iov_base = first_bytes.ptr, - .iov_len = first_bytes.len, + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } }; - iovecs_len += 1; - for (options.extra_headers) |header| { + pub const RespondOptions = struct { + version: http.Version = .@"HTTP/1.1", + status: http.Status = .ok, + reason: ?[]const u8 = null, + keep_alive: bool = true, + extra_headers: []const http.Header = &.{}, + }; + + /// Send an entire HTTP response to the client, including headers and body. + /// + /// Automatically handles HEAD requests by omitting the body. + /// Uses the "content-length" header unless `content` is empty in which + /// case it omits the content-length header. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// Asserts status is not `continue`. + /// Asserts there are at most 25 extra_headers. + pub fn respond( + request: *Request, + content: []const u8, + options: RespondOptions, + ) Response.WriteError!void { + const max_extra_headers = 25; + assert(options.status != .@"continue"); + assert(options.extra_headers.len <= max_extra_headers); + + const keep_alive = request.discardBody(options.keep_alive); + + const phrase = options.reason orelse options.status.phrase() orelse ""; + + var first_buffer: [500]u8 = undefined; + var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); + h.writerAssumeCapacity().print("{s} {d} {s}\r\n", .{ + @tagName(options.version), @intFromEnum(options.status), phrase, + }) catch |err| switch (err) {}; + if (keep_alive) + h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + if (content.len > 0) + h.writerAssumeCapacity().print("content-length: {d}\r\n", .{content.len}) catch |err| + switch (err) {}; + + var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; + iovecs[iovecs_len] = .{ - .iov_base = header.name.ptr, - .iov_len = header.name.len, + .iov_base = h.items.ptr, + .iov_len = h.items.len, }; iovecs_len += 1; - iovecs[iovecs_len] = .{ - .iov_base = ": ", - .iov_len = 2, - }; - iovecs_len += 1; + for (options.extra_headers) |header| { + iovecs[iovecs_len] = .{ + .iov_base = header.name.ptr, + .iov_len = header.name.len, + }; + iovecs_len += 1; - iovecs[iovecs_len] = .{ - .iov_base = header.value.ptr, - .iov_len = header.value.len, - }; - iovecs_len += 1; + iovecs[iovecs_len] = .{ + .iov_base = ": ", + .iov_len = 2, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = header.value.ptr, + .iov_len = header.value.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } iovecs[iovecs_len] = .{ .iov_base = "\r\n", .iov_len = 2, }; iovecs_len += 1; + + if (request.head.method != .HEAD and content.len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = content.ptr, + .iov_len = content.len, + }; + iovecs_len += 1; + } + + try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); } - iovecs[iovecs_len] = .{ - .iov_base = "\r\n", - .iov_len = 2, + pub const RespondStreamingOptions = struct { + /// An externally managed slice of memory used to batch bytes before + /// sending. `respondStreaming` asserts this is large enough to store + /// the full HTTP response head. + /// + /// Must outlive the returned Response. + send_buffer: []u8, + /// If provided, the response will use the content-length header; + /// otherwise it will use transfer-encoding: chunked. + content_length: ?u64 = null, + /// Options that are shared with the `respond` method. + respond_options: RespondOptions = .{}, }; - iovecs_len += 1; - if (s.request.method != .HEAD) { - iovecs[iovecs_len] = .{ - .iov_base = options.content.ptr, - .iov_len = options.content.len, + /// The header is buffered but not sent until Response.flush is called. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// HEAD requests are handled transparently by setting a flag on the + /// returned Response to omit the body. However it may be worth noticing + /// that flag and skipping any expensive work that would otherwise need to + /// be done to satisfy the request. + /// + /// Asserts `send_buffer` is large enough to store the entire response header. + /// Asserts status is not `continue`. + pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { + const o = options.respond_options; + assert(o.status != .@"continue"); + + const keep_alive = request.discardBody(o.keep_alive); + const phrase = o.reason orelse o.status.phrase() orelse ""; + + var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); + + h.writerAssumeCapacity().print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }) catch |err| switch (err) {}; + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + + if (options.content_length) |len| { + h.writerAssumeCapacity().print("content-length: {d}\r\n", .{len}) catch |err| switch (err) {}; + } else { + h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); + } + + for (o.extra_headers) |header| { + h.appendSliceAssumeCapacity(header.name); + h.appendSliceAssumeCapacity(": "); + h.appendSliceAssumeCapacity(header.value); + h.appendSliceAssumeCapacity("\r\n"); + } + + h.appendSliceAssumeCapacity("\r\n"); + + return .{ + .stream = request.server.connection.stream, + .send_buffer = options.send_buffer, + .send_buffer_start = 0, + .send_buffer_end = h.items.len, + .content_length = options.content_length, + .elide_body = request.head.method == .HEAD, + .chunk_len = 0, }; - iovecs_len += 1; } - return s.connection.stream.writevAll(iovecs[0..iovecs_len]); -} + pub const ReadError = net.Stream.ReadError; + + fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + assert(s.state == .receiving_body); + + const remaining_content_length = &request.reader_state.remaining_content_length; + + if (remaining_content_length.* == 0) { + s.state = .ready; + return 0; + } + + const available_bytes = s.read_buffer_len - request.head_end; + if (available_bytes == 0) + s.read_buffer_len += try s.connection.stream.read(s.read_buffer[request.head_end..]); + + const available_buf = s.read_buffer[request.head_end..s.read_buffer_len]; + const len = @min(remaining_content_length.*, available_buf.len, buffer.len); + @memcpy(buffer[0..len], available_buf[0..len]); + remaining_content_length.* -= len; + if (remaining_content_length.* == 0) + s.state = .ready; + return len; + } + + fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + assert(s.state == .receiving_body); + _ = buffer; + @panic("TODO"); + } + + pub const ReadAllError = ReadError || error{HttpBodyOversize}; + + pub fn reader(request: *Request) std.io.AnyReader { + const s = request.server; + assert(s.state == .received_head); + s.state = .receiving_body; + switch (request.head.transfer_encoding) { + .chunked => return .{ + .readFn = read_chunked, + .context = request, + }, + .none => { + request.reader_state = .{ + .remaining_content_length = request.head.content_length orelse 0, + }; + return .{ + .readFn = read_cl, + .context = request, + }; + }, + } + } + + /// Returns whether the connection: keep-alive header should be sent to the client. + /// If it would fail, it instead sets the Server state to `receiving_body` + /// and returns false. + fn discardBody(request: *Request, keep_alive: bool) bool { + // Prepare to receive another request on the same connection. + // There are two factors to consider: + // * Any body the client sent must be discarded. + // * The Server's read_buffer may already have some bytes in it from + // whatever came after the head, which may be the next HTTP request + // or the request body. + // If the connection won't be kept alive, then none of this matters + // because the connection will be severed after the response is sent. + const s = request.server; + if (keep_alive and request.head.keep_alive) switch (s.state) { + .received_head => { + s.state = .receiving_body; + switch (request.head.transfer_encoding) { + .none => t: { + const len = request.head.content_length orelse break :t; + const head_end = request.head_end; + var total_body_discarded: usize = 0; + while (true) { + const available_bytes = s.read_buffer_len - head_end; + const remaining_len = len - total_body_discarded; + if (available_bytes >= remaining_len) { + s.next_request_start = head_end + remaining_len; + break :t; + } + total_body_discarded += available_bytes; + // Preserve request header memory until receiveHead is called. + const buf = s.read_buffer[head_end..]; + const read_n = s.connection.stream.read(buf) catch return false; + s.read_buffer_len = head_end + read_n; + } + }, + .chunked => { + @panic("TODO"); + }, + } + s.state = .ready; + return true; + }, + .receiving_body, .ready => return true, + else => unreachable, + } else { + s.state = .closing; + return false; + } + } +}; pub const Response = struct { - transfer_encoding: ResponseTransfer, -}; + stream: net.Stream, + send_buffer: []u8, + /// Index of the first byte in `send_buffer`. + /// This is 0 unless a short write happens in `write`. + send_buffer_start: usize, + /// Index of the last byte + 1 in `send_buffer`. + send_buffer_end: usize, + /// `null` means transfer-encoding: chunked. + /// As a debugging utility, counts down to zero as bytes are written. + content_length: ?u64, + elide_body: bool, + /// Indicates how much of the end of the `send_buffer` corresponds to a + /// chunk. This amount of data will be wrapped by an HTTP chunk header. + chunk_len: usize, -pub const SendError = Connection.WriteError || error{ - UnsupportedTransferEncoding, - InvalidContentLength, -}; + pub const WriteError = net.Stream.WriteError; -/// Send the HTTP response headers to the client. -pub fn send(res: *Server) SendError!void { - switch (res.state) { - .waited => res.state = .responded, - .first, .start, .responded, .finished => unreachable, - } - - var buffered = std.io.bufferedWriter(res.connection.writer()); - const w = buffered.writer(); - - try w.writeAll(@tagName(res.version)); - try w.writeByte(' '); - try w.print("{d}", .{@intFromEnum(res.status)}); - try w.writeByte(' '); - if (res.reason) |reason| { - try w.writeAll(reason); - } else if (res.status.phrase()) |phrase| { - try w.writeAll(phrase); - } - try w.writeAll("\r\n"); - - if (res.status == .@"continue") { - res.state = .waited; // we still need to send another request after this - } else { - res.connection_keep_alive = res.keep_alive and res.request.keep_alive; - if (res.connection_keep_alive) { - try w.writeAll("connection: keep-alive\r\n"); - } else { - try w.writeAll("connection: close\r\n"); + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header, then calls `flush`. + /// Otherwise, transfer-encoding: chunked is being used, and it writes the + /// end-of-stream message, then flushes the stream to the system. + /// When request method is HEAD, does not write anything to the stream. + pub fn end(r: *Response) WriteError!void { + if (r.content_length) |len| { + assert(len == 0); // Trips when end() called before all bytes written. + return flush_cl(r); } - - switch (res.transfer_encoding) { - .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), - .content_length => |content_length| try w.print("content-length: {d}\r\n", .{content_length}), - .none => {}, - } - - for (res.extra_headers) |header| { - try w.print("{s}: {s}\r\n", .{ header.name, header.value }); + if (!r.elide_body) { + return flush_chunked(r, &.{}); } + r.* = undefined; } - if (res.request.method == .HEAD) { - res.transfer_encoding = .none; - } - - try w.writeAll("\r\n"); - - try buffered.flush(); -} - -const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - -const TransferReader = std.io.Reader(*Server, TransferReadError, transferRead); - -fn transferReader(res: *Server) TransferReader { - return .{ .context = res }; -} - -fn transferRead(res: *Server, buf: []u8) TransferReadError!usize { - if (res.request.parser.done) return 0; - - var index: usize = 0; - while (index == 0) { - const amt = try res.request.parser.read(&res.connection, buf[index..], false); - if (amt == 0 and res.request.parser.done) break; - index += amt; - } - - return index; -} - -pub const WaitError = Connection.ReadError || - proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || - error{CompressionUnsupported}; - -/// Wait for the client to send a complete request head. -/// -/// For correct behavior, the following rules must be followed: -/// -/// * If this returns any error in `Connection.ReadError`, you MUST -/// immediately close the connection by calling `deinit`. -/// * If this returns `error.HttpHeadersInvalid`, you MAY immediately close -/// the connection by calling `deinit`. -/// * If this returns `error.HttpHeadersOversize`, you MUST -/// respond with a 431 status code and then call `deinit`. -/// * If this returns any error in `Request.ParseError`, you MUST respond -/// with a 400 status code and then call `deinit`. -/// * If this returns any other error, you MUST respond with a 400 status -/// code and then call `deinit`. -/// * If the request has an Expect header containing 100-continue, you MUST either: -/// * Respond with a 100 status code, then call `wait` again. -/// * Respond with a 417 status code. -pub fn wait(res: *Server) WaitError!void { - switch (res.state) { - .first, .start => res.state = .waited, - .waited, .responded, .finished => unreachable, - } - - while (true) { - try res.connection.fill(); - - const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek()); - res.connection.drop(@intCast(nchecked)); - - if (res.request.parser.state.isContent()) break; - } - - try res.request.parse(res.request.parser.get()); - - switch (res.request.transfer_encoding) { - .none => { - if (res.request.content_length) |len| { - res.request.parser.next_chunk_length = len; - - if (len == 0) res.request.parser.done = true; - } else { - res.request.parser.done = true; - } - }, - .chunked => { - res.request.parser.next_chunk_length = 0; - res.request.parser.state = .chunk_head_size; - }, - } - - if (!res.request.parser.done) { - switch (res.request.transfer_compression) { - .identity => res.request.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - .deflate => res.request.compression = .{ - .deflate = std.compress.zlib.decompressor(res.transferReader()), - }, - .gzip, .@"x-gzip" => res.request.compression = .{ - .gzip = std.compress.gzip.decompressor(res.transferReader()), - }, - .zstd => { - // https://github.com/ziglang/zig/issues/18937 - return error.CompressionUnsupported; - }, - } - } -} - -pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; - -pub const Reader = std.io.Reader(*Server, ReadError, read); - -pub fn reader(res: *Server) Reader { - return .{ .context = res }; -} - -/// Reads data from the response body. Must be called after `wait`. -pub fn read(res: *Server, buffer: []u8) ReadError!usize { - switch (res.state) { - .waited, .responded, .finished => {}, - .first, .start => unreachable, - } - - const out_index = switch (res.request.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try res.transferRead(buffer), + pub const EndChunkedOptions = struct { + trailers: []const http.Header = &.{}, }; - if (out_index == 0) { - const has_trail = !res.request.parser.state.isContent(); + /// Asserts that the Response is using transfer-encoding: chunked. + /// Writes the end-of-stream message and any optional trailers, then + /// flushes the stream to the system. + /// When request method is HEAD, does not write anything to the stream. + /// Asserts there are at most 25 trailers. + pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { + assert(r.content_length == null); + if (r.elide_body) return; + try flush_chunked(r, options.trailers); + r.* = undefined; + } - while (!res.request.parser.state.isContent()) { // read trailing headers - try res.connection.fill(); - - const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek()); - res.connection.drop(@intCast(nchecked)); - } - - if (has_trail) { - // The response headers before the trailers are already - // guaranteed to be valid, so they will always be parsed again - // and cannot return an error. - // This will *only* fail for a malformed trailer. - res.request.parse(res.request.parser.get()) catch return error.InvalidTrailers; + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + /// May return 0, which does not indicate end of stream. The caller decides + /// when the end of stream occurs by calling `end`. + pub fn write(r: *Response, bytes: []const u8) WriteError!usize { + if (r.content_length != null) { + return write_cl(r, bytes); + } else { + return write_chunked(r, bytes); } } - return out_index; -} + fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + const len = &r.content_length.?; + if (r.elide_body) { + len.* -= bytes.len; + return bytes.len; + } -/// Reads data from the response body. Must be called after `wait`. -pub fn readAll(res: *Server, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(res, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; -} + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + var iovecs: [2]std.posix.iovec_const = .{ + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len, + }, + .{ + .iov_base = bytes.ptr, + .iov_len = bytes.len, + }, + }; + const n = try r.stream.writev(&iovecs); -pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - -pub const Writer = std.io.Writer(*Server, WriteError, write); - -pub fn writer(res: *Server) Writer { - return .{ .context = res }; -} - -/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. -/// Must be called after `send` and before `finish`. -pub fn write(res: *Server, bytes: []const u8) WriteError!usize { - switch (res.state) { - .responded => {}, - .first, .waited, .start, .finished => unreachable, - } - - switch (res.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try res.connection.writer().print("{x}\r\n", .{bytes.len}); - try res.connection.writeAll(bytes); - try res.connection.writeAll("\r\n"); + if (n >= send_buffer_len) { + // It was enough to reset the buffer. + r.send_buffer_start = 0; + r.send_buffer_end = 0; + const bytes_n = n - send_buffer_len; + len.* -= bytes_n; + return bytes_n; } + // It didn't even make it through the existing buffer, let + // alone the new bytes provided. + r.send_buffer_start += n; + return 0; + } + + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + len.* -= bytes.len; + return bytes.len; + } + + fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + assert(r.content_length == null); + + if (r.elide_body) return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - const amt = try res.connection.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } -} + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + const chunk_len = r.chunk_len + bytes.len; + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; -/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. -/// Must be called after `send` and before `finish`. -pub fn writeAll(req: *Server, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } -} + var iovecs: [5]std.posix.iovec_const = .{ + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len - r.chunk_len, + }, + .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }, + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .iov_len = r.chunk_len, + }, + .{ + .iov_base = bytes.ptr, + .iov_len = bytes.len, + }, + .{ + .iov_base = "\r\n", + .iov_len = 2, + }, + }; + // TODO make this writev instead of writevAll, which involves + // complicating the logic of this function. + try r.stream.writevAll(&iovecs); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return bytes.len; + } -pub const FinishError = Connection.WriteError || error{MessageNotCompleted}; - -/// Finish the body of a request. This notifies the server that you have no more data to send. -/// Must be called after `send`. -pub fn finish(res: *Server) FinishError!void { - switch (res.state) { - .responded => res.state = .finished, - .first, .waited, .start, .finished => unreachable, + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + r.chunk_len += bytes.len; + return bytes.len; } - switch (res.transfer_encoding) { - .chunked => try res.connection.writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(r, bytes[index..]); + } } -} -const builtin = @import("builtin"); + /// Sends all buffered data to the client. + /// This is redundant after calling `end`. + pub fn flush(r: *Response) WriteError!void { + if (r.content_length != null) { + return flush_cl(r); + } else { + return flush_chunked(r, null); + } + } + + fn flush_cl(r: *Response) WriteError!void { + assert(r.content_length != null); + try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + } + + fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { + const max_trailers = 25; + if (end_trailers) |trailers| assert(trailers.len <= max_trailers); + assert(r.content_length == null); + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; + + var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; + + iovecs[iovecs_len] = .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len - r.chunk_len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .iov_len = r.chunk_len, + }; + iovecs_len += 1; + + if (end_trailers) |trailers| { + if (r.chunk_len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = "\r\n0\r\n", + .iov_len = 5, + }; + iovecs_len += 1; + } + + for (trailers) |trailer| { + iovecs[iovecs_len] = .{ + .iov_base = trailer.name.ptr, + .iov_len = trailer.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = ": ", + .iov_len = 2, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = trailer.value.ptr, + .iov_len = trailer.value.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } else if (r.chunk_len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } + + try r.stream.writevAll(iovecs[0..iovecs_len]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + } + + pub fn writer(r: *Response) std.io.AnyWriter { + return .{ + .writeFn = if (r.content_length != null) write_cl else write_chunked, + .context = r, + }; + } +}; + const std = @import("../std.zig"); -const testing = std.testing; const http = std.http; const mem = std.mem; const net = std.net; @@ -615,4 +807,3 @@ const Uri = std.Uri; const assert = std.debug.assert; const Server = @This(); -const proto = @import("protocol.zig"); diff --git a/lib/std/http/Server/Connection.zig b/lib/std/http/Server/Connection.zig deleted file mode 100644 index 74997fc140..0000000000 --- a/lib/std/http/Server/Connection.zig +++ /dev/null @@ -1,119 +0,0 @@ -stream: std.net.Stream, - -read_buf: [buffer_size]u8, -read_start: u16, -read_end: u16, - -pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - -pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - return conn.stream.readAtLeast(buffer, len) catch |err| { - switch (err) { - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; -} - -pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); -} - -pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; -} - -pub fn drop(conn: *Connection, num: u16) void { - conn.read_start += num; -} - -pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - - var out_index: u16 = 0; - while (out_index < len) { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len - out_index; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - out_index += @as(u16, @intCast(available_buffer)); - conn.read_start += @as(u16, @intCast(available_buffer)); - - break; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - out_index += available_read; - conn.read_start += available_read; - - if (out_index >= len) break; - } - - const leftover_buffer = available_buffer - available_read; - const leftover_len = len - out_index; - - if (leftover_buffer > conn.read_buf.len) { - // skip the buffer if the output is large enough - return conn.rawReadAtLeast(buffer[out_index..], leftover_len); - } - - try conn.fill(); - } - - return out_index; -} - -pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); -} - -pub const ReadError = error{ - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, -}; - -pub const Reader = std.io.Reader(*Connection, ReadError, read); - -pub fn reader(conn: *Connection) Reader { - return .{ .context = conn }; -} - -pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; -} - -pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - return conn.stream.write(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; -} - -pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, -}; - -pub const Writer = std.io.Writer(*Connection, WriteError, write); - -pub fn writer(conn: *Connection) Writer { - return .{ .context = conn }; -} - -pub fn close(conn: *Connection) void { - conn.stream.close(); -} - -const Connection = @This(); -const std = @import("../../std.zig"); -const assert = std.debug.assert; diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 4c69a79105..d080d3c389 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -73,339 +73,28 @@ pub const HeadersParser = struct { return hp.header_bytes_buffer[0..hp.header_bytes_len]; } - /// Returns the number of bytes consumed by headers. This is always less - /// than or equal to `bytes.len`. - /// You should check `r.state.isContent()` after this to check if the - /// headers are done. - /// - /// If the amount returned is less than `bytes.len`, you may assume that - /// the parser is in a content state and the - /// first byte of content is located at `bytes[result]`. pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { - const vector_len: comptime_int = @max(std.simd.suggestVectorLength(u8) orelse 1, 8); - const len: u32 = @intCast(bytes.len); - var index: u32 = 0; - - while (true) { - switch (r.state) { - .invalid => unreachable, - .finished => return index, - .start => switch (len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - return index + 1; - }, - 2 => { - const b16 = int16(bytes[index..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - return index + 2; - }, - 3 => { - const b24 = int24(bytes[index..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - - return index + 3; - }, - 4...vector_len - 1 => { - const b32 = int32(bytes[index..][0..4]); - const b24 = intShift(u24, b32); - const b16 = intShift(u16, b32); - const b8 = intShift(u8, b32); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - - switch (b32) { - int32("\r\n\r\n") => r.state = .finished, - else => {}, - } - - index += 4; - continue; - }, - else => { - const chunk = bytes[index..][0..vector_len]; - const matches = if (use_vectors) matches: { - const Vector = @Vector(vector_len, u8); - // const BoolVector = @Vector(vector_len, bool); - const BitVector = @Vector(vector_len, u1); - const SizeVector = @Vector(vector_len, u8); - - const v: Vector = chunk.*; - const matches_r: BitVector = @bitCast(v == @as(Vector, @splat('\r'))); - const matches_n: BitVector = @bitCast(v == @as(Vector, @splat('\n'))); - const matches_or: SizeVector = matches_r | matches_n; - - break :matches @reduce(.Add, matches_or); - } else matches: { - var matches: u8 = 0; - for (chunk) |byte| switch (byte) { - '\r', '\n' => matches += 1, - else => {}, - }; - break :matches matches; - }; - switch (matches) { - 0 => {}, - 1 => switch (chunk[vector_len - 1]) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - }, - 2 => { - const b16 = int16(chunk[vector_len - 2 ..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - }, - 3 => { - const b24 = int24(chunk[vector_len - 3 ..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - }, - 4...vector_len => { - inline for (0..vector_len - 3) |i_usize| { - const i = @as(u32, @truncate(i_usize)); - - const b32 = int32(chunk[i..][0..4]); - const b16 = intShift(u16, b32); - - if (b32 == int32("\r\n\r\n")) { - r.state = .finished; - return index + i + 4; - } else if (b16 == int16("\n\n")) { - r.state = .finished; - return index + i + 2; - } - } - - const b24 = int24(chunk[vector_len - 3 ..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - }, - else => unreachable, - } - - index += vector_len; - continue; - }, - }, - .seen_n => switch (len - index) { - 0 => return index, - else => { - switch (bytes[index]) { - '\n' => r.state = .finished, - else => r.state = .start, - } - - index += 1; - continue; - }, - }, - .seen_r => switch (len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\n' => r.state = .seen_rn, - '\r' => r.state = .seen_r, - else => r.state = .start, - } - - return index + 1; - }, - 2 => { - const b16 = int16(bytes[index..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_rn, - else => r.state = .start, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\r") => r.state = .seen_rnr, - int16("\n\n") => r.state = .finished, - else => {}, - } - - return index + 2; - }, - else => { - const b24 = int24(bytes[index..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => r.state = .start, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\n\r\n") => r.state = .finished, - else => {}, - } - - index += 3; - continue; - }, - }, - .seen_rn => switch (len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\r' => r.state = .seen_rnr, - '\n' => r.state = .seen_n, - else => r.state = .start, - } - - return index + 1; - }, - else => { - const b16 = int16(bytes[index..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_rnr, - '\n' => r.state = .seen_n, - else => r.state = .start, - } - - switch (b16) { - int16("\r\n") => r.state = .finished, - int16("\n\n") => r.state = .finished, - else => {}, - } - - index += 2; - continue; - }, - }, - .seen_rnr => switch (len - index) { - 0 => return index, - else => { - switch (bytes[index]) { - '\n' => r.state = .finished, - else => r.state = .start, - } - - index += 1; - continue; - }, - }, - .chunk_head_size => unreachable, - .chunk_head_ext => unreachable, - .chunk_head_r => unreachable, - .chunk_data => unreachable, - .chunk_data_suffix => unreachable, - .chunk_data_suffix_r => unreachable, - } - - return index; - } + var hp: std.http.HeadParser = .{ + .state = switch (r.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + else => unreachable, + }, + }; + const result = hp.feed(bytes); + r.state = switch (hp.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + }; + return @intCast(result); } /// Returns the number of bytes consumed by the chunk size. This is always @@ -775,17 +464,6 @@ const MockBufferedConnection = struct { } }; -test "HeadersParser.findHeadersEnd" { - var r: HeadersParser = undefined; - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello"; - - for (0..36) |i| { - r = HeadersParser.init(&.{}); - try std.testing.expectEqual(@as(u32, @intCast(i)), r.findHeadersEnd(data[0..i])); - try std.testing.expectEqual(@as(u32, @intCast(35 - i)), r.findHeadersEnd(data[i..])); - } -} - test "HeadersParser.findChunkedLen" { var r: HeadersParser = undefined; const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n"; diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 3441346baf..3dbccbcff6 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -69,31 +69,35 @@ test "trailers" { fn serverThread(http_server: *std.net.Server) anyerror!void { var header_buffer: [1024]u8 = undefined; var remaining: usize = 1; - accept: while (remaining != 0) : (remaining -= 1) { + while (remaining != 0) : (remaining -= 1) { const conn = try http_server.accept(); defer conn.stream.close(); - var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer }); + var server = std.http.Server.init(conn, &header_buffer); - res.wait() catch |err| switch (err) { - error.HttpHeadersInvalid => continue :accept, - error.EndOfStream => continue, - else => return err, - }; - try serve(&res); - - try testing.expectEqual(.reset, res.reset()); + try testing.expectEqual(.ready, server.state); + var request = try server.receiveHead(); + try serve(&request); + try testing.expectEqual(.ready, server.state); } } -fn serve(res: *std.http.Server) !void { - try testing.expectEqualStrings(res.request.target, "/trailer"); - res.transfer_encoding = .chunked; +fn serve(request: *std.http.Server.Request) !void { + try testing.expectEqualStrings(request.head.target, "/trailer"); - try res.send(); - try res.writeAll("Hello, "); - try res.writeAll("World!\n"); - try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); + var send_buffer: [1024]u8 = undefined; + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + }); + try response.writeAll("Hello, "); + try response.flush(); + try response.writeAll("World!\n"); + try response.flush(); + try response.endChunked(.{ + .trailers = &.{ + .{ .name = "X-Checksum", .value = "aaaa" }, + }, + }); } test "HTTP server handles a chunked transfer coding request" { @@ -116,34 +120,33 @@ test "HTTP server handles a chunked transfer coding request" { const max_header_size = 8192; const address = try std.net.Address.parseIp("127.0.0.1", 0); - var server = try address.listen(.{ .reuse_address = true }); - defer server.deinit(); - const server_port = server.listen_address.in.getPort(); + var socket_server = try address.listen(.{ .reuse_address = true }); + defer socket_server.deinit(); + const server_port = socket_server.listen_address.in.getPort(); const server_thread = try std.Thread.spawn(.{}, (struct { - fn apply(s: *std.net.Server) !void { + fn apply(net_server: *std.net.Server) !void { var header_buffer: [max_header_size]u8 = undefined; - const conn = try s.accept(); + const conn = try net_server.accept(); defer conn.stream.close(); - var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer }); - try res.wait(); - try expect(res.request.transfer_encoding == .chunked); - const server_body: []const u8 = "message from server!\n"; - res.transfer_encoding = .{ .content_length = server_body.len }; - res.extra_headers = &.{ - .{ .name = "content-type", .value = "text/plain" }, - }; - res.keep_alive = false; - try res.send(); + var server = std.http.Server.init(conn, &header_buffer); + var request = try server.receiveHead(); + + try expect(request.head.transfer_encoding == .chunked); var buf: [128]u8 = undefined; - const n = try res.readAll(&buf); + const n = try request.reader().readAll(&buf); try expect(std.mem.eql(u8, buf[0..n], "ABCD")); - _ = try res.writer().writeAll(server_body); - try res.finish(); + + try request.respond("message from server!\n", .{ + .extra_headers = &.{ + .{ .name = "content-type", .value = "text/plain" }, + }, + .keep_alive = false, + }); } - }).apply, .{&server}); + }).apply, .{&socket_server}); const request_bytes = "POST / HTTP/1.1\r\n" ++