From 6e671d4c779dc2087b0687f9e5ed5cd7a3341ea9 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 31 Jul 2025 22:36:08 -0700 Subject: [PATCH] std.http: rework for new std.Io API --- lib/std/http.zig | 781 ++++++++++++- lib/std/http/ChunkParser.zig | 6 +- lib/std/http/Client.zig | 2052 +++++++++++++++++----------------- lib/std/http/Server.zig | 1140 +++++++------------ lib/std/http/WebSocket.zig | 246 ---- lib/std/http/protocol.zig | 464 -------- lib/std/http/test.zig | 594 +++++----- 7 files changed, 2449 insertions(+), 2834 deletions(-) delete mode 100644 lib/std/http/WebSocket.zig delete mode 100644 lib/std/http/protocol.zig diff --git a/lib/std/http.zig b/lib/std/http.zig index 5bf12a1876..6075a2fe6d 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,14 +1,14 @@ const builtin = @import("builtin"); const std = @import("std.zig"); const assert = std.debug.assert; +const Writer = std.Io.Writer; +const File = std.fs.File; pub const Client = @import("http/Client.zig"); pub const Server = @import("http/Server.zig"); -pub const protocol = @import("http/protocol.zig"); pub const HeadParser = @import("http/HeadParser.zig"); pub const ChunkParser = @import("http/ChunkParser.zig"); pub const HeaderIterator = @import("http/HeaderIterator.zig"); -pub const WebSocket = @import("http/WebSocket.zig"); pub const Version = enum { @"HTTP/1.0", @@ -42,7 +42,7 @@ pub const Method = enum(u64) { return x; } - pub fn format(self: Method, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn format(self: Method, w: *Writer) Writer.Error!void { const bytes: []const u8 = @ptrCast(&@intFromEnum(self)); const str = std.mem.sliceTo(bytes, 0); try w.writeAll(str); @@ -296,13 +296,24 @@ pub const TransferEncoding = enum { }; pub const ContentEncoding = enum { - identity, - compress, - @"x-compress", - deflate, - gzip, - @"x-gzip", zstd, + gzip, + deflate, + compress, + identity, + + pub fn fromString(s: []const u8) ?ContentEncoding { + const map = std.StaticStringMap(ContentEncoding).initComptime(.{ + .{ "zstd", .zstd }, + .{ "gzip", .gzip }, + .{ "x-gzip", .gzip }, + .{ "deflate", .deflate }, + .{ "compress", .compress }, + .{ "x-compress", .compress }, + .{ "identity", .identity }, + }); + return map.get(s); + } }; pub const Connection = enum { @@ -315,15 +326,755 @@ pub const Header = struct { value: []const u8, }; +pub const Reader = struct { + in: *std.Io.Reader, + /// This is preallocated memory that might be used by `bodyReader`. That + /// function might return a pointer to this field, or a different + /// `*std.Io.Reader`. Advisable to not access this field directly. + interface: std.Io.Reader, + /// Keeps track of whether the stream is ready to accept a new request, + /// making invalid API usage cause assertion failures rather than HTTP + /// protocol violations. + state: State, + /// HTTP trailer bytes. These are at the end of a transfer-encoding: + /// chunked message. This data is available only after calling one of the + /// "end" functions and points to data inside the buffer of `in`, and is + /// therefore invalidated on the next call to `receiveHead`, or any other + /// read from `in`. + trailers: []const u8 = &.{}, + body_err: ?BodyError = null, + /// Stolen from `in`. + head_buffer: []u8 = &.{}, + + pub const max_chunk_header_len = 22; + + pub const RemainingChunkLen = enum(u64) { + head = 0, + n = 1, + rn = 2, + _, + + pub fn init(integer: u64) RemainingChunkLen { + return @enumFromInt(integer); + } + + pub fn int(rcl: RemainingChunkLen) u64 { + return @intFromEnum(rcl); + } + }; + + pub const State = union(enum) { + /// The stream is available to be used for the first time, or reused. + ready, + received_head, + /// The stream goes until the connection is closed. + body_none, + body_remaining_content_length: u64, + body_remaining_chunk_len: RemainingChunkLen, + /// The stream would be eligible for another HTTP request, however the + /// client and server did not negotiate a persistent connection. + closing, + }; + + pub const BodyError = error{ + HttpChunkInvalid, + HttpChunkTruncated, + HttpHeadersOversize, + }; + + pub const HeadError = error{ + /// Too many bytes of HTTP headers. + /// + /// The HTTP specification suggests to respond with a 431 status code + /// before closing the connection. + HttpHeadersOversize, + /// Partial HTTP request was received but the connection was closed + /// before fully receiving the headers. + HttpRequestTruncated, + /// The client sent 0 bytes of headers before closing the stream. This + /// happens when a keep-alive connection is finally closed. + HttpConnectionClosing, + /// Transitive error occurred reading from `in`. + ReadFailed, + }; + + pub fn restituteHeadBuffer(reader: *Reader) void { + reader.in.restitute(reader.head_buffer.len); + reader.head_buffer.len = 0; + } + + /// Buffers the entire head into `head_buffer`, invalidating the previous + /// `head_buffer`, if any. + pub fn receiveHead(reader: *Reader) HeadError!void { + reader.trailers = &.{}; + const in = reader.in; + in.restitute(reader.head_buffer.len); + reader.head_buffer.len = 0; + in.rebase(); + var hp: HeadParser = .{}; + var head_end: usize = 0; + while (true) { + if (head_end >= in.buffer.len) return error.HttpHeadersOversize; + in.fillMore() catch |err| switch (err) { + error.EndOfStream => switch (head_end) { + 0 => return error.HttpConnectionClosing, + else => return error.HttpRequestTruncated, + }, + error.ReadFailed => return error.ReadFailed, + }; + head_end += hp.feed(in.buffered()[head_end..]); + if (hp.state == .finished) { + reader.head_buffer = in.steal(head_end); + reader.state = .received_head; + return; + } + } + } + + /// If compressed body has been negotiated this will return compressed bytes. + /// + /// Asserts only called once and after `receiveHead`. + /// + /// See also: + /// * `interfaceDecompressing` + pub fn bodyReader( + reader: *Reader, + buffer: []u8, + transfer_encoding: TransferEncoding, + content_length: ?u64, + ) *std.Io.Reader { + assert(reader.state == .received_head); + switch (transfer_encoding) { + .chunked => { + reader.state = .{ .body_remaining_chunk_len = .head }; + reader.interface = .{ + .buffer = buffer, + .seek = 0, + .end = 0, + .vtable = &.{ + .stream = chunkedStream, + .discard = chunkedDiscard, + }, + }; + return &reader.interface; + }, + .none => { + if (content_length) |len| { + reader.state = .{ .body_remaining_content_length = len }; + reader.interface = .{ + .buffer = buffer, + .seek = 0, + .end = 0, + .vtable = &.{ + .stream = contentLengthStream, + .discard = contentLengthDiscard, + }, + }; + return &reader.interface; + } else { + reader.state = .body_none; + return reader.in; + } + }, + } + } + + /// If compressed body has been negotiated this will return decompressed bytes. + /// + /// Asserts only called once and after `receiveHead`. + /// + /// See also: + /// * `interface` + pub fn bodyReaderDecompressing( + reader: *Reader, + transfer_encoding: TransferEncoding, + content_length: ?u64, + content_encoding: ContentEncoding, + decompressor: *Decompressor, + decompression_buffer: []u8, + ) *std.Io.Reader { + if (transfer_encoding == .none and content_length == null) { + assert(reader.state == .received_head); + reader.state = .body_none; + switch (content_encoding) { + .identity => { + return reader.in; + }, + .deflate => { + decompressor.* = .{ .flate = .init(reader.in, .raw, decompression_buffer) }; + return &decompressor.flate.reader; + }, + .gzip => { + decompressor.* = .{ .flate = .init(reader.in, .gzip, decompression_buffer) }; + return &decompressor.flate.reader; + }, + .zstd => { + decompressor.* = .{ .zstd = .init(reader.in, decompression_buffer, .{ .verify_checksum = false }) }; + return &decompressor.zstd.reader; + }, + .compress => unreachable, + } + } + const transfer_reader = bodyReader(reader, &.{}, transfer_encoding, content_length); + return decompressor.init(transfer_reader, decompression_buffer, content_encoding); + } + + fn contentLengthStream( + io_r: *std.Io.Reader, + w: *Writer, + limit: std.Io.Limit, + ) std.Io.Reader.StreamError!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); + const remaining_content_length = &reader.state.body_remaining_content_length; + const remaining = remaining_content_length.*; + if (remaining == 0) { + reader.state = .ready; + return error.EndOfStream; + } + const n = try reader.in.stream(w, limit.min(.limited(remaining))); + remaining_content_length.* = remaining - n; + return n; + } + + fn contentLengthDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); + const remaining_content_length = &reader.state.body_remaining_content_length; + const remaining = remaining_content_length.*; + if (remaining == 0) { + reader.state = .ready; + return error.EndOfStream; + } + const n = try reader.in.discard(limit.min(.limited(remaining))); + remaining_content_length.* = remaining - n; + return n; + } + + fn chunkedStream(io_r: *std.Io.Reader, w: *Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); + const chunk_len_ptr = switch (reader.state) { + .ready => return error.EndOfStream, + .body_remaining_chunk_len => |*x| x, + else => unreachable, + }; + return chunkedReadEndless(reader, w, limit, chunk_len_ptr) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.WriteFailed => return error.WriteFailed, + error.EndOfStream => { + reader.body_err = error.HttpChunkTruncated; + return error.ReadFailed; + }, + else => |e| { + reader.body_err = e; + return error.ReadFailed; + }, + }; + } + + fn chunkedReadEndless( + reader: *Reader, + w: *Writer, + limit: std.Io.Limit, + chunk_len_ptr: *RemainingChunkLen, + ) (BodyError || std.Io.Reader.StreamError)!usize { + const in = reader.in; + len: switch (chunk_len_ptr.*) { + .head => { + var cp: ChunkParser = .init; + while (true) { + const i = cp.feed(in.buffered()); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + in.toss(i); + break; + }, + else => { + in.toss(i); + try in.fillMore(); + continue; + }, + } + } + if (cp.chunk_len == 0) return parseTrailers(reader, 0); + const n = try in.stream(w, limit.min(.limited(cp.chunk_len))); + chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); + return n; + }, + .n => { + if ((try in.peekByte()) != '\n') return error.HttpChunkInvalid; + in.toss(1); + continue :len .head; + }, + .rn => { + const rn = try in.peekArray(2); + if (rn[0] != '\r' or rn[1] != '\n') return error.HttpChunkInvalid; + in.toss(2); + continue :len .head; + }, + else => |remaining_chunk_len| { + const n = try in.stream(w, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2))); + chunk_len_ptr.* = .init(@intFromEnum(remaining_chunk_len) - n); + return n; + }, + } + } + + fn chunkedDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); + const chunk_len_ptr = switch (reader.state) { + .ready => return error.EndOfStream, + .body_remaining_chunk_len => |*x| x, + else => unreachable, + }; + return chunkedDiscardEndless(reader, limit, chunk_len_ptr) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.EndOfStream => { + reader.body_err = error.HttpChunkTruncated; + return error.ReadFailed; + }, + else => |e| { + reader.body_err = e; + return error.ReadFailed; + }, + }; + } + + fn chunkedDiscardEndless( + reader: *Reader, + limit: std.Io.Limit, + chunk_len_ptr: *RemainingChunkLen, + ) (BodyError || std.Io.Reader.Error)!usize { + const in = reader.in; + len: switch (chunk_len_ptr.*) { + .head => { + var cp: ChunkParser = .init; + while (true) { + const i = cp.feed(in.buffered()); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + in.toss(i); + break; + }, + else => { + in.toss(i); + try in.fillMore(); + continue; + }, + } + } + if (cp.chunk_len == 0) return parseTrailers(reader, 0); + const n = try in.discard(limit.min(.limited(cp.chunk_len))); + chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); + return n; + }, + .n => { + if ((try in.peekByte()) != '\n') return error.HttpChunkInvalid; + in.toss(1); + continue :len .head; + }, + .rn => { + const rn = try in.peekArray(2); + if (rn[0] != '\r' or rn[1] != '\n') return error.HttpChunkInvalid; + in.toss(2); + continue :len .head; + }, + else => |remaining_chunk_len| { + const n = try in.discard(limit.min(.limited(remaining_chunk_len.int() - 2))); + chunk_len_ptr.* = .init(remaining_chunk_len.int() - n); + return n; + }, + } + } + + /// Called when next bytes in the stream are trailers, or "\r\n" to indicate + /// end of chunked body. + fn parseTrailers(reader: *Reader, amt_read: usize) (BodyError || std.Io.Reader.Error)!usize { + const in = reader.in; + const rn = try in.peekArray(2); + if (rn[0] == '\r' and rn[1] == '\n') { + in.toss(2); + reader.state = .ready; + assert(reader.trailers.len == 0); + return amt_read; + } + var hp: HeadParser = .{ .state = .seen_rn }; + var trailers_len: usize = 2; + while (true) { + if (in.buffer.len - trailers_len == 0) return error.HttpHeadersOversize; + const remaining = in.buffered()[trailers_len..]; + if (remaining.len == 0) { + try in.fillMore(); + continue; + } + trailers_len += hp.feed(remaining); + if (hp.state == .finished) { + reader.state = .ready; + reader.trailers = in.buffered()[0..trailers_len]; + in.toss(trailers_len); + return amt_read; + } + } + } +}; + +pub const Decompressor = union(enum) { + flate: std.compress.flate.Decompress, + zstd: std.compress.zstd.Decompress, + none: *std.Io.Reader, + + pub fn init( + decompressor: *Decompressor, + transfer_reader: *std.Io.Reader, + buffer: []u8, + content_encoding: ContentEncoding, + ) *std.Io.Reader { + switch (content_encoding) { + .identity => { + decompressor.* = .{ .none = transfer_reader }; + return transfer_reader; + }, + .deflate => { + decompressor.* = .{ .flate = .init(transfer_reader, .raw, buffer) }; + return &decompressor.flate.reader; + }, + .gzip => { + decompressor.* = .{ .flate = .init(transfer_reader, .gzip, buffer) }; + return &decompressor.flate.reader; + }, + .zstd => { + decompressor.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) }; + return &decompressor.zstd.reader; + }, + .compress => unreachable, + } + } +}; + +/// Request or response body. +pub const BodyWriter = struct { + /// Until the lifetime of `BodyWriter` ends, it is illegal to modify the + /// state of this other than via methods of `BodyWriter`. + http_protocol_output: *Writer, + state: State, + writer: Writer, + + pub const Error = Writer.Error; + + /// How many zeroes to reserve for hex-encoded chunk length. + const chunk_len_digits = 8; + const max_chunk_len: usize = std.math.pow(usize, 16, chunk_len_digits) - 1; + const chunk_header_template = ("0" ** chunk_len_digits) ++ "\r\n"; + + comptime { + assert(max_chunk_len == std.math.maxInt(u32)); + } + + pub const State = union(enum) { + /// End of connection signals the end of the stream. + none, + /// As a debugging utility, counts down to zero as bytes are written. + content_length: u64, + /// Each chunk is wrapped in a header and trailer. + chunked: Chunked, + /// Cleanly finished stream; connection can be reused. + end, + + pub const Chunked = union(enum) { + /// Index to the start of the hex-encoded chunk length in the chunk + /// header within the buffer of `BodyWriter.http_protocol_output`. + /// Buffered chunk data starts here plus length of `chunk_header_template`. + offset: usize, + /// We are in the middle of a chunk and this is how many bytes are + /// left until the next header. This includes +2 for "\r"\n", and + /// is zero for the beginning of the stream. + chunk_len: usize, + + pub const init: Chunked = .{ .chunk_len = 0 }; + }; + }; + + pub fn isEliding(w: *const BodyWriter) bool { + return w.writer.vtable.drain == Writer.discardingDrain; + } + + /// Sends all buffered data across `BodyWriter.http_protocol_output`. + pub fn flush(w: *BodyWriter) Error!void { + const out = w.http_protocol_output; + switch (w.state) { + .end, .none, .content_length => return out.flush(), + .chunked => |*chunked| switch (chunked.*) { + .offset => |offset| { + const chunk_len = out.end - offset - chunk_header_template.len; + if (chunk_len > 0) { + writeHex(out.buffer[offset..][0..chunk_len_digits], chunk_len); + chunked.* = .{ .chunk_len = 2 }; + } else { + out.end = offset; + chunked.* = .{ .chunk_len = 0 }; + } + try out.flush(); + }, + .chunk_len => return out.flush(), + }, + } + } + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header, then flushes. + /// + /// When using transfer-encoding: chunked, writes the end-of-stream message + /// with empty trailers, then flushes the stream to the system. Asserts any + /// started chunk has been completely finished. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `endUnflushed` + /// * `endChunked` + pub fn end(w: *BodyWriter) Error!void { + try endUnflushed(w); + try w.http_protocol_output.flush(); + } + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header. + /// + /// Otherwise, transfer-encoding: chunked is being used, and it writes the + /// end-of-stream message with empty trailers. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `end` + /// * `endChunked` + pub fn endUnflushed(w: *BodyWriter) Error!void { + switch (w.state) { + .end => unreachable, + .content_length => |len| { + assert(len == 0); // Trips when end() called before all bytes written. + w.state = .end; + }, + .none => {}, + .chunked => return endChunkedUnflushed(w, .{}), + } + } + + pub const EndChunkedOptions = struct { + trailers: []const Header = &.{}, + }; + + /// Writes the end-of-stream message and any optional trailers, flushing + /// the underlying stream. + /// + /// Asserts that the BodyWriter is using transfer-encoding: chunked. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `endChunkedUnflushed` + /// * `end` + pub fn endChunked(w: *BodyWriter, options: EndChunkedOptions) Error!void { + try endChunkedUnflushed(w, options); + try w.http_protocol_output.flush(); + } + + /// Writes the end-of-stream message and any optional trailers. + /// + /// Does not flush. + /// + /// Asserts that the BodyWriter is using transfer-encoding: chunked. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `endChunked` + /// * `endUnflushed` + /// * `end` + pub fn endChunkedUnflushed(w: *BodyWriter, options: EndChunkedOptions) Error!void { + const chunked = &w.state.chunked; + if (w.isEliding()) { + w.state = .end; + return; + } + const bw = w.http_protocol_output; + switch (chunked.*) { + .offset => |offset| { + const chunk_len = bw.end - offset - chunk_header_template.len; + writeHex(bw.buffer[offset..][0..chunk_len_digits], chunk_len); + try bw.writeAll("\r\n"); + }, + .chunk_len => |chunk_len| switch (chunk_len) { + 0 => {}, + 1 => try bw.writeByte('\n'), + 2 => try bw.writeAll("\r\n"), + else => unreachable, // An earlier write call indicated more data would follow. + }, + } + try bw.writeAll("0\r\n"); + for (options.trailers) |trailer| { + try bw.writeAll(trailer.name); + try bw.writeAll(": "); + try bw.writeAll(trailer.value); + try bw.writeAll("\r\n"); + } + try bw.writeAll("\r\n"); + w.state = .end; + } + + pub fn contentLengthDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.writeSplatHeader(w.buffered(), data, splat); + bw.state.content_length -= n; + return w.consume(n); + } + + pub fn noneDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.writeSplatHeader(w.buffered(), data, splat); + return w.consume(n); + } + + /// Returns `null` if size cannot be computed without making any syscalls. + pub fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.sendFileHeader(w.buffered(), file_reader, limit); + return w.consume(n); + } + + pub fn contentLengthSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.sendFileHeader(w.buffered(), file_reader, limit); + bw.state.content_length -= n; + return w.consume(n); + } + + pub fn chunkedSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const data_len = Writer.countSendFileLowerBound(w.end, file_reader, limit) orelse { + // If the file size is unknown, we cannot lower to a `sendFile` since we would + // have to flush the chunk header before knowing the chunk length. + return error.Unimplemented; + }; + const out = bw.http_protocol_output; + const chunked = &bw.state.chunked; + state: switch (chunked.*) { + .offset => |off| { + // TODO: is it better perf to read small files into the buffer? + const buffered_len = out.end - off - chunk_header_template.len; + const chunk_len = data_len + buffered_len; + writeHex(out.buffer[off..][0..chunk_len_digits], chunk_len); + const n = try out.sendFileHeader(w.buffered(), file_reader, limit); + chunked.* = .{ .chunk_len = data_len + 2 - n }; + return w.consume(n); + }, + .chunk_len => |chunk_len| l: switch (chunk_len) { + 0 => { + const off = out.end; + const header_buf = try out.writableArray(chunk_header_template.len); + @memcpy(header_buf, chunk_header_template); + chunked.* = .{ .offset = off }; + continue :state .{ .offset = off }; + }, + 1 => { + try out.writeByte('\n'); + chunked.chunk_len = 0; + continue :l 0; + }, + 2 => { + try out.writeByte('\r'); + chunked.chunk_len = 1; + continue :l 1; + }, + else => { + const new_limit = limit.min(.limited(chunk_len - 2)); + const n = try out.sendFileHeader(w.buffered(), file_reader, new_limit); + chunked.chunk_len = chunk_len - n; + return w.consume(n); + }, + }, + } + } + + pub fn chunkedDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const data_len = w.end + Writer.countSplat(data, splat); + const chunked = &bw.state.chunked; + state: switch (chunked.*) { + .offset => |offset| { + if (out.unusedCapacityLen() >= data_len) { + return w.consume(out.writeSplatHeader(w.buffered(), data, splat) catch unreachable); + } + const buffered_len = out.end - offset - chunk_header_template.len; + const chunk_len = data_len + buffered_len; + writeHex(out.buffer[offset..][0..chunk_len_digits], chunk_len); + const n = try out.writeSplatHeader(w.buffered(), data, splat); + chunked.* = .{ .chunk_len = data_len + 2 - n }; + return w.consume(n); + }, + .chunk_len => |chunk_len| l: switch (chunk_len) { + 0 => { + const offset = out.end; + const header_buf = try out.writableArray(chunk_header_template.len); + @memcpy(header_buf, chunk_header_template); + chunked.* = .{ .offset = offset }; + continue :state .{ .offset = offset }; + }, + 1 => { + try out.writeByte('\n'); + chunked.chunk_len = 0; + continue :l 0; + }, + 2 => { + try out.writeByte('\r'); + chunked.chunk_len = 1; + continue :l 1; + }, + else => { + const n = try out.writeSplatHeaderLimit(w.buffered(), data, splat, .limited(chunk_len - 2)); + chunked.chunk_len = chunk_len - n; + return w.consume(n); + }, + }, + } + } + + /// Writes an integer as base 16 to `buf`, right-aligned, assuming the + /// buffer has already been filled with zeroes. + fn writeHex(buf: []u8, x: usize) void { + assert(std.mem.allEqual(u8, buf, '0')); + const base = 16; + var index: usize = buf.len; + var a = x; + while (a > 0) { + const digit = a % base; + index -= 1; + buf[index] = std.fmt.digitToChar(@intCast(digit), .lower); + a /= base; + } + } +}; + test { + _ = Server; + _ = Status; + _ = Method; + _ = ChunkParser; + _ = HeadParser; + if (builtin.os.tag != .wasi) { _ = Client; - _ = Method; - _ = Server; - _ = Status; - _ = HeadParser; - _ = ChunkParser; - _ = WebSocket; _ = @import("http/test.zig"); } } diff --git a/lib/std/http/ChunkParser.zig b/lib/std/http/ChunkParser.zig index adcdc74bc7..7c628ec327 100644 --- a/lib/std/http/ChunkParser.zig +++ b/lib/std/http/ChunkParser.zig @@ -1,5 +1,8 @@ //! Parser for transfer-encoding: chunked. +const ChunkParser = @This(); +const std = @import("std"); + state: State, chunk_len: u64, @@ -97,9 +100,6 @@ pub fn feed(p: *ChunkParser, bytes: []const u8) usize { return bytes.len; } -const ChunkParser = @This(); -const std = @import("std"); - test feed { const testing = std.testing; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 20f6018e45..61a9eeb5c3 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -13,9 +13,10 @@ const net = std.net; const Uri = std.Uri; const Allocator = mem.Allocator; const assert = std.debug.assert; +const Writer = std.io.Writer; +const Reader = std.io.Reader; const Client = @This(); -const proto = @import("protocol.zig"); pub const disable_tls = std.options.http_disable_tls; @@ -24,6 +25,12 @@ allocator: Allocator, ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, +/// Used both for the reader and writer buffers. +tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len, +/// If non-null, ssl secrets are logged to a stream. Creating such a stream +/// allows other processes with access to that stream to decrypt all +/// traffic over connections created with this `Client`. +ssl_key_log: ?*std.crypto.tls.Client.SslKeyLog = null, /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. @@ -31,6 +38,13 @@ next_https_rescan_certs: bool = true, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, +/// Each `Connection` allocates this amount for the reader buffer. +/// +/// If the entire HTTP header cannot fit in this amount of bytes, +/// `error.HttpHeadersOversize` will be returned from `Request.wait`. +read_buffer_size: usize = 4096, +/// Each `Connection` allocates this amount for the writer buffer. +write_buffer_size: usize = 1024, /// If populated, all http traffic travels through this third party. /// This field cannot be modified while the client has active connections. @@ -41,7 +55,7 @@ http_proxy: ?*Proxy = null, /// Pointer to externally-owned memory. https_proxy: ?*Proxy = null, -/// A set of linked lists of connections that can be reused. +/// A Least-Recently-Used cache of open connections to be reused. pub const ConnectionPool = struct { mutex: std.Thread.Mutex = .{}, /// Open connections that are currently in use. @@ -55,11 +69,13 @@ pub const ConnectionPool = struct { pub const Criteria = struct { host: []const u8, port: u16, - protocol: Connection.Protocol, + protocol: Protocol, }; - /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// Finds and acquires a connection from the connection pool matching the criteria. /// If no connection is found, null is returned. + /// + /// Threadsafe. pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -71,7 +87,7 @@ pub const ConnectionPool = struct { if (connection.port != criteria.port) continue; // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) - if (!std.ascii.eqlIgnoreCase(connection.host, criteria.host)) continue; + if (!std.ascii.eqlIgnoreCase(connection.host(), criteria.host)) continue; pool.acquireUnsafe(connection); return connection; @@ -96,28 +112,25 @@ pub const ConnectionPool = struct { return pool.acquireUnsafe(connection); } - /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// Tries to release a connection back to the connection pool. /// If the connection is marked as closing, it will be closed instead. /// - /// The allocator must be the owner of all nodes in this pool. - /// The allocator must be the owner of all resources associated with the connection. - pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { + /// `allocator` must be the same one used to create `connection`. + /// + /// Threadsafe. + pub fn release(pool: *ConnectionPool, connection: *Connection) void { pool.mutex.lock(); defer pool.mutex.unlock(); pool.used.remove(&connection.pool_node); - if (connection.closing or pool.free_size == 0) { - connection.close(allocator); - return allocator.destroy(connection); - } + if (connection.closing or pool.free_size == 0) return connection.destroy(); if (pool.free_len >= pool.free_size) { const popped: *Connection = @fieldParentPtr("pool_node", pool.free.popFirst().?); pool.free_len -= 1; - popped.close(allocator); - allocator.destroy(popped); + popped.destroy(); } if (connection.proxied) { @@ -138,9 +151,11 @@ pub const ConnectionPool = struct { pool.used.append(&connection.pool_node); } - /// Resizes the connection pool. This function is threadsafe. + /// Resizes the connection pool. /// /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + /// + /// Threadsafe. pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -158,538 +173,586 @@ pub const ConnectionPool = struct { pool.free_size = new_size; } - /// Frees the connection pool and closes all connections within. This function is threadsafe. + /// Frees the connection pool and closes all connections within. /// /// All future operations on the connection pool will deadlock. - pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { + /// + /// Threadsafe. + pub fn deinit(pool: *ConnectionPool) void { pool.mutex.lock(); var next = pool.free.first; while (next) |node| { const connection: *Connection = @fieldParentPtr("pool_node", node); next = node.next; - connection.close(allocator); - allocator.destroy(connection); + connection.destroy(); } next = pool.used.first; while (next) |node| { const connection: *Connection = @fieldParentPtr("pool_node", node); next = node.next; - connection.close(allocator); - allocator.destroy(node); + connection.destroy(); } pool.* = undefined; } }; -/// An interface to either a plain or TLS connection. -pub const Connection = struct { - stream: net.Stream, - /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *std.crypto.tls.Client else void, +pub const Protocol = enum { + plain, + tls, + fn port(protocol: Protocol) u16 { + return switch (protocol) { + .plain => 80, + .tls => 443, + }; + } + + pub fn fromScheme(scheme: []const u8) ?Protocol { + const protocol_map = std.StaticStringMap(Protocol).initComptime(.{ + .{ "http", .plain }, + .{ "ws", .plain }, + .{ "https", .tls }, + .{ "wss", .tls }, + }); + return protocol_map.get(scheme); + } + + pub fn fromUri(uri: Uri) ?Protocol { + return fromScheme(uri.scheme); + } +}; + +pub const Connection = struct { + client: *Client, + stream_writer: net.Stream.Writer, + stream_reader: net.Stream.Reader, /// Entry in `ConnectionPool.used` or `ConnectionPool.free`. pool_node: std.DoublyLinkedList.Node, - - /// The protocol that this connection is using. + port: u16, + host_len: u8, + proxied: bool, + closing: bool, protocol: Protocol, - /// The host that this connection is connected to. - host: []u8, + const Plain = struct { + connection: Connection, - /// The port that this connection is connected to. - port: u16, - - /// Whether this connection is proxied and is not directly connected. - proxied: bool = false, - - /// Whether this connection is closing when we're done with it. - closing: bool = false, - - read_start: BufferSize = 0, - read_end: BufferSize = 0, - write_end: BufferSize = 0, - read_buf: [buffer_size]u8 = undefined, - write_buf: [buffer_size]u8 = undefined, - - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - const BufferSize = std.math.IntFittingRange(0, buffer_size); - - pub const Protocol = enum { plain, tls }; - - pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - return conn.tls_client.readv(conn.stream, buffers) catch |err| { - // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; - - switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; - } - - pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.readvDirectTls(buffers); + fn create( + client: *Client, + remote_host: []const u8, + port: u16, + stream: net.Stream, + ) error{OutOfMemory}!*Plain { + const gpa = client.allocator; + const alloc_len = allocLen(client, remote_host.len); + const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len); + errdefer gpa.free(base); + const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len]; + const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size]; + const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size]; + assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len); + @memcpy(host_buffer, remote_host); + const plain: *Plain = @ptrCast(base); + plain.* = .{ + .connection = .{ + .client = client, + .stream_writer = stream.writer(socket_write_buffer), + .stream_reader = stream.reader(socket_read_buffer), + .pool_node = .{}, + .port = port, + .host_len = @intCast(remote_host.len), + .proxied = false, + .closing = false, + .protocol = .plain, + }, + }; + return plain; } - return conn.stream.readv(buffers) catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; - } - - /// Refills the read buffer with data from the connection. - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - var iovecs = [1]std.posix.iovec{ - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); - } - - /// Returns the current slice of buffered data. - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; - } - - /// Discards the given number of bytes from the read buffer. - pub fn drop(conn: *Connection, num: BufferSize) void { - conn.read_start += num; - } - - /// Reads data from the connection into the given buffer. - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - conn.read_start += @intCast(available_buffer); - - return available_buffer; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - conn.read_start += available_read; - - return available_read; + fn destroy(plain: *Plain) void { + const c = &plain.connection; + const gpa = c.client.allocator; + const base: [*]align(@alignOf(Plain)) u8 = @ptrCast(plain); + gpa.free(base[0..allocLen(c.client, c.host_len)]); } - var iovecs = [2]std.posix.iovec{ - .{ .base = buffer.ptr, .len = buffer.len }, - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - - if (nread > buffer.len) { - conn.read_start = 0; - conn.read_end = @intCast(nread - buffer.len); - return buffer.len; + fn allocLen(client: *Client, host_len: usize) usize { + return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size; } - return nread; - } - - pub const ReadError = error{ - TlsFailure, - TlsAlert, - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, + fn host(plain: *Plain) []u8 { + const base: [*]u8 = @ptrCast(plain); + return base[@sizeOf(Plain)..][0..plain.connection.host_len]; + } }; - pub const Reader = std.io.GenericReader(*Connection, ReadError, read); + const Tls = struct { + client: std.crypto.tls.Client, + connection: Connection, - pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.writeAllDirectTls(buffer); + fn create( + client: *Client, + remote_host: []const u8, + port: u16, + stream: net.Stream, + ) error{ OutOfMemory, TlsInitializationFailed }!*Tls { + const gpa = client.allocator; + const alloc_len = allocLen(client, remote_host.len); + const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len); + errdefer gpa.free(base); + const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len]; + const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size]; + const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size]; + const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; + assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len); + @memcpy(host_buffer, remote_host); + const tls: *Tls = @ptrCast(base); + tls.* = .{ + .connection = .{ + .client = client, + .stream_writer = stream.writer(socket_write_buffer), + .stream_reader = stream.reader(&.{}), + .pool_node = .{}, + .port = port, + .host_len = @intCast(remote_host.len), + .proxied = false, + .closing = false, + .protocol = .tls, + }, + // TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true + .client = std.crypto.tls.Client.init( + tls.connection.stream_reader.interface(), + &tls.connection.stream_writer.interface, + .{ + .host = .{ .explicit = remote_host }, + .ca = .{ .bundle = client.ca_bundle }, + .ssl_key_log = client.ssl_key_log, + .read_buffer = tls_read_buffer, + .write_buffer = tls_write_buffer, + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + .allow_truncation_attacks = true, + }, + ) catch return error.TlsInitializationFailed, + }; + return tls; } - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - /// Writes the given buffer to the connection. - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - if (conn.write_buf.len - conn.write_end < buffer.len) { - try conn.flush(); - - if (buffer.len > conn.write_buf.len) { - try conn.writeAllDirect(buffer); - return buffer.len; - } + fn destroy(tls: *Tls) void { + const c = &tls.connection; + const gpa = c.client.allocator; + const base: [*]align(@alignOf(Tls)) u8 = @ptrCast(tls); + gpa.free(base[0..allocLen(c.client, c.host_len)]); } - @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); - conn.write_end += @intCast(buffer.len); + fn allocLen(client: *Client, host_len: usize) usize { + return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size; + } - return buffer.len; - } - - /// Returns a buffer to be filled with exactly len bytes to write to the connection. - pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { - if (conn.write_buf.len - conn.write_end < len) try conn.flush(); - defer conn.write_end += len; - return conn.write_buf[conn.write_end..][0..len]; - } - - /// Flushes the write buffer to the connection. - pub fn flush(conn: *Connection) WriteError!void { - if (conn.write_end == 0) return; - - try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); - conn.write_end = 0; - } - - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, + fn host(tls: *Tls) []u8 { + const base: [*]u8 = @ptrCast(tls); + return base[@sizeOf(Tls)..][0..tls.connection.host_len]; + } }; - pub const Writer = std.io.GenericWriter(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; + fn getStream(c: *Connection) net.Stream { + return c.stream_reader.getStream(); } - /// Closes the connection. - pub fn close(conn: *Connection, allocator: Allocator) void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; + fn host(c: *Connection) []u8 { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return tls.host(); + }, + .plain => { + const plain: *Plain = @fieldParentPtr("connection", c); + return plain.host(); + }, + }; + } - // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; - if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close(); - allocator.destroy(conn.tls_client); + /// If this is called without calling `flush` or `end`, data will be + /// dropped unsent. + pub fn destroy(c: *Connection) void { + c.getStream().close(); + switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + tls.destroy(); + }, + .plain => { + const plain: *Plain = @fieldParentPtr("connection", c); + plain.destroy(); + }, } + } - conn.stream.close(); - allocator.free(conn.host); + /// HTTP protocol from client to server. + /// This either goes directly to `stream_writer`, or to a TLS client. + pub fn writer(c: *Connection) *Writer { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return &tls.client.writer; + }, + .plain => &c.stream_writer.interface, + }; + } + + /// HTTP protocol from server to client. + /// This either comes directly from `stream_reader`, or from a TLS client. + pub fn reader(c: *Connection) *Reader { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return &tls.client.reader; + }, + .plain => c.stream_reader.interface(), + }; + } + + pub fn flush(c: *Connection) Writer.Error!void { + if (c.protocol == .tls) { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + try tls.client.writer.flush(); + } + try c.stream_writer.interface.flush(); + } + + /// If the connection is a TLS connection, sends the close_notify alert. + /// + /// Flushes all buffers. + pub fn end(c: *Connection) Writer.Error!void { + if (c.protocol == .tls) { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + try tls.client.end(); + try tls.client.writer.flush(); + } + try c.stream_writer.interface.flush(); } }; -/// The mode of transport for requests. -pub const RequestTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for response messages. -pub const Compression = union(enum) { - //deflate: std.compress.flate.Decompress, - //gzip: std.compress.flate.Decompress, - // https://github.com/ziglang/zig/issues/18937 - //zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP response originating from a server. pub const Response = struct { - version: http.Version, - status: http.Status, - reason: []const u8, + request: *Request, + /// Pointers in this struct are invalidated with the next call to + /// `receiveHead`. + head: Head, - /// Points into the user-provided `server_header_buffer`. - location: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_type: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_disposition: ?[]const u8 = null, + pub const Head = struct { + bytes: []const u8, + version: http.Version, + status: http.Status, + reason: []const u8, + location: ?[]const u8 = null, + content_type: ?[]const u8 = null, + content_disposition: ?[]const u8 = null, - keep_alive: bool, + keep_alive: bool, - /// If present, the number of bytes in the response body. - content_length: ?u64 = null, + /// If present, the number of bytes in the response body. + content_length: ?u64 = null, - /// If present, the transfer encoding of the response body, otherwise none. - transfer_encoding: http.TransferEncoding = .none, + transfer_encoding: http.TransferEncoding = .none, + content_encoding: http.ContentEncoding = .identity, - /// If present, the compression of the response body, otherwise identity (no compression). - transfer_compression: http.ContentEncoding = .identity, - - parser: proto.HeadersParser, - compression: Compression = .none, - - /// Whether the response body should be skipped. Any data read from the - /// response body will be discarded. - skip: bool = false, - - pub const ParseError = error{ - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, - }; - - pub fn parse(res: *Response, bytes: []const u8) ParseError!void { - var it = mem.splitSequence(u8, bytes, "\r\n"); - - const first_line = it.next().?; - if (first_line.len < 12) { - return error.HttpHeadersInvalid; - } - - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); - const reason = mem.trimStart(u8, first_line[12..], " "); - - res.version = version; - res.status = status; - res.reason = reason; - res.keep_alive = switch (version) { - .@"HTTP/1.0" => false, - .@"HTTP/1.1" => true, + pub const ParseError = error{ + HttpConnectionHeaderUnsupported, + HttpContentEncodingUnsupported, + HttpHeaderContinuationsUnsupported, + HttpHeadersInvalid, + HttpTransferEncodingUnsupported, + InvalidContentLength, }; - while (it.next()) |line| { - if (line.len == 0) return; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, + pub fn parse(bytes: []const u8) ParseError!Head { + var res: Head = .{ + .bytes = bytes, + .status = undefined, + .reason = undefined, + .version = undefined, + .keep_alive = false, + }; + var it = mem.splitSequence(u8, bytes, "\r\n"); + + const first_line = it.next().?; + if (first_line.len < 12) { + return error.HttpHeadersInvalid; } - var line_it = mem.splitScalar(u8, line, ':'); - const header_name = line_it.next().?; - const header_value = mem.trim(u8, line_it.rest(), " \t"); - if (header_name.len == 0) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; + const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); + const reason = mem.trimLeft(u8, first_line[12..], " "); - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - res.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { - res.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { - res.content_disposition = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); + res.version = version; + res.status = status; + res.reason = reason; + res.keep_alive = switch (version) { + .@"HTTP/1.0" => false, + .@"HTTP/1.1" => true, + }; - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding - res.transfer_encoding = transfer; - - next = iter.next(); + while (it.next()) |line| { + if (line.len == 0) return res; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, } - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); + var line_it = mem.splitScalar(u8, line, ':'); + const header_name = line_it.next().?; + const header_value = mem.trim(u8, line_it.rest(), " \t"); + if (header_name.len == 0) return error.HttpHeadersInvalid; - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported - res.transfer_compression = transfer; + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + res.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { + res.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { + res.content_disposition = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + res.transfer_encoding = transfer; + + next = iter.next(); + } + + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); + + if (http.ContentEncoding.fromString(trimmed_second)) |transfer| { + if (res.content_encoding != .identity) return error.HttpHeadersInvalid; // double compression is not supported + res.content_encoding = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; + + if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; + + res.content_length = content_length; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (res.content_encoding != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (http.ContentEncoding.fromString(trimmed)) |ce| { + res.content_encoding = ce; } else { - return error.HttpTransferEncodingUnsupported; + return error.HttpContentEncodingUnsupported; } } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - - if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; - - res.content_length = content_length; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; - - const trimmed = mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - res.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } } + return error.HttpHeadersInvalid; // missing empty line } - return error.HttpHeadersInvalid; // missing empty line + + test parse { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + const head = try Head.parse(response_bytes); + + try testing.expectEqual(.@"HTTP/1.1", head.version); + try testing.expectEqualStrings("OK", head.reason); + try testing.expectEqual(.ok, head.status); + + try testing.expectEqualStrings("url", head.location.?); + try testing.expectEqualStrings("text/plain", head.content_type.?); + try testing.expectEqualStrings("attachment; filename=example.txt", head.content_disposition.?); + + try testing.expectEqual(true, head.keep_alive); + try testing.expectEqual(10, head.content_length.?); + try testing.expectEqual(.chunked, head.transfer_encoding); + try testing.expectEqual(.deflate, head.content_encoding); + } + + pub fn iterateHeaders(h: Head) http.HeaderIterator { + return .init(h.bytes); + } + + test iterateHeaders { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + const head = try Head.parse(response_bytes); + var it = head.iterateHeaders(); + { + const header = it.next().?; + try testing.expectEqualStrings("LOcation", header.name); + try testing.expectEqualStrings("url", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-tYpe", header.name); + try testing.expectEqualStrings("text/plain", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-disposition", header.name); + try testing.expectEqualStrings("attachment; filename=example.txt", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-Length", header.name); + try testing.expectEqualStrings("10", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("TRansfer-encoding", header.name); + try testing.expectEqualStrings("deflate, chunked", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("connectioN", header.name); + try testing.expectEqualStrings("keep-alive", header.value); + try testing.expect(!it.is_trailer); + } + try testing.expectEqual(null, it.next()); + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } + + fn parseInt3(text: *const [3]u8) u10 { + const nnn: @Vector(3, u8) = text.*; + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, (nnn -% zero) *% mmm); + } + + test parseInt3 { + const expectEqual = testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000")); + try expectEqual(@as(u10, 418), parseInt3("418")); + try expectEqual(@as(u10, 999), parseInt3("999")); + } + }; + + /// If compressed body has been negotiated this will return compressed bytes. + /// + /// If the returned `Reader` returns `error.ReadFailed` the error is + /// available via `bodyErr`. + /// + /// Asserts that this function is only called once. + /// + /// See also: + /// * `readerDecompressing` + pub fn reader(response: *Response, buffer: []u8) *Reader { + const req = response.request; + if (!req.method.responseHasBody()) return .ending; + const head = &response.head; + return req.reader.bodyReader(buffer, head.transfer_encoding, head.content_length); } - test parse { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; + /// If compressed body has been negotiated this will return decompressed bytes. + /// + /// If the returned `Reader` returns `error.ReadFailed` the error is + /// available via `bodyErr`. + /// + /// Asserts that this function is only called once. + /// + /// See also: + /// * `reader` + pub fn readerDecompressing( + response: *Response, + decompressor: *http.Decompressor, + decompression_buffer: []u8, + ) *Reader { + const head = &response.head; + return response.request.reader.bodyReaderDecompressing( + head.transfer_encoding, + head.content_length, + head.content_encoding, + decompressor, + decompression_buffer, + ); + } - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = .init(&header_buffer), + /// After receiving `error.ReadFailed` from the `Reader` returned by + /// `reader` or `readerDecompressing`, this function accesses the + /// more specific error code. + pub fn bodyErr(response: *const Response) ?http.Reader.BodyError { + return response.request.reader.body_err; + } + + pub fn iterateTrailers(response: *const Response) http.HeaderIterator { + const r = &response.request.reader; + assert(r.state == .ready); + return .{ + .bytes = r.trailers, + .index = 0, + .is_trailer = true, }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - try res.parse(response_bytes); - - try testing.expectEqual(.@"HTTP/1.1", res.version); - try testing.expectEqualStrings("OK", res.reason); - try testing.expectEqual(.ok, res.status); - - try testing.expectEqualStrings("url", res.location.?); - try testing.expectEqualStrings("text/plain", res.content_type.?); - try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?); - - try testing.expectEqual(true, res.keep_alive); - try testing.expectEqual(10, res.content_length.?); - try testing.expectEqual(.chunked, res.transfer_encoding); - try testing.expectEqual(.deflate, res.transfer_compression); - } - - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } - - fn parseInt3(text: *const [3]u8) u10 { - const nnn: @Vector(3, u8) = text.*; - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, (nnn -% zero) *% mmm); - } - - test parseInt3 { - const expectEqual = testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000")); - try expectEqual(@as(u10, 418), parseInt3("418")); - try expectEqual(@as(u10, 999), parseInt3("999")); - } - - pub fn iterateHeaders(r: Response) http.HeaderIterator { - return .init(r.parser.get()); - } - - test iterateHeaders { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = .init(&header_buffer), - }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - var it = res.iterateHeaders(); - { - const header = it.next().?; - try testing.expectEqualStrings("LOcation", header.name); - try testing.expectEqualStrings("url", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-tYpe", header.name); - try testing.expectEqualStrings("text/plain", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-disposition", header.name); - try testing.expectEqualStrings("attachment; filename=example.txt", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-Length", header.name); - try testing.expectEqualStrings("10", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("TRansfer-encoding", header.name); - try testing.expectEqualStrings("deflate, chunked", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("connectioN", header.name); - try testing.expectEqualStrings("keep-alive", header.value); - try testing.expect(!it.is_trailer); - } - try testing.expectEqual(null, it.next()); } }; -/// A HTTP request that has been sent. -/// -/// Order of operations: open -> send[ -> write -> finish] -> wait -> read pub const Request = struct { + /// This field is provided so that clients can observe redirected URIs. + /// + /// Its backing memory is externally provided by API users when creating a + /// request, and then again provided externally via `redirect_buffer` to + /// `receiveHead`. uri: Uri, client: *Client, /// This is null when the connection is released. connection: ?*Connection, + reader: http.Reader, keep_alive: bool, method: http.Method, version: http.Version = .@"HTTP/1.1", - transfer_encoding: RequestTransfer, + transfer_encoding: TransferEncoding, redirect_behavior: RedirectBehavior, + accept_encoding: @TypeOf(default_accept_encoding) = default_accept_encoding, /// Whether the request should handle a 100-continue response before sending the request body. handle_continue: bool, - /// The response associated with this request. - /// - /// This field is undefined until `wait` is called. - response: Response, - /// Standard headers that have default, but overridable, behavior. headers: Headers, @@ -703,6 +766,20 @@ pub const Request = struct { /// Externally-owned; must outlive the Request. privileged_headers: []const http.Header, + pub const default_accept_encoding: [@typeInfo(http.ContentEncoding).@"enum".fields.len]bool = b: { + var result: [@typeInfo(http.ContentEncoding).@"enum".fields.len]bool = @splat(false); + result[@intFromEnum(http.ContentEncoding.gzip)] = true; + result[@intFromEnum(http.ContentEncoding.deflate)] = true; + result[@intFromEnum(http.ContentEncoding.identity)] = true; + break :b result; + }; + + pub const TransferEncoding = union(enum) { + content_length: u64, + chunked: void, + none: void, + }; + pub const Headers = struct { host: Value = .default, authorization: Value = .default, @@ -742,98 +819,102 @@ pub const Request = struct { } }; - /// Frees all resources associated with the request. - pub fn deinit(req: *Request) void { - if (req.connection) |connection| { - if (!req.response.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - connection.closing = true; - } - req.client.connection_pool.release(req.client.allocator, connection); + /// Returns the request's `Connection` back to the pool of the `Client`. + pub fn deinit(r: *Request) void { + r.reader.restituteHeadBuffer(); + if (r.connection) |connection| { + connection.closing = connection.closing or switch (r.reader.state) { + .ready => false, + .received_head => r.method.requestHasBody(), + else => true, + }; + r.client.connection_pool.release(connection); } - req.* = undefined; + r.* = undefined; } - // This function must deallocate all resources associated with the request, - // or keep those which will be used. - // This needs to be kept in sync with deinit and request. - fn redirect(req: *Request, uri: Uri) !void { - assert(req.response.parser.done); + /// Sends and flushes a complete request as only HTTP head, no body. + pub fn sendBodiless(r: *Request) Writer.Error!void { + try sendBodilessUnflushed(r); + try r.connection.?.flush(); + } - req.client.connection_pool.release(req.client.allocator, req.connection.?); - req.connection = null; + /// Sends but does not flush a complete request as only HTTP head, no body. + pub fn sendBodilessUnflushed(r: *Request) Writer.Error!void { + assert(r.transfer_encoding == .none); + assert(!r.method.requestHasBody()); + try sendHead(r); + } - var server_header: std.heap.FixedBufferAllocator = .init(req.response.parser.header_bytes_buffer); - defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + /// Transfers the HTTP head over the connection and flushes. + /// + /// See also: + /// * `sendBodyUnflushed` + pub fn sendBody(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter { + const result = try sendBodyUnflushed(r, buffer); + try r.connection.?.flush(); + return result; + } - const new_host = valid_uri.host.?.raw; - const prev_host = req.uri.host.?.raw; - const keep_privileged_headers = - std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and - std.ascii.endsWithIgnoreCase(new_host, prev_host) and - (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); - if (!keep_privileged_headers) { - // When redirecting to a different domain, strip privileged headers. - req.privileged_headers = &.{}; - } - - if (switch (req.response.status) { - .see_other => true, - .moved_permanently, .found => req.method == .POST, - else => false, - }) { - // A redirect to a GET must change the method and remove the body. - req.method = .GET; - req.transfer_encoding = .none; - req.headers.content_type = .omit; - } - - if (req.transfer_encoding != .none) { - // The request body has already been sent. The request is - // still in a valid state, but the redirect must be handled - // manually. - return error.RedirectRequiresResend; - } - - req.uri = valid_uri; - req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol); - req.redirect_behavior.subtractOne(); - req.response.parser.reset(); - - req.response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = req.response.parser, + /// Transfers the HTTP head over the connection, which is not flushed until + /// `BodyWriter.flush` or `BodyWriter.end` is called. + /// + /// See also: + /// * `sendBody` + pub fn sendBodyUnflushed(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter { + assert(r.method.requestHasBody()); + try sendHead(r); + const http_protocol_output = r.connection.?.writer(); + return switch (r.transfer_encoding) { + .chunked => .{ + .http_protocol_output = http_protocol_output, + .state = .{ .chunked = .init }, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.chunkedDrain, + .sendFile = http.BodyWriter.chunkedSendFile, + }, + }, + }, + .content_length => |len| .{ + .http_protocol_output = http_protocol_output, + .state = .{ .content_length = len }, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.contentLengthDrain, + .sendFile = http.BodyWriter.contentLengthSendFile, + }, + }, + }, + .none => .{ + .http_protocol_output = http_protocol_output, + .state = .none, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.noneDrain, + .sendFile = http.BodyWriter.noneSendFile, + }, + }, + }, }; } - pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + /// Sends HTTP headers without flushing. + fn sendHead(r: *Request) Writer.Error!void { + const uri = r.uri; + const connection = r.connection.?; + const w = connection.writer(); - /// Send the HTTP request headers to the server. - pub fn send(req: *Request) SendError!void { - if (!req.method.requestHasBody() and req.transfer_encoding != .none) - return error.UnsupportedTransferEncoding; - - const connection = req.connection.?; - var connection_writer_adapter = connection.writer().adaptToNewApi(); - const w = &connection_writer_adapter.new_interface; - sendAdapted(req, connection, w) catch |err| switch (err) { - error.WriteFailed => return connection_writer_adapter.err.?, - else => |e| return e, - }; - } - - fn sendAdapted(req: *Request, connection: *Connection, w: *std.io.Writer) !void { - try req.method.format(w); + try r.method.write(w); try w.writeByte(' '); - if (req.method == .CONNECT) { - try req.uri.writeToStream(w, .{ .authority = true }); + if (r.method == .CONNECT) { + try uri.writeToStream(.{ .authority = true }, w); } else { - try req.uri.writeToStream(w, .{ + try uri.writeToStream(.{ .scheme = connection.proxied, .authentication = connection.proxied, .authority = connection.proxied, @@ -842,58 +923,64 @@ pub const Request = struct { }); } try w.writeByte(' '); - try w.writeAll(@tagName(req.version)); + try w.writeAll(@tagName(r.version)); try w.writeAll("\r\n"); - if (try emitOverridableHeader("host: ", req.headers.host, w)) { + if (try emitOverridableHeader("host: ", r.headers.host, w)) { try w.writeAll("host: "); - try req.uri.writeToStream(w, .{ .authority = true }); + try uri.writeToStream(.{ .authority = true }, w); try w.writeAll("\r\n"); } - if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { - if (req.uri.user != null or req.uri.password != null) { + if (try emitOverridableHeader("authorization: ", r.headers.authorization, w)) { + if (uri.user != null or uri.password != null) { try w.writeAll("authorization: "); - const authorization = try connection.allocWriteBuffer( - @intCast(basic_authorization.valueLengthFromUri(req.uri)), - ); - assert(basic_authorization.value(req.uri, authorization).len == authorization.len); + try basic_authorization.write(uri, w); try w.writeAll("\r\n"); } } - if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { + if (try emitOverridableHeader("user-agent: ", r.headers.user_agent, w)) { try w.writeAll("user-agent: zig/"); try w.writeAll(builtin.zig_version_string); try w.writeAll(" (std.http)\r\n"); } - if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { - if (req.keep_alive) { + if (try emitOverridableHeader("connection: ", r.headers.connection, w)) { + if (r.keep_alive) { try w.writeAll("connection: keep-alive\r\n"); } else { try w.writeAll("connection: close\r\n"); } } - if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { - // https://github.com/ziglang/zig/issues/18937 - //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); - try w.writeAll("accept-encoding: gzip, deflate\r\n"); + if (try emitOverridableHeader("accept-encoding: ", r.headers.accept_encoding, w)) { + try w.writeAll("accept-encoding: "); + for (r.accept_encoding, 0..) |enabled, i| { + if (!enabled) continue; + const tag: http.ContentEncoding = @enumFromInt(i); + if (tag == .identity) continue; + const tag_name = @tagName(tag); + try w.ensureUnusedCapacity(tag_name.len + 2); + try w.writeAll(tag_name); + try w.writeAll(", "); + } + w.undo(2); + try w.writeAll("\r\n"); } - switch (req.transfer_encoding) { + switch (r.transfer_encoding) { .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), .none => {}, } - if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { + if (try emitOverridableHeader("content-type: ", r.headers.content_type, w)) { // The default is to omit content-type if not provided because // "application/octet-stream" is redundant. } - for (req.extra_headers) |header| { + for (r.extra_headers) |header| { assert(header.name.len != 0); try w.writeAll(header.name); @@ -904,8 +991,8 @@ pub const Request = struct { if (connection.proxied) proxy: { const proxy = switch (connection.protocol) { - .plain => req.client.http_proxy, - .tls => req.client.https_proxy, + .plain => r.client.http_proxy, + .tls => r.client.https_proxy, } orelse break :proxy; const authorization = proxy.authorization orelse break :proxy; @@ -915,282 +1002,198 @@ pub const Request = struct { } try w.writeAll("\r\n"); - - try connection.flush(); } - /// Returns true if the default behavior is required, otherwise handles - /// writing (or not writing) the header. - fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { - switch (v) { - .default => return true, - .omit => return false, - .override => |x| { - try w.writeAll(prefix); - try w.writeAll(x); - try w.writeAll("\r\n"); - return false; - }, - } - } + pub const ReceiveHeadError = http.Reader.HeadError || ConnectError || error{ + /// Server sent headers that did not conform to the HTTP protocol. + /// + /// To find out more detailed diagnostics, `http.Reader.head_buffer` can be + /// passed directly to `Request.Head.parse`. + HttpHeadersInvalid, + TooManyHttpRedirects, + /// This can be avoided by calling `receiveHead` before sending the + /// request body. + RedirectRequiresResend, + HttpRedirectLocationMissing, + HttpRedirectLocationOversize, + HttpRedirectLocationInvalid, + HttpContentEncodingUnsupported, + HttpChunkInvalid, + HttpChunkTruncated, + HttpHeadersOversize, + UnsupportedUriScheme, - const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + /// Sending the request failed. Error code can be found on the + /// `Connection` object. + WriteFailed, + }; - const TransferReader = std.io.GenericReader(*Request, TransferReadError, transferRead); - - fn transferReader(req: *Request) TransferReader { - return .{ .context = req }; - } - - fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { - if (req.response.parser.done) return 0; - - var index: usize = 0; - while (index == 0) { - const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); - if (amt == 0 and req.response.parser.done) break; - index += amt; - } - - return index; - } - - pub const WaitError = RequestError || SendError || TransferReadError || - proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || - error{ - TooManyHttpRedirects, - RedirectRequiresResend, - HttpRedirectLocationMissing, - HttpRedirectLocationInvalid, - CompressionInitializationFailed, - CompressionUnsupported, - }; - - /// Waits for a response from the server and parses any headers that are sent. - /// This function will block until the final response is received. - /// /// If handling redirects and the request has no payload, then this - /// function will automatically follow redirects. If a request payload is - /// present, then this function will error with - /// error.RedirectRequiresResend. + /// function will automatically follow redirects. /// - /// Must be called after `send` and, if any data was written to the request - /// body, then also after `finish`. - pub fn wait(req: *Request) WaitError!void { + /// If a request payload is present, then this function will error with + /// `error.RedirectRequiresResend`. + /// + /// This function takes an auxiliary buffer to store the arbitrarily large + /// URI which may need to be merged with the previous URI, and that data + /// needs to survive across different connections, which is where the input + /// buffer lives. + /// + /// `redirect_buffer` must outlive accesses to `Request.uri`. If this + /// buffer capacity would be exceeded, `error.HttpRedirectLocationOversize` + /// is returned instead. This buffer may be empty if no redirects are to be + /// handled. + pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response { + var aux_buf = redirect_buffer; while (true) { + try r.reader.receiveHead(); + const response: Response = .{ + .request = r, + .head = Response.Head.parse(r.reader.head_buffer) catch return error.HttpHeadersInvalid, + }; + const head = &response.head; + + if (head.status == .@"continue") { + if (r.handle_continue) continue; + return response; // we're not handling the 100-continue + } + // This while loop is for handling redirects, which means the request's // connection may be different than the previous iteration. However, it // is still guaranteed to be non-null with each iteration of this loop. - const connection = req.connection.?; + const connection = r.connection.?; - while (true) { // read headers - try connection.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); - connection.drop(@intCast(nchecked)); - - if (req.response.parser.state.isContent()) break; - } - - try req.response.parse(req.response.parser.get()); - - if (req.response.status == .@"continue") { - // We're done parsing the continue response; reset to prepare - // for the real response. - req.response.parser.done = true; - req.response.parser.reset(); - - if (req.handle_continue) - continue; - - return; // we're not handling the 100-continue - } - - // we're switching protocols, so this connection is no longer doing http - if (req.method == .CONNECT and req.response.status.class() == .success) { + if (r.method == .CONNECT and head.status.class() == .success) { + // This connection is no longer doing HTTP. connection.closing = false; - req.response.parser.done = true; - return; // the connection is not HTTP past this point + return response; } - connection.closing = !req.response.keep_alive or !req.keep_alive; + connection.closing = !head.keep_alive or !r.keep_alive; // Any response to a HEAD request and any response with a 1xx // (Informational), 204 (No Content), or 304 (Not Modified) status // code is always terminated by the first empty line after the // header fields, regardless of the header fields present in the // message. - if (req.method == .HEAD or req.response.status.class() == .informational or - req.response.status == .no_content or req.response.status == .not_modified) + if (r.method == .HEAD or head.status.class() == .informational or + head.status == .no_content or head.status == .not_modified) { - req.response.parser.done = true; - return; // The response is empty; no further setup or redirection is necessary. + return response; } - switch (req.response.transfer_encoding) { - .none => { - if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; - - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); - } - }, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, - } - - if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { - // skip the body of the redirect response, this will at least - // leave the connection in a known good state. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary - - if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; - - const location = req.response.location orelse - return error.HttpRedirectLocationMissing; - - // This mutates the beginning of header_bytes_buffer and uses that - // for the backing memory of the returned Uri. - try req.redirect(req.uri.resolve_inplace( - location, - &req.response.parser.header_bytes_buffer, - ) catch |err| switch (err) { - error.UnexpectedCharacter, - error.InvalidFormat, - error.InvalidPort, - => return error.HttpRedirectLocationInvalid, - error.NoSpaceLeft => return error.HttpHeadersOversize, - }); - try req.send(); - } else { - req.response.skip = false; - if (!req.response.parser.done) { - switch (req.response.transfer_compression) { - .identity => req.response.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - // I'm about to upstream my http.Client rewrite - .deflate => return error.CompressionUnsupported, - // I'm about to upstream my http.Client rewrite - .gzip, .@"x-gzip" => return error.CompressionUnsupported, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => req.response.compression = .{ - // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), - //}, - .zstd => return error.CompressionUnsupported, - } + if (head.status.class() == .redirect and r.redirect_behavior != .unhandled) { + if (r.redirect_behavior == .not_allowed) { + // Connection can still be reused by skipping the body. + const reader = r.reader.bodyReader(&.{}, head.transfer_encoding, head.content_length); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => connection.closing = true, + }; + return error.TooManyHttpRedirects; } - - break; + try r.redirect(head, &aux_buf); + try r.sendBodiless(); + continue; } + + if (!r.accept_encoding[@intFromEnum(head.content_encoding)]) + return error.HttpContentEncodingUnsupported; + + return response; } } - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || - error{ DecompressionFailure, InvalidTrailers }; - - pub const Reader = std.io.GenericReader(*Request, ReadError, read); - - pub fn reader(req: *Request) Reader { - return .{ .context = req }; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - const out_index = switch (req.response.compression) { - // I'm about to upstream my http client rewrite - //.deflate => |*deflate| deflate.readSlice(buffer) catch return error.DecompressionFailure, - //.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try req.transferRead(buffer), + /// This function takes an auxiliary buffer to store the arbitrarily large + /// URI which may need to be merged with the previous URI, and that data + /// needs to survive across different connections, which is where the input + /// buffer lives. + /// + /// `aux_buf` must outlive accesses to `Request.uri`. + fn redirect(r: *Request, head: *const Response.Head, aux_buf: *[]u8) !void { + const new_location = head.location orelse return error.HttpRedirectLocationMissing; + if (new_location.len > aux_buf.*.len) return error.HttpRedirectLocationOversize; + const location = aux_buf.*[0..new_location.len]; + @memcpy(location, new_location); + { + // Skip the body of the redirect response to leave the connection in + // the correct state. This causes `new_location` to be invalidated. + const reader = r.reader.bodyReader(&.{}, head.transfer_encoding, head.content_length); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => return r.reader.body_err.?, + }; + r.reader.restituteHeadBuffer(); + } + const new_uri = r.uri.resolveInPlace(location.len, aux_buf) catch |err| switch (err) { + error.UnexpectedCharacter => return error.HttpRedirectLocationInvalid, + error.InvalidFormat => return error.HttpRedirectLocationInvalid, + error.InvalidPort => return error.HttpRedirectLocationInvalid, + error.NoSpaceLeft => return error.HttpRedirectLocationOversize, }; - if (out_index > 0) return out_index; - while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.fill(); + const protocol = Protocol.fromUri(new_uri) orelse return error.UnsupportedUriScheme; + const old_connection = r.connection.?; + const old_host = old_connection.host(); + var new_host_name_buffer: [Uri.host_name_max]u8 = undefined; + const new_host = try new_uri.getHost(&new_host_name_buffer); + const keep_privileged_headers = + std.ascii.eqlIgnoreCase(r.uri.scheme, new_uri.scheme) and + sameParentDomain(old_host, new_host); - const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); + r.client.connection_pool.release(old_connection); + r.connection = null; + + if (!keep_privileged_headers) { + // When redirecting to a different domain, strip privileged headers. + r.privileged_headers = &.{}; } - return 0; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; + if (switch (head.status) { + .see_other => true, + .moved_permanently, .found => r.method == .POST, + else => false, + }) { + // A redirect to a GET must change the method and remove the body. + r.method = .GET; + r.transfer_encoding = .none; + r.headers.content_type = .omit; } - return index; + + if (r.transfer_encoding != .none) { + // The request body has already been sent. The request is + // still in a valid state, but the redirect must be handled + // manually. + return error.RedirectRequiresResend; + } + + const new_connection = try r.client.connect(new_host, uriPort(new_uri, protocol), protocol); + r.uri = new_uri; + r.connection = new_connection; + r.reader = .{ + .in = new_connection.reader(), + .state = .ready, + // Populated when `http.Reader.bodyReader` is called. + .interface = undefined, + }; + r.redirect_behavior.subtractOne(); } - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - - pub const Writer = std.io.GenericWriter(*Request, WriteError, write); - - pub fn writer(req: *Request) Writer { - return .{ .context = req }; - } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(req: *Request, bytes: []const u8) WriteError!usize { - switch (req.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.writer().writeAll(bytes); - try req.connection.?.writer().writeAll("\r\n"); - } - - return bytes.len; + /// Returns true if the default behavior is required, otherwise handles + /// writing (or not writing) the header. + fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, bw: *Writer) Writer.Error!bool { + switch (v) { + .default => return true, + .omit => return false, + .override => |x| { + var vecs: [3][]const u8 = .{ prefix, x, "\r\n" }; + try bw.writeVecAll(&vecs); + return false; }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try req.connection.?.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, } } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } - } - - pub const FinishError = WriteError || error{MessageNotCompleted}; - - /// Finish the body of a request. This notifies the server that you have no more data to send. - /// Must be called after `send`. - pub fn finish(req: *Request) FinishError!void { - switch (req.transfer_encoding) { - .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } - - try req.connection.?.flush(); - } }; pub const Proxy = struct { - protocol: Connection.Protocol, + protocol: Protocol, host: []const u8, authorization: ?[]const u8, port: u16, @@ -1204,10 +1207,8 @@ pub const Proxy = struct { pub fn deinit(client: *Client) void { assert(client.connection_pool.used.first == null); // There are still active requests. - client.connection_pool.deinit(client.allocator); - - if (!disable_tls) - client.ca_bundle.deinit(client.allocator); + client.connection_pool.deinit(); + if (!disable_tls) client.ca_bundle.deinit(client.allocator); client.* = undefined; } @@ -1249,24 +1250,21 @@ fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !? } else return null; const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content); - const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) { - error.UnsupportedUriScheme => return null, - error.UriMissingHost => return error.HttpProxyMissingHost, - error.OutOfMemory => |e| return e, - }; + const protocol = Protocol.fromUri(uri) orelse return null; + const raw_host = try uri.getHostAlloc(arena); - const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: { - const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri)); - assert(basic_authorization.value(valid_uri, authorization).len == authorization.len); + const authorization: ?[]const u8 = if (uri.user != null or uri.password != null) a: { + const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(uri)); + assert(basic_authorization.value(uri, authorization).len == authorization.len); break :a authorization; } else null; const proxy = try arena.create(Proxy); proxy.* = .{ .protocol = protocol, - .host = valid_uri.host.?.raw, + .host = raw_host, .authorization = authorization, - .port = uriPort(valid_uri, protocol), + .port = uriPort(uri, protocol), .supports_connect = true, }; return proxy; @@ -1277,10 +1275,8 @@ pub const basic_authorization = struct { pub const max_password_len = 255; pub const max_value_len = valueLength(max_user_len, max_password_len); - const prefix = "Basic "; - pub fn valueLength(user_len: usize, password_len: usize) usize { - return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); + return "Basic ".len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); } pub fn valueLengthFromUri(uri: Uri) usize { @@ -1300,37 +1296,69 @@ pub const basic_authorization = struct { } pub fn value(uri: Uri, out: []u8) []u8 { - const user: Uri.Component = uri.user orelse .empty; - const password: Uri.Component = uri.password orelse .empty; + var bw: Writer = .fixed(out); + write(uri, &bw) catch unreachable; + return bw.getWritten(); + } + pub fn write(uri: Uri, out: *Writer) Writer.Error!void { var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; - var w: std.io.Writer = .fixed(&buf); - user.formatUser(&w) catch unreachable; // fixed - password.formatPassword(&w) catch unreachable; // fixed - - @memcpy(out[0..prefix.len], prefix); - const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], w.buffered()); - return out[0 .. prefix.len + base64.len]; + var w: Writer = .fixed(&buf); + w.print("{fuser}:{fpassword}", .{ + uri.user orelse Uri.Component.empty, + uri.password orelse Uri.Component.empty, + }) catch unreachable; + try out.print("Basic {b64}", .{w.buffered()}); } }; -pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; +pub const ConnectTcpError = Allocator.Error || error{ + ConnectionRefused, + NetworkUnreachable, + ConnectionTimedOut, + ConnectionResetByPeer, + TemporaryNameServerFailure, + NameServerFailure, + UnknownHostName, + HostLacksNetworkAddresses, + UnexpectedConnectFailure, + TlsInitializationFailed, +}; -/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// Reuses a `Connection` if one matching `host` and `port` is already open. /// -/// This function is threadsafe. -pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { +/// Threadsafe. +pub fn connectTcp( + client: *Client, + host: []const u8, + port: u16, + protocol: Protocol, +) ConnectTcpError!*Connection { + return connectTcpOptions(client, .{ .host = host, .port = port, .protocol = protocol }); +} + +pub const ConnectTcpOptions = struct { + host: []const u8, + port: u16, + protocol: Protocol, + + proxied_host: ?[]const u8 = null, + proxied_port: ?u16 = null, +}; + +pub fn connectTcpOptions(client: *Client, options: ConnectTcpOptions) ConnectTcpError!*Connection { + const host = options.host; + const port = options.port; + const protocol = options.protocol; + + const proxied_host = options.proxied_host orelse host; + const proxied_port = options.proxied_port orelse port; + if (client.connection_pool.findConnection(.{ - .host = host, - .port = port, + .host = proxied_host, + .port = proxied_port, .protocol = protocol, - })) |node| return node; - - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - - const conn = try client.allocator.create(Connection); - errdefer client.allocator.destroy(conn); + })) |conn| return conn; const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { error.ConnectionRefused => return error.ConnectionRefused, @@ -1345,53 +1373,19 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec }; errdefer stream.close(); - conn.* = .{ - .stream = stream, - .tls_client = undefined, - - .protocol = protocol, - .host = try client.allocator.dupe(u8, host), - .port = port, - - .pool_node = .{}, - }; - errdefer client.allocator.free(conn.host); - - if (protocol == .tls) { - if (disable_tls) unreachable; - - conn.tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(conn.tls_client); - - const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: { - const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) { - error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null, - error.OutOfMemory => return error.OutOfMemory, - }; - defer client.allocator.free(ssl_key_log_path); - break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ - .truncate = false, - .mode = switch (builtin.os.tag) { - .windows, .wasi => 0, - else => 0o600, - }, - }) catch null; - } else null; - errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close(); - - conn.tls_client.* = std.crypto.tls.Client.init(stream, .{ - .host = .{ .explicit = host }, - .ca = .{ .bundle = client.ca_bundle }, - .ssl_key_log_file = ssl_key_log_file, - }) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.tls_client.allow_truncation_attacks = true; + switch (protocol) { + .tls => { + if (disable_tls) return error.TlsInitializationFailed; + const tc = try Connection.Tls.create(client, proxied_host, proxied_port, stream); + client.connection_pool.addUsed(&tc.connection); + return &tc.connection; + }, + .plain => { + const pc = try Connection.Plain.create(client, proxied_host, proxied_port, stream); + client.connection_pool.addUsed(&pc.connection); + return &pc.connection; + }, } - - client.connection_pool.addUsed(conn); - - return conn; } pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError; @@ -1429,69 +1423,67 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti return &conn.data; } -/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP +/// Connect to `proxied_host:proxied_port` using the specified proxy with HTTP /// CONNECT. This will reuse a connection if one is already open. /// /// This function is threadsafe. -pub fn connectTunnel( +pub fn connectProxied( client: *Client, proxy: *Proxy, - tunnel_host: []const u8, - tunnel_port: u16, + proxied_host: []const u8, + proxied_port: u16, ) !*Connection { if (!proxy.supports_connect) return error.TunnelNotSupported; if (client.connection_pool.findConnection(.{ - .host = tunnel_host, - .port = tunnel_port, + .host = proxied_host, + .port = proxied_port, .protocol = proxy.protocol, - })) |node| - return node; + })) |node| return node; var maybe_valid = false; (tunnel: { - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + const connection = try client.connectTcpOptions(.{ + .host = proxy.host, + .port = proxy.port, + .protocol = proxy.protocol, + .proxied_host = proxied_host, + .proxied_port = proxied_port, + }); errdefer { - conn.closing = true; - client.connection_pool.release(client.allocator, conn); + connection.closing = true; + client.connection_pool.release(connection); } - var buffer: [8096]u8 = undefined; - var req = client.open(.CONNECT, .{ + var req = client.request(.CONNECT, .{ .scheme = "http", - .host = .{ .raw = tunnel_host }, - .port = tunnel_port, + .host = .{ .raw = proxied_host }, + .port = proxied_port, }, .{ .redirect_behavior = .unhandled, - .connection = conn, - .server_header_buffer = &buffer, + .connection = connection, }) catch |err| { - std.log.debug("err {}", .{err}); break :tunnel err; }; defer req.deinit(); - req.send() catch |err| break :tunnel err; - req.wait() catch |err| break :tunnel err; + req.sendBodiless() catch |err| break :tunnel err; + const response = req.receiveHead(&.{}) catch |err| break :tunnel err; - if (req.response.status.class() == .server_error) { + if (response.head.status.class() == .server_error) { maybe_valid = true; break :tunnel error.ServerError; } - if (req.response.status != .ok) break :tunnel error.ConnectionRefused; + if (response.head.status != .ok) break :tunnel error.ConnectionRefused; - // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. + // this connection is now a tunnel, so we can't use it for anything + // else, it will only be released when the client is de-initialized. req.connection = null; - client.allocator.free(conn.host); - conn.host = try client.allocator.dupe(u8, tunnel_host); - errdefer client.allocator.free(conn.host); + connection.closing = false; - conn.port = tunnel_port; - conn.closing = false; - - return conn; + return connection; }) catch { // something went wrong with the tunnel proxy.supports_connect = maybe_valid; @@ -1499,12 +1491,11 @@ pub fn connectTunnel( }; } -// Prevents a dependency loop in open() -const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused }; -pub const ConnectError = ConnectErrorPartial || RequestError; +pub const ConnectError = ConnectTcpError || RequestError; /// Connect to `host:port` using the specified protocol. This will reuse a /// connection if one is already open. +/// /// If a proxy is configured for the client, then the proxy will be used to /// connect to the host. /// @@ -1513,7 +1504,7 @@ pub fn connect( client: *Client, host: []const u8, port: u16, - protocol: Connection.Protocol, + protocol: Protocol, ) ConnectError!*Connection { const proxy = switch (protocol) { .plain => client.http_proxy, @@ -1528,32 +1519,24 @@ pub fn connect( } if (proxy.supports_connect) tunnel: { - return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + return connectProxied(client, proxy, host, port) catch |err| switch (err) { error.TunnelNotSupported => break :tunnel, else => |e| return e, }; } // fall back to using the proxy as a normal http proxy - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(conn); - } - - conn.proxied = true; - return conn; + const connection = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + connection.proxied = true; + return connection; } -pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || - std.fmt.ParseIntError || Connection.WriteError || - error{ - UnsupportedUriScheme, - UriMissingHost, - - CertificateBundleLoadFailure, - UnsupportedTransferEncoding, - }; +pub const RequestError = ConnectTcpError || error{ + UnsupportedUriScheme, + UriMissingHost, + UriHostTooLong, + CertificateBundleLoadFailure, +}; pub const RequestOptions = struct { version: http.Version = .@"HTTP/1.1", @@ -1578,11 +1561,6 @@ pub const RequestOptions = struct { /// payload or the server has acknowledged the payload). redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), - /// Externally-owned memory used to store the server's entire HTTP header. - /// `error.HttpHeadersOversize` is returned from read() when a - /// client sends too many bytes of HTTP headers. - server_header_buffer: []u8, - /// Must be an already acquired connection. connection: ?*Connection = null, @@ -1598,38 +1576,17 @@ pub const RequestOptions = struct { privileged_headers: []const http.Header = &.{}, }; -fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } { - const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{ - .{ "http", .plain }, - .{ "ws", .plain }, - .{ "https", .tls }, - .{ "wss", .tls }, - }); - const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme; - var valid_uri = uri; - // The host is always going to be needed as a raw string for hostname resolution anyway. - valid_uri.host = .{ - .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena), - }; - return .{ protocol, valid_uri }; -} - -fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 { - return uri.port orelse switch (protocol) { - .plain => 80, - .tls => 443, - }; +fn uriPort(uri: Uri, protocol: Protocol) u16 { + return uri.port orelse protocol.port(); } /// Open a connection to the host specified by `uri` and prepare to send a HTTP request. /// -/// `uri` must remain alive during the entire request. -/// /// The caller is responsible for calling `deinit()` on the `Request`. /// This function is threadsafe. /// /// Asserts that "\r\n" does not occur in any header name or value. -pub fn open( +pub fn request( client: *Client, method: http.Method, uri: Uri, @@ -1649,59 +1606,58 @@ pub fn open( } } - var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer); - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + const protocol = Protocol.fromUri(uri) orelse return error.UnsupportedUriScheme; - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + if (protocol == .tls) { if (disable_tls) unreachable; + if (@atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch - return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + if (client.next_https_rescan_certs) { + client.ca_bundle.rescan(client.allocator) catch + return error.CertificateBundleLoadFailure; + @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + } } } - const conn = options.connection orelse - try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol); + const connection = options.connection orelse c: { + var host_name_buffer: [Uri.host_name_max]u8 = undefined; + const host_name = try uri.getHost(&host_name_buffer); + break :c try client.connect(host_name, uriPort(uri, protocol), protocol); + }; - var req: Request = .{ - .uri = valid_uri, + return .{ + .uri = uri, .client = client, - .connection = conn, + .connection = connection, + .reader = .{ + .in = connection.reader(), + .state = .ready, + // Populated when `http.Reader.bodyReader` is called. + .interface = undefined, + }, .keep_alive = options.keep_alive, .method = method, .version = options.version, .transfer_encoding = .none, .redirect_behavior = options.redirect_behavior, .handle_continue = options.handle_continue, - .response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = .init(server_header.buffer[server_header.end_index..]), - }, .headers = options.headers, .extra_headers = options.extra_headers, .privileged_headers = options.privileged_headers, }; - errdefer req.deinit(); - - return req; } pub const FetchOptions = struct { - server_header_buffer: ?[]u8 = null, + /// `null` means it will be heap-allocated. + redirect_buffer: ?[]u8 = null, + /// `null` means it will be heap-allocated. + decompress_buffer: ?[]u8 = null, redirect_behavior: ?Request.RedirectBehavior = null, - - /// If the server sends a body, it will be appended to this ArrayList. - /// `max_append_size` provides an upper limit for how much they can grow. - response_storage: ResponseStorage = .ignore, - max_append_size: ?usize = null, + /// If the server sends a body, it will be stored here. + response_storage: ?ResponseStorage = null, location: Location, method: ?http.Method = null, @@ -1725,11 +1681,11 @@ pub const FetchOptions = struct { uri: Uri, }; - pub const ResponseStorage = union(enum) { - ignore, - /// Only the existing capacity will be used. - static: *std.ArrayListUnmanaged(u8), - dynamic: *std.ArrayList(u8), + pub const ResponseStorage = struct { + list: *std.ArrayListUnmanaged(u8), + /// If null then only the existing capacity will be used. + allocator: ?Allocator = null, + append_limit: std.io.Limit = .unlimited, }; }; @@ -1737,23 +1693,28 @@ pub const FetchResult = struct { status: http.Status, }; +pub const FetchError = Uri.ParseError || RequestError || Request.ReceiveHeadError || error{ + StreamTooLong, + /// TODO provide optional diagnostics when this occurs or break into more error codes + WriteFailed, +}; + /// Perform a one-shot HTTP request with the provided options. /// /// This function is threadsafe. -pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { +pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult { const uri = switch (options.location) { .url => |u| try Uri.parse(u), .uri => |u| u, }; - var server_header_buffer: [16 * 1024]u8 = undefined; - const method: http.Method = options.method orelse if (options.payload != null) .POST else .GET; - var req = try open(client, method, uri, .{ - .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, - .redirect_behavior = options.redirect_behavior orelse - if (options.payload == null) @enumFromInt(3) else .unhandled, + const redirect_behavior: Request.RedirectBehavior = options.redirect_behavior orelse + if (options.payload == null) @enumFromInt(3) else .unhandled; + + var req = try request(client, method, uri, .{ + .redirect_behavior = redirect_behavior, .headers = options.headers, .extra_headers = options.extra_headers, .privileged_headers = options.privileged_headers, @@ -1761,44 +1722,69 @@ pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { }); defer req.deinit(); - if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; - - try req.send(); - - if (options.payload) |payload| try req.writeAll(payload); - - try req.finish(); - try req.wait(); - - switch (options.response_storage) { - .ignore => { - // Take advantage of request internals to discard the response body - // and make the connection available for another request. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. - }, - .dynamic => |list| { - const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; - try req.reader().readAllArrayList(list, max_append_size); - }, - .static => |list| { - const buf = b: { - const buf = list.unusedCapacitySlice(); - if (options.max_append_size) |len| { - if (len < buf.len) break :b buf[0..len]; - } - break :b buf; - }; - list.items.len += try req.reader().readAll(buf); - }, + if (options.payload) |payload| { + req.transfer_encoding = .{ .content_length = payload.len }; + var body = try req.sendBody(&.{}); + try body.writer.writeAll(payload); + try body.end(); + } else { + try req.sendBodiless(); } - return .{ - .status = req.response.status, + const redirect_buffer: []u8 = if (redirect_behavior == .unhandled) &.{} else options.redirect_buffer orelse + try client.allocator.alloc(u8, 8 * 1024); + defer if (options.redirect_buffer == null) client.allocator.free(redirect_buffer); + + var response = try req.receiveHead(redirect_buffer); + + const storage = options.response_storage orelse { + const reader = response.reader(&.{}); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => return response.bodyErr().?, + }; + return .{ .status = response.head.status }; }; + + const decompress_buffer: []u8 = switch (response.head.content_encoding) { + .identity => &.{}, + .zstd => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.zstd.default_window_len), + else => options.decompress_buffer orelse try client.allocator.alloc(u8, 8 * 1024), + }; + defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer); + + var decompressor: http.Decompressor = undefined; + const reader = response.readerDecompressing(&decompressor, decompress_buffer); + const list = storage.list; + + if (storage.allocator) |allocator| { + reader.appendRemaining(allocator, null, list, storage.append_limit) catch |err| switch (err) { + error.ReadFailed => return response.bodyErr().?, + else => |e| return e, + }; + } else { + const buf = storage.append_limit.slice(list.unusedCapacitySlice()); + list.items.len += reader.readSliceShort(buf) catch |err| switch (err) { + error.ReadFailed => return response.bodyErr().?, + }; + } + + return .{ .status = response.head.status }; +} + +pub fn sameParentDomain(parent_host: []const u8, child_host: []const u8) bool { + if (!std.ascii.endsWithIgnoreCase(child_host, parent_host)) return false; + if (child_host.len == parent_host.len) return true; + if (parent_host.len > child_host.len) return false; + return child_host[child_host.len - parent_host.len - 1] == '.'; +} + +test sameParentDomain { + try testing.expect(!sameParentDomain("foo.com", "bar.com")); + try testing.expect(sameParentDomain("foo.com", "foo.com")); + try testing.expect(sameParentDomain("foo.com", "bar.foo.com")); + try testing.expect(!sameParentDomain("bar.foo.com", "foo.com")); } test { _ = Response; - _ = &initDefaultProxies; } diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 7ec5d5c11f..004741d1ae 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,139 +1,70 @@ -//! Blocking HTTP server implementation. -//! Handles a single connection's lifecycle. +//! Handles a single connection lifecycle. -connection: net.Server.Connection, -/// Keeps track of whether the Server is ready to accept a new request on the -/// same connection, and makes invalid API usage cause assertion failures -/// rather than HTTP protocol violations. -state: State, -/// User-provided buffer that must outlive this Server. -/// Used to store the client's entire HTTP header. -read_buffer: []u8, -/// Amount of available data inside read_buffer. -read_buffer_len: usize, -/// Index into `read_buffer` of the first byte of the next HTTP request. -next_request_start: usize, +const std = @import("../std.zig"); +const http = std.http; +const mem = std.mem; +const Uri = std.Uri; +const assert = std.debug.assert; +const testing = std.testing; +const Writer = std.io.Writer; -pub const State = enum { - /// The connection is available to be used for the first time, or reused. - ready, - /// An error occurred in `receiveHead`. - receiving_head, - /// A Request object has been obtained and from there a Response can be - /// opened. - received_head, - /// The client is uploading something to this Server. - receiving_body, - /// The connection is eligible for another HTTP request, however the client - /// and server did not negotiate a persistent connection. - closing, -}; +const Server = @This(); + +/// Data from the HTTP server to the HTTP client. +out: *Writer, +reader: http.Reader, /// Initialize an HTTP server that can respond to multiple requests on the same /// connection. +/// +/// The buffer of `in` must be large enough to store the client's entire HTTP +/// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`. +/// /// The returned `Server` is ready for `receiveHead` to be called. -pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { +pub fn init(in: *std.io.Reader, out: *Writer) Server { return .{ - .connection = connection, - .state = .ready, - .read_buffer = read_buffer, - .read_buffer_len = 0, - .next_request_start = 0, + .reader = .{ + .in = in, + .state = .ready, + // Populated when `http.Reader.bodyReader` is called. + .interface = undefined, + }, + .out = out, }; } -pub const ReceiveHeadError = error{ - /// Client sent too many bytes of HTTP headers. - /// The HTTP specification suggests to respond with a 431 status code - /// before closing the connection. - HttpHeadersOversize, - /// Client sent headers that did not conform to the HTTP protocol. - HttpHeadersInvalid, - /// A low level I/O error occurred trying to read the headers. - HttpHeadersUnreadable, - /// Partial HTTP request was received but the connection was closed before - /// fully receiving the headers. - HttpRequestTruncated, - /// The client sent 0 bytes of headers before closing the stream. - /// In other words, a keep-alive connection was finally closed. - HttpConnectionClosing, -}; - -/// The header bytes reference the read buffer that Server was initialized with -/// and remain alive until the next call to receiveHead. -pub fn receiveHead(s: *Server) ReceiveHeadError!Request { - assert(s.state == .ready); - s.state = .received_head; - errdefer s.state = .receiving_head; - - // In case of a reused connection, move the next request's bytes to the - // beginning of the buffer. - if (s.next_request_start > 0) { - if (s.read_buffer_len > s.next_request_start) { - rebase(s, 0); - } else { - s.read_buffer_len = 0; - } - } - - var hp: http.HeadParser = .{}; - - if (s.read_buffer_len > 0) { - const bytes = s.read_buffer[0..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, end); - } - - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = s.connection.stream.read(buf) catch - return error.HttpHeadersUnreadable; - if (read_n == 0) { - if (s.read_buffer_len > 0) { - return error.HttpRequestTruncated; - } else { - return error.HttpConnectionClosing; - } - } - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); - } +pub fn deinit(s: *Server) void { + s.reader.restituteHeadBuffer(); } -fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { +pub const ReceiveHeadError = http.Reader.HeadError || error{ + /// Client sent headers that did not conform to the HTTP protocol. + /// + /// To find out more detailed diagnostics, `http.Reader.head_buffer` can be + /// passed directly to `Request.Head.parse`. + HttpHeadersInvalid, +}; + +pub fn receiveHead(s: *Server) ReceiveHeadError!Request { + try s.reader.receiveHead(); return .{ .server = s, - .head_end = head_end, - .head = Request.Head.parse(s.read_buffer[0..head_end]) catch - return error.HttpHeadersInvalid, - .reader_state = undefined, + // No need to track the returned error here since users can repeat the + // parse with the header buffer to get detailed diagnostics. + .head = Request.Head.parse(s.reader.head_buffer) catch return error.HttpHeadersInvalid, }; } pub const Request = struct { server: *Server, - /// Index into Server's read_buffer. - head_end: usize, + /// Pointers in this struct are invalidated with the next call to + /// `receiveHead`. head: Head, - reader_state: union { - remaining_content_length: u64, - chunk_parser: http.ChunkParser, - }, + respond_err: ?RespondError = null, - pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); - - deflate: std.compress.flate.Decompress, - gzip: std.compress.flate.Decompress, - zstd: std.compress.zstd.Decompress, - none: void, + pub const RespondError = error{ + /// The request contained an `expect` header with an unrecognized value. + HttpExpectationFailed, }; pub const Head = struct { @@ -146,7 +77,6 @@ pub const Request = struct { transfer_encoding: http.TransferEncoding, transfer_compression: http.ContentEncoding, keep_alive: bool, - compression: Compression, pub const ParseError = error{ UnknownHttpMethod, @@ -200,7 +130,6 @@ pub const Request = struct { .@"HTTP/1.0" => false, .@"HTTP/1.1" => true, }, - .compression = .none, }; while (it.next()) |line| { @@ -230,7 +159,7 @@ pub const Request = struct { const trimmed = mem.trim(u8, header_value, " "); - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + if (http.ContentEncoding.fromString(trimmed)) |ce| { head.transfer_compression = ce; } else { return error.HttpTransferEncodingUnsupported; @@ -255,7 +184,7 @@ pub const Request = struct { if (next) |second| { const trimmed_second = mem.trim(u8, second, " "); - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (http.ContentEncoding.fromString(trimmed_second)) |transfer| { if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported head.transfer_compression = transfer; @@ -299,7 +228,8 @@ pub const Request = struct { }; pub fn iterateHeaders(r: *Request) http.HeaderIterator { - return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); + assert(r.server.reader.state == .received_head); + return http.HeaderIterator.init(r.server.reader.head_buffer); } test iterateHeaders { @@ -310,22 +240,19 @@ pub const Request = struct { "TRansfer-encoding:\tdeflate, chunked \r\n" ++ "connectioN:\t keep-alive \r\n\r\n"; - var read_buffer: [500]u8 = undefined; - @memcpy(read_buffer[0..request_bytes.len], request_bytes); - var server: Server = .{ - .connection = undefined, - .state = .ready, - .read_buffer = &read_buffer, - .read_buffer_len = request_bytes.len, - .next_request_start = 0, + .reader = .{ + .in = undefined, + .state = .received_head, + .head_buffer = @constCast(request_bytes), + .interface = undefined, + }, + .out = undefined, }; var request: Request = .{ .server = &server, - .head_end = request_bytes.len, .head = undefined, - .reader_state = undefined, }; var it = request.iterateHeaders(); @@ -384,16 +311,22 @@ pub const Request = struct { /// no error is surfaced. /// /// Asserts status is not `continue`. - /// Asserts there are at most 25 extra_headers. /// Asserts that "\r\n" does not occur in any header name or value. pub fn respond( request: *Request, content: []const u8, options: RespondOptions, - ) Response.WriteError!void { - const max_extra_headers = 25; + ) ExpectContinueError!void { + try respondUnflushed(request, content, options); + try request.server.out.flush(); + } + + pub fn respondUnflushed( + request: *Request, + content: []const u8, + options: RespondOptions, + ) ExpectContinueError!void { assert(options.status != .@"continue"); - assert(options.extra_headers.len <= max_extra_headers); if (std.debug.runtime_safety) { for (options.extra_headers) |header| { assert(header.name.len != 0); @@ -402,6 +335,7 @@ pub const Request = struct { assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); } } + try writeExpectContinue(request); const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none; const server_keep_alive = !transfer_encoding_none and options.keep_alive; @@ -409,130 +343,42 @@ pub const Request = struct { const phrase = options.reason orelse options.status.phrase() orelse ""; - var first_buffer: [500]u8 = undefined; - var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); - if (request.head.expect != null) { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - try request.server.connection.stream.writeAll(h.items); - return; - } - h.fixedWriter().print("{s} {d} {s}\r\n", .{ + const out = request.server.out; + try out.print("{s} {d} {s}\r\n", .{ @tagName(options.version), @intFromEnum(options.status), phrase, - }) catch unreachable; + }); switch (options.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), + .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"), } if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { .none => {}, - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + .chunked => try out.writeAll("transfer-encoding: chunked\r\n"), } else { - h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable; + try out.print("content-length: {d}\r\n", .{content.len}); } - var chunk_header_buffer: [18]u8 = undefined; - var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = h.items.ptr, - .len = h.items.len, - }; - iovecs_len += 1; - for (options.extra_headers) |header| { - iovecs[iovecs_len] = .{ - .base = header.name.ptr, - .len = header.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (header.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = header.value.ptr, - .len = header.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; + var vecs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" }; + try out.writeVecAll(&vecs); } - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; + try out.writeAll("\r\n"); if (request.head.method != .HEAD) { const is_chunked = (options.transfer_encoding orelse .none) == .chunked; if (is_chunked) { - if (content.len > 0) { - const chunk_header = std.fmt.bufPrint( - &chunk_header_buffer, - "{x}\r\n", - .{content.len}, - ) catch unreachable; - - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "0\r\n\r\n", - .len = 5, - }; - iovecs_len += 1; + if (content.len > 0) try out.print("{x}\r\n{s}\r\n", .{ content.len, content }); + try out.writeAll("0\r\n\r\n"); } else if (content.len > 0) { - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; + try out.writeAll(content); } } - - try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); } pub const RespondStreamingOptions = struct { - /// An externally managed slice of memory used to batch bytes before - /// sending. `respondStreaming` asserts this is large enough to store - /// the full HTTP response head. - /// - /// Must outlive the returned Response. - send_buffer: []u8, /// If provided, the response will use the content-length header; /// otherwise it will use transfer-encoding: chunked. content_length: ?u64 = null, @@ -540,254 +386,221 @@ pub const Request = struct { respond_options: RespondOptions = .{}, }; - /// The header is buffered but not sent until Response.flush is called. + /// The header is not guaranteed to be sent until `BodyWriter.flush` or + /// `BodyWriter.end` is called. /// /// If the request contains a body and the connection is to be reused, /// discards the request body, leaving the Server in the `ready` state. If /// this discarding fails, the connection is marked as not to be reused and /// no error is surfaced. /// - /// HEAD requests are handled transparently by setting a flag on the - /// returned Response to omit the body. However it may be worth noticing + /// HEAD requests are handled transparently by setting the + /// `BodyWriter.elide` flag on the returned `BodyWriter`, causing + /// the response stream to omit the body. However, it may be worth noticing /// that flag and skipping any expensive work that would otherwise need to /// be done to satisfy the request. /// - /// Asserts `send_buffer` is large enough to store the entire response header. /// Asserts status is not `continue`. - pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { + pub fn respondStreaming( + request: *Request, + buffer: []u8, + options: RespondStreamingOptions, + ) ExpectContinueError!http.BodyWriter { + try writeExpectContinue(request); const o = options.respond_options; assert(o.status != .@"continue"); const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none; const server_keep_alive = !transfer_encoding_none and o.keep_alive; const keep_alive = request.discardBody(server_keep_alive); const phrase = o.reason orelse o.status.phrase() orelse ""; + const out = request.server.out; - var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); + try out.print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }); - const elide_body = if (request.head.expect != null) eb: { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - break :eb true; - } else eb: { - h.fixedWriter().print("{s} {d} {s}\r\n", .{ - @tagName(o.version), @intFromEnum(o.status), phrase, - }) catch unreachable; + switch (o.version) { + .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"), + } - switch (o.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), - } + if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .chunked => try out.writeAll("transfer-encoding: chunked\r\n"), + .none => {}, + } else if (options.content_length) |len| { + try out.print("content-length: {d}\r\n", .{len}); + } else { + try out.writeAll("transfer-encoding: chunked\r\n"); + } - if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), - .none => {}, - } else if (options.content_length) |len| { - h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; - } else { - h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); - } + for (o.extra_headers) |header| { + assert(header.name.len != 0); + try out.writeAll(header.name); + try out.writeAll(": "); + try out.writeAll(header.value); + try out.writeAll("\r\n"); + } - for (o.extra_headers) |header| { - assert(header.name.len != 0); - h.appendSliceAssumeCapacity(header.name); - h.appendSliceAssumeCapacity(": "); - h.appendSliceAssumeCapacity(header.value); - h.appendSliceAssumeCapacity("\r\n"); - } + try out.writeAll("\r\n"); + const elide_body = request.head.method == .HEAD; + const state: http.BodyWriter.State = if (o.transfer_encoding) |te| switch (te) { + .chunked => .{ .chunked = .init }, + .none => .none, + } else if (options.content_length) |len| .{ + .content_length = len, + } else .{ .chunked = .init }; - h.appendSliceAssumeCapacity("\r\n"); - break :eb request.head.method == .HEAD; + return if (elide_body) .{ + .http_protocol_output = request.server.out, + .state = state, + .writer = .discarding(buffer), + } else .{ + .http_protocol_output = request.server.out, + .state = state, + .writer = .{ + .buffer = buffer, + .vtable = switch (state) { + .none => &.{ + .drain = http.BodyWriter.noneDrain, + .sendFile = http.BodyWriter.noneSendFile, + }, + .content_length => &.{ + .drain = http.BodyWriter.contentLengthDrain, + .sendFile = http.BodyWriter.contentLengthSendFile, + }, + .chunked => &.{ + .drain = http.BodyWriter.chunkedDrain, + .sendFile = http.BodyWriter.chunkedSendFile, + }, + .end => unreachable, + }, + }, }; + } + + pub const UpgradeRequest = union(enum) { + websocket: ?[]const u8, + other: []const u8, + none, + }; + + pub fn upgradeRequested(request: *const Request) UpgradeRequest { + switch (request.head.version) { + .@"HTTP/1.0" => return null, + .@"HTTP/1.1" => if (request.head.method != .GET) return null, + } + + var sec_websocket_key: ?[]const u8 = null; + var upgrade_name: ?[]const u8 = null; + var it = request.iterateHeaders(); + while (it.next()) |header| { + if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) { + sec_websocket_key = header.value; + } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) { + upgrade_name = header.value; + } + } + + const name = upgrade_name orelse return .none; + if (std.ascii.eqlIgnoreCase(name, "websocket")) return .{ .websocket = sec_websocket_key }; + return .{ .other = name }; + } + + pub const WebSocketOptions = struct { + /// The value from `UpgradeRequest.websocket` (sec-websocket-key header value). + key: []const u8, + reason: ?[]const u8 = null, + extra_headers: []const http.Header = &.{}, + }; + + /// The header is not guaranteed to be sent until `WebSocket.flush` is + /// called on the returned struct. + pub fn respondWebSocket(request: *Request, options: WebSocketOptions) Writer.Error!WebSocket { + if (request.head.expect != null) return error.HttpExpectationFailed; + + const out = request.server.out; + const version: http.Version = .@"HTTP/1.1"; + const status: http.Status = .switching_protocols; + const phrase = options.reason orelse status.phrase() orelse ""; + + assert(request.head.version == version); + assert(request.head.method == .GET); + + var sha1 = std.crypto.hash.Sha1.init(.{}); + sha1.update(options.key); + sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined; + sha1.final(&digest); + try out.print("{s} {d} {s}\r\n", .{ @tagName(version), @intFromEnum(status), phrase }); + try out.writeAll("connection: upgrade\r\nupgrade: websocket\r\nsec-websocket-accept: "); + const base64_digest = try out.writableArray(28); + assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len); + out.advance(base64_digest.len); + try out.writeAll("\r\n"); + + for (options.extra_headers) |header| { + assert(header.name.len != 0); + try out.writeAll(header.name); + try out.writeAll(": "); + try out.writeAll(header.value); + try out.writeAll("\r\n"); + } + + try out.writeAll("\r\n"); return .{ - .stream = request.server.connection.stream, - .send_buffer = options.send_buffer, - .send_buffer_start = 0, - .send_buffer_end = h.items.len, - .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { - .chunked => .chunked, - .none => .none, - } else if (options.content_length) |len| .{ - .content_length = len, - } else .chunked, - .elide_body = elide_body, - .chunk_len = 0, + .input = request.server.reader.in, + .output = request.server.out, + .key = options.key, }; } - pub const ReadError = net.Stream.ReadError || error{ - HttpChunkInvalid, - HttpHeadersOversize, - }; - - fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @constCast(@alignCast(@ptrCast(context))); - const s = request.server; - - const remaining_content_length = &request.reader_state.remaining_content_length; - if (remaining_content_length.* == 0) { - s.state = .ready; - return 0; - } - assert(s.state == .receiving_body); - const available = try fill(s, request.head_end); - const len = @min(remaining_content_length.*, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - remaining_content_length.* -= len; - s.next_request_start += len; - if (remaining_content_length.* == 0) - s.state = .ready; - return len; - } - - fn fill(s: *Server, head_end: usize) ReadError![]u8 { - const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; - if (available.len > 0) return available; - s.next_request_start = head_end; - s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); - return s.read_buffer[head_end..s.read_buffer_len]; - } - - fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @constCast(@alignCast(@ptrCast(context))); - const s = request.server; - - const cp = &request.reader_state.chunk_parser; - const head_end = request.head_end; - - // Protect against returning 0 before the end of stream. - var out_end: usize = 0; - while (out_end == 0) { - switch (cp.state) { - .invalid => return 0, - .data => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const len = @min(cp.chunk_len, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += len; - continue; - }, - else => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const n = cp.feed(available); - switch (cp.state) { - .invalid => return error.HttpChunkInvalid, - .data => { - if (cp.chunk_len == 0) { - // The next bytes in the stream are trailers, - // or \r\n to indicate end of chunked body. - // - // This function must append the trailers at - // head_end so that headers and trailers are - // together. - // - // Since returning 0 would indicate end of - // stream, this function must read all the - // trailers before returning. - if (s.next_request_start > head_end) rebase(s, head_end); - var hp: http.HeadParser = .{}; - { - const bytes = s.read_buffer[head_end..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = try s.connection.stream.read(buf); - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - } - const data = available[n..]; - const len = @min(cp.chunk_len, data.len, buffer.len); - @memcpy(buffer[0..len], data[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += n + len; - continue; - }, - else => continue, - } - }, - } - } - return out_end; - } - - pub const ReaderError = Response.WriteError || error{ - /// The client sent an expect HTTP header value other than - /// "100-continue". - HttpExpectationFailed, - }; - /// In the case that the request contains "expect: 100-continue", this /// function writes the continuation header, which means it can fail with a /// write error. After sending the continuation header, it sets the /// request's expect field to `null`. /// /// Asserts that this function is only called once. - pub fn reader(request: *Request) ReaderError!std.io.AnyReader { - const s = request.server; - assert(s.state == .received_head); - s.state = .receiving_body; - s.next_request_start = request.head_end; + /// + /// See `readerExpectNone` for an infallible alternative that cannot write + /// to the server output stream. + pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*std.io.Reader { + const flush = request.head.expect != null; + try writeExpectContinue(request); + if (flush) try request.server.out.flush(); + return readerExpectNone(request, buffer); + } - if (request.head.expect) |expect| { - if (mem.eql(u8, expect, "100-continue")) { - try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); - request.head.expect = null; - } else { - return error.HttpExpectationFailed; - } - } + /// Asserts the expect header is `null`. The caller must handle the + /// expectation manually and then set the value to `null` prior to calling + /// this function. + /// + /// Asserts that this function is only called once. + pub fn readerExpectNone(request: *Request, buffer: []u8) *std.io.Reader { + assert(request.server.reader.state == .received_head); + assert(request.head.expect == null); + if (!request.head.method.requestHasBody()) return .ending; + return request.server.reader.bodyReader(buffer, request.head.transfer_encoding, request.head.content_length); + } - switch (request.head.transfer_encoding) { - .chunked => { - request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; - return .{ - .readFn = read_chunked, - .context = request, - }; - }, - .none => { - request.reader_state = .{ - .remaining_content_length = request.head.content_length orelse 0, - }; - return .{ - .readFn = read_cl, - .context = request, - }; - }, - } + pub const ExpectContinueError = error{ + /// Failed to write "HTTP/1.1 100 Continue\r\n\r\n" to the stream. + WriteFailed, + /// The client sent an expect HTTP header value other than + /// "100-continue". + HttpExpectationFailed, + }; + + pub fn writeExpectContinue(request: *Request) ExpectContinueError!void { + const expect = request.head.expect orelse return; + if (!mem.eql(u8, expect, "100-continue")) return error.HttpExpectationFailed; + try request.server.out.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + request.head.expect = null; } /// Returns whether the connection should remain persistent. - /// If it would fail, it instead sets the Server state to `receiving_body` + /// + /// If it would fail, it instead sets the Server state to receiving body /// and returns false. fn discardBody(request: *Request, keep_alive: bool) bool { // Prepare to receive another request on the same connection. @@ -798,350 +611,175 @@ pub const Request = struct { // or the request body. // If the connection won't be kept alive, then none of this matters // because the connection will be severed after the response is sent. - const s = request.server; - if (keep_alive and request.head.keep_alive) switch (s.state) { + const r = &request.server.reader; + if (keep_alive and request.head.keep_alive) switch (r.state) { .received_head => { - const r = request.reader() catch return false; - _ = r.discard() catch return false; - assert(s.state == .ready); + if (request.head.method.requestHasBody()) { + assert(request.head.transfer_encoding != .none or request.head.content_length != null); + const reader_interface = request.readerExpectContinue(&.{}) catch return false; + _ = reader_interface.discardRemaining() catch return false; + assert(r.state == .ready); + } else { + r.state = .ready; + } return true; }, - .receiving_body, .ready => return true, + .body_remaining_content_length, .body_remaining_chunk_len, .body_none, .ready => return true, else => unreachable, }; // Avoid clobbering the state in case a reading stream already exists. - switch (s.state) { - .received_head => s.state = .closing, + switch (r.state) { + .received_head => r.state = .closing, else => {}, } return false; } }; -pub const Response = struct { - stream: net.Stream, - send_buffer: []u8, - /// Index of the first byte in `send_buffer`. - /// This is 0 unless a short write happens in `write`. - send_buffer_start: usize, - /// Index of the last byte + 1 in `send_buffer`. - send_buffer_end: usize, - /// `null` means transfer-encoding: chunked. - /// As a debugging utility, counts down to zero as bytes are written. - transfer_encoding: TransferEncoding, - elide_body: bool, - /// Indicates how much of the end of the `send_buffer` corresponds to a - /// chunk. This amount of data will be wrapped by an HTTP chunk header. - chunk_len: usize, +/// See https://tools.ietf.org/html/rfc6455 +pub const WebSocket = struct { + key: []const u8, + input: *std.io.Reader, + output: *Writer, - pub const TransferEncoding = union(enum) { - /// End of connection signals the end of the stream. - none, - /// As a debugging utility, counts down to zero as bytes are written. - content_length: u64, - /// Each chunk is wrapped in a header and trailer. - chunked, + pub const Header0 = packed struct(u8) { + opcode: Opcode, + rsv3: u1 = 0, + rsv2: u1 = 0, + rsv1: u1 = 0, + fin: bool, }; - pub const WriteError = net.Stream.WriteError; - - /// When using content-length, asserts that the amount of data sent matches - /// the value sent in the header, then calls `flush`. - /// Otherwise, transfer-encoding: chunked is being used, and it writes the - /// end-of-stream message, then flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn end(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .content_length => |len| { - assert(len == 0); // Trips when end() called before all bytes written. - try flush_cl(r); - }, - .none => { - try flush_cl(r); - }, - .chunked => { - try flush_chunked(r, &.{}); - }, - } - r.* = undefined; - } - - pub const EndChunkedOptions = struct { - trailers: []const http.Header = &.{}, + pub const Header1 = packed struct(u8) { + payload_len: enum(u7) { + len16 = 126, + len64 = 127, + _, + }, + mask: bool, }; - /// Asserts that the Response is using transfer-encoding: chunked. - /// Writes the end-of-stream message and any optional trailers, then - /// flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - /// Asserts there are at most 25 trailers. - pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { - assert(r.transfer_encoding == .chunked); - try flush_chunked(r, options.trailers); - r.* = undefined; - } + pub const Opcode = enum(u4) { + continuation = 0, + text = 1, + binary = 2, + connection_close = 8, + ping = 9, + /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional + /// heartbeat. A response to an unsolicited Pong frame is not expected." + pong = 10, + _, + }; - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - /// May return 0, which does not indicate end of stream. The caller decides - /// when the end of stream occurs by calling `end`. - pub fn write(r: *Response, bytes: []const u8) WriteError!usize { - switch (r.transfer_encoding) { - .content_length, .none => return write_cl(r, bytes), - .chunked => return write_chunked(r, bytes), - } - } + pub const ReadSmallTextMessageError = error{ + ConnectionClose, + UnexpectedOpCode, + MessageTooBig, + MissingMaskBit, + }; - fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @constCast(@alignCast(@ptrCast(context))); + pub const SmallMessage = struct { + /// Can be text, binary, or ping. + opcode: Opcode, + data: []u8, + }; - var trash: u64 = std.math.maxInt(u64); - const len = switch (r.transfer_encoding) { - .content_length => |*len| len, - else => &trash, - }; + /// Reads the next message from the WebSocket stream, failing if the + /// message does not fit into the input buffer. The returned memory points + /// into the input buffer and is invalidated on the next read. + pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage { + const in = ws.input; + while (true) { + const h0 = in.takeStruct(Header0); + const h1 = in.takeStruct(Header1); - if (r.elide_body) { - len.* -= bytes.len; - return bytes.len; - } - - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - var iovecs: [2]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - }; - const n = try r.stream.writev(&iovecs); - - if (n >= send_buffer_len) { - // It was enough to reset the buffer. - r.send_buffer_start = 0; - r.send_buffer_end = 0; - const bytes_n = n - send_buffer_len; - len.* -= bytes_n; - return bytes_n; + switch (h0.opcode) { + .text, .binary, .pong, .ping => {}, + .connection_close => return error.ConnectionClose, + .continuation => return error.UnexpectedOpCode, + _ => return error.UnexpectedOpCode, } - // It didn't even make it through the existing buffer, let - // alone the new bytes provided. - r.send_buffer_start += n; - return 0; - } + if (!h0.fin) return error.MessageTooBig; + if (!h1.mask) return error.MissingMaskBit; - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - len.* -= bytes.len; - return bytes.len; - } - - fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @constCast(@alignCast(@ptrCast(context))); - assert(r.transfer_encoding == .chunked); - - if (r.elide_body) - return bytes.len; - - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - const chunk_len = r.chunk_len + bytes.len; - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; - - var iovecs: [5]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len - r.chunk_len, - }, - .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }, - .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - .{ - .base = "\r\n", - .len = 2, - }, + const len: usize = switch (h1.payload_len) { + .len16 => try in.takeInt(u16, .big), + .len64 => std.math.cast(usize, try in.takeInt(u64, .big)) orelse return error.MessageTooBig, + else => @intFromEnum(h1.payload_len), }; - // TODO make this writev instead of writevAll, which involves - // complicating the logic of this function. - try r.stream.writevAll(&iovecs); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return bytes.len; - } + if (len > in.buffer.len) return error.MessageTooBig; + const mask: u32 = @bitCast((try in.takeArray(4)).*); + const payload = try in.take(len); - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - r.chunk_len += bytes.len; - return bytes.len; - } + // Skip pongs. + if (h0.opcode == .pong) continue; - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(r, bytes[index..]); + // The last item may contain a partial word of unused data. + const floored_len = (payload.len / 4) * 4; + const u32_payload: []align(1) u32 = @ptrCast(payload[0..floored_len]); + for (u32_payload) |*elem| elem.* ^= mask; + const mask_bytes: []const u8 = @ptrCast(&mask); + for (payload[floored_len..], mask_bytes[0 .. payload.len - floored_len]) |*leftover, m| + leftover.* ^= m; + + return .{ + .opcode = h0.opcode, + .data = payload, + }; } } - /// Sends all buffered data to the client. - /// This is redundant after calling `end`. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn flush(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .none, .content_length => return flush_cl(r), - .chunked => return flush_chunked(r, null), - } + pub fn writeMessage(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { + try writeMessageVecUnflushed(ws, &.{data}, op); + try ws.output.flush(); } - fn flush_cl(r: *Response) WriteError!void { - try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; + pub fn writeMessageUnflushed(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { + try writeMessageVecUnflushed(ws, &.{data}, op); } - fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { - const max_trailers = 25; - if (end_trailers) |trailers| assert(trailers.len <= max_trailers); - assert(r.transfer_encoding == .chunked); + pub fn writeMessageVec(ws: *WebSocket, data: []const []const u8, op: Opcode) Writer.Error!void { + try writeMessageVecUnflushed(ws, data, op); + try ws.output.flush(); + } - const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; - - if (r.elide_body) { - try r.stream.writeAll(http_headers); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return; - } - - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; - - var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = http_headers.ptr, - .len = http_headers.len, + pub fn writeMessageVecUnflushed(ws: *WebSocket, data: []const []const u8, op: Opcode) Writer.Error!void { + const total_len = l: { + var total_len: u64 = 0; + for (data) |iovec| total_len += iovec.len; + break :l total_len; }; - iovecs_len += 1; - - if (r.chunk_len > 0) { - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - if (end_trailers) |trailers| { - iovecs[iovecs_len] = .{ - .base = "0\r\n", - .len = 3, - }; - iovecs_len += 1; - - for (trailers) |trailer| { - iovecs[iovecs_len] = .{ - .base = trailer.name.ptr, - .len = trailer.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (trailer.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = trailer.value.ptr, - .len = trailer.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - try r.stream.writevAll(iovecs[0..iovecs_len]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - } - - pub fn writer(r: *Response) std.io.AnyWriter { - return .{ - .writeFn = switch (r.transfer_encoding) { - .none, .content_length => write_cl, - .chunked => write_chunked, + const out = ws.output; + try out.writeStruct(@as(Header0, .{ + .opcode = op, + .fin = true, + })); + switch (total_len) { + 0...125 => try out.writeStruct(@as(Header1, .{ + .payload_len = @enumFromInt(total_len), + .mask = false, + })), + 126...0xffff => { + try out.writeStruct(@as(Header1, .{ + .payload_len = .len16, + .mask = false, + })); + try out.writeInt(u16, @intCast(total_len), .big); }, - .context = r, - }; + else => { + try out.writeStruct(@as(Header1, .{ + .payload_len = .len64, + .mask = false, + })); + try out.writeInt(u64, total_len, .big); + }, + } + try out.writeVecAll(data); + } + + pub fn flush(ws: *WebSocket) Writer.Error!void { + try ws.output.flush(); } }; - -fn rebase(s: *Server, index: usize) void { - const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; - const dest = s.read_buffer[index..][0..leftover.len]; - if (leftover.len <= s.next_request_start - index) { - @memcpy(dest, leftover); - } else { - mem.copyBackwards(u8, dest, leftover); - } - s.read_buffer_len = index + leftover.len; -} - -const std = @import("../std.zig"); -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const assert = std.debug.assert; -const testing = std.testing; - -const Server = @This(); diff --git a/lib/std/http/WebSocket.zig b/lib/std/http/WebSocket.zig deleted file mode 100644 index b9a66cdbd6..0000000000 --- a/lib/std/http/WebSocket.zig +++ /dev/null @@ -1,246 +0,0 @@ -//! See https://tools.ietf.org/html/rfc6455 - -const builtin = @import("builtin"); -const std = @import("std"); -const WebSocket = @This(); -const assert = std.debug.assert; -const native_endian = builtin.cpu.arch.endian(); - -key: []const u8, -request: *std.http.Server.Request, -recv_fifo: std.fifo.LinearFifo(u8, .Slice), -reader: std.io.AnyReader, -response: std.http.Server.Response, -/// Number of bytes that have been peeked but not discarded yet. -outstanding_len: usize, - -pub const InitError = error{WebSocketUpgradeMissingKey} || - std.http.Server.Request.ReaderError; - -pub fn init( - request: *std.http.Server.Request, - send_buffer: []u8, - recv_buffer: []align(4) u8, -) InitError!?WebSocket { - switch (request.head.version) { - .@"HTTP/1.0" => return null, - .@"HTTP/1.1" => if (request.head.method != .GET) return null, - } - - var sec_websocket_key: ?[]const u8 = null; - var upgrade_websocket: bool = false; - var it = request.iterateHeaders(); - while (it.next()) |header| { - if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) { - sec_websocket_key = header.value; - } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) { - if (!std.ascii.eqlIgnoreCase(header.value, "websocket")) - return null; - upgrade_websocket = true; - } - } - if (!upgrade_websocket) - return null; - - const key = sec_websocket_key orelse return error.WebSocketUpgradeMissingKey; - - var sha1 = std.crypto.hash.Sha1.init(.{}); - sha1.update(key); - sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined; - sha1.final(&digest); - var base64_digest: [28]u8 = undefined; - assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len); - - request.head.content_length = std.math.maxInt(u64); - - return .{ - .key = key, - .recv_fifo = std.fifo.LinearFifo(u8, .Slice).init(recv_buffer), - .reader = try request.reader(), - .response = request.respondStreaming(.{ - .send_buffer = send_buffer, - .respond_options = .{ - .status = .switching_protocols, - .extra_headers = &.{ - .{ .name = "upgrade", .value = "websocket" }, - .{ .name = "connection", .value = "upgrade" }, - .{ .name = "sec-websocket-accept", .value = &base64_digest }, - }, - .transfer_encoding = .none, - }, - }), - .request = request, - .outstanding_len = 0, - }; -} - -pub const Header0 = packed struct(u8) { - opcode: Opcode, - rsv3: u1 = 0, - rsv2: u1 = 0, - rsv1: u1 = 0, - fin: bool, -}; - -pub const Header1 = packed struct(u8) { - payload_len: enum(u7) { - len16 = 126, - len64 = 127, - _, - }, - mask: bool, -}; - -pub const Opcode = enum(u4) { - continuation = 0, - text = 1, - binary = 2, - connection_close = 8, - ping = 9, - /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional - /// heartbeat. A response to an unsolicited Pong frame is not expected." - pong = 10, - _, -}; - -pub const ReadSmallTextMessageError = error{ - ConnectionClose, - UnexpectedOpCode, - MessageTooBig, - MissingMaskBit, -} || RecvError; - -pub const SmallMessage = struct { - /// Can be text, binary, or ping. - opcode: Opcode, - data: []u8, -}; - -/// Reads the next message from the WebSocket stream, failing if the message does not fit -/// into `recv_buffer`. -pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage { - while (true) { - const header_bytes = (try recv(ws, 2))[0..2]; - const h0: Header0 = @bitCast(header_bytes[0]); - const h1: Header1 = @bitCast(header_bytes[1]); - - switch (h0.opcode) { - .text, .binary, .pong, .ping => {}, - .connection_close => return error.ConnectionClose, - .continuation => return error.UnexpectedOpCode, - _ => return error.UnexpectedOpCode, - } - - if (!h0.fin) return error.MessageTooBig; - if (!h1.mask) return error.MissingMaskBit; - - const len: usize = switch (h1.payload_len) { - .len16 => try recvReadInt(ws, u16), - .len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig, - else => @intFromEnum(h1.payload_len), - }; - if (len > ws.recv_fifo.buf.len) return error.MessageTooBig; - - const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*); - const payload = try recv(ws, len); - - // Skip pongs. - if (h0.opcode == .pong) continue; - - // The last item may contain a partial word of unused data. - const floored_len = (payload.len / 4) * 4; - const u32_payload: []align(1) u32 = @alignCast(std.mem.bytesAsSlice(u32, payload[0..floored_len])); - for (u32_payload) |*elem| elem.* ^= mask; - const mask_bytes = std.mem.asBytes(&mask)[0 .. payload.len - floored_len]; - for (payload[floored_len..], mask_bytes) |*leftover, m| leftover.* ^= m; - - return .{ - .opcode = h0.opcode, - .data = payload, - }; - } -} - -const RecvError = std.http.Server.Request.ReadError || error{EndOfStream}; - -fn recv(ws: *WebSocket, len: usize) RecvError![]u8 { - ws.recv_fifo.discard(ws.outstanding_len); - assert(len <= ws.recv_fifo.buf.len); - if (len > ws.recv_fifo.count) { - const small_buf = ws.recv_fifo.writableSlice(0); - const needed = len - ws.recv_fifo.count; - const buf = if (small_buf.len >= needed) small_buf else b: { - ws.recv_fifo.realign(); - break :b ws.recv_fifo.writableSlice(0); - }; - const n = try @as(RecvError!usize, @errorCast(ws.reader.readAtLeast(buf, needed))); - if (n < needed) return error.EndOfStream; - ws.recv_fifo.update(n); - } - ws.outstanding_len = len; - // TODO: improve the std lib API so this cast isn't necessary. - return @constCast(ws.recv_fifo.readableSliceOfLen(len)); -} - -fn recvReadInt(ws: *WebSocket, comptime I: type) !I { - const unswapped: I = @bitCast((try recv(ws, @sizeOf(I)))[0..@sizeOf(I)].*); - return switch (native_endian) { - .little => @byteSwap(unswapped), - .big => unswapped, - }; -} - -pub const WriteError = std.http.Server.Response.WriteError; - -pub fn writeMessage(ws: *WebSocket, message: []const u8, opcode: Opcode) WriteError!void { - const iovecs: [1]std.posix.iovec_const = .{ - .{ .base = message.ptr, .len = message.len }, - }; - return writeMessagev(ws, &iovecs, opcode); -} - -pub fn writeMessagev(ws: *WebSocket, message: []const std.posix.iovec_const, opcode: Opcode) WriteError!void { - const total_len = l: { - var total_len: u64 = 0; - for (message) |iovec| total_len += iovec.len; - break :l total_len; - }; - - var header_buf: [2 + 8]u8 = undefined; - header_buf[0] = @bitCast(@as(Header0, .{ - .opcode = opcode, - .fin = true, - })); - const header = switch (total_len) { - 0...125 => blk: { - header_buf[1] = @bitCast(@as(Header1, .{ - .payload_len = @enumFromInt(total_len), - .mask = false, - })); - break :blk header_buf[0..2]; - }, - 126...0xffff => blk: { - header_buf[1] = @bitCast(@as(Header1, .{ - .payload_len = .len16, - .mask = false, - })); - std.mem.writeInt(u16, header_buf[2..4], @intCast(total_len), .big); - break :blk header_buf[0..4]; - }, - else => blk: { - header_buf[1] = @bitCast(@as(Header1, .{ - .payload_len = .len64, - .mask = false, - })); - std.mem.writeInt(u64, header_buf[2..10], total_len, .big); - break :blk header_buf[0..10]; - }, - }; - - const response = &ws.response; - try response.writeAll(header); - for (message) |iovec| - try response.writeAll(iovec.base[0..iovec.len]); - try response.flush(); -} diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig deleted file mode 100644 index 797ed989ad..0000000000 --- a/lib/std/http/protocol.zig +++ /dev/null @@ -1,464 +0,0 @@ -const std = @import("../std.zig"); -const builtin = @import("builtin"); -const testing = std.testing; -const mem = std.mem; - -const assert = std.debug.assert; - -pub const State = enum { - invalid, - - // Begin header and trailer parsing states. - - start, - seen_n, - seen_r, - seen_rn, - seen_rnr, - finished, - - // Begin transfer-encoding: chunked parsing states. - - chunk_head_size, - chunk_head_ext, - chunk_head_r, - chunk_data, - chunk_data_suffix, - chunk_data_suffix_r, - - /// Returns true if the parser is in a content state (ie. not waiting for more headers). - pub fn isContent(self: State) bool { - return switch (self) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false, - .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true, - }; - } -}; - -pub const HeadersParser = struct { - state: State = .start, - /// A fixed buffer of len `max_header_bytes`. - /// Pointers into this buffer are not stable until after a message is complete. - header_bytes_buffer: []u8, - header_bytes_len: u32, - next_chunk_length: u64, - /// `false`: headers. `true`: trailers. - done: bool, - - /// Initializes the parser with a provided buffer `buf`. - pub fn init(buf: []u8) HeadersParser { - return .{ - .header_bytes_buffer = buf, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - /// Reinitialize the parser. - /// Asserts the parser is in the "done" state. - pub fn reset(hp: *HeadersParser) void { - assert(hp.done); - hp.* = .{ - .state = .start, - .header_bytes_buffer = hp.header_bytes_buffer, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - pub fn get(hp: HeadersParser) []u8 { - return hp.header_bytes_buffer[0..hp.header_bytes_len]; - } - - pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { - var hp: std.http.HeadParser = .{ - .state = switch (r.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - else => unreachable, - }, - }; - const result = hp.feed(bytes); - r.state = switch (hp.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - }; - return @intCast(result); - } - - pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { - var cp: std.http.ChunkParser = .{ - .state = switch (r.state) { - .chunk_head_size => .head_size, - .chunk_head_ext => .head_ext, - .chunk_head_r => .head_r, - .chunk_data => .data, - .chunk_data_suffix => .data_suffix, - .chunk_data_suffix_r => .data_suffix_r, - .invalid => .invalid, - else => unreachable, - }, - .chunk_len = r.next_chunk_length, - }; - const result = cp.feed(bytes); - r.state = switch (cp.state) { - .head_size => .chunk_head_size, - .head_ext => .chunk_head_ext, - .head_r => .chunk_head_r, - .data => .chunk_data, - .data_suffix => .chunk_data_suffix, - .data_suffix_r => .chunk_data_suffix_r, - .invalid => .invalid, - }; - r.next_chunk_length = cp.chunk_len; - return @intCast(result); - } - - /// Returns whether or not the parser has finished parsing a complete - /// message. A message is only complete after the entire body has been read - /// and any trailing headers have been parsed. - pub fn isComplete(r: *HeadersParser) bool { - return r.done and r.state == .finished; - } - - pub const CheckCompleteHeadError = error{HttpHeadersOversize}; - - /// Pushes `in` into the parser. Returns the number of bytes consumed by - /// the header. Any header bytes are appended to `header_bytes_buffer`. - pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 { - if (hp.state.isContent()) return 0; - - const i = hp.findHeadersEnd(in); - const data = in[0..i]; - if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len) - return error.HttpHeadersOversize; - - @memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data); - hp.header_bytes_len += @intCast(data.len); - - return i; - } - - pub const ReadError = error{ - HttpChunkInvalid, - }; - - /// Reads the body of the message into `buffer`. Returns the number of - /// bytes placed in the buffer. - /// - /// If `skip` is true, the buffer will be unused and the body will be skipped. - /// - /// See `std.http.Client.Connection for an example of `conn`. - pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize { - assert(r.state.isContent()); - if (r.done) return 0; - - var out_index: usize = 0; - while (true) { - switch (r.state) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable, - .finished => { - const data_avail = r.next_chunk_length; - - if (skip) { - conn.fill() catch |err| switch (err) { - error.EndOfStream => { - r.done = true; - return 0; - }, - else => |e| return e, - }; - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return out_index; - } else if (out_index < buffer.len) { - const out_avail = buffer.len - out_index; - - const can_read = @as(usize, @intCast(@min(data_avail, out_avail))); - const nread = try conn.read(buffer[0..can_read]); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return nread; - } else { - return out_index; - } - }, - .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { - conn.fill() catch |err| switch (err) { - error.EndOfStream => { - r.done = true; - return 0; - }, - else => |e| return e, - }; - - const i = r.findChunkedLen(conn.peek()); - conn.drop(@intCast(i)); - - switch (r.state) { - .invalid => return error.HttpChunkInvalid, - .chunk_data => if (r.next_chunk_length == 0) { - if (std.mem.eql(u8, conn.peek(), "\r\n")) { - r.state = .finished; - conn.drop(2); - } else { - // The trailer section is formatted identically - // to the header section. - r.state = .seen_rn; - } - r.done = true; - - return out_index; - }, - else => return out_index, - } - - continue; - }, - .chunk_data => { - const data_avail = r.next_chunk_length; - const out_avail = buffer.len - out_index; - - if (skip) { - conn.fill() catch |err| switch (err) { - error.EndOfStream => { - r.done = true; - return 0; - }, - else => |e| return e, - }; - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - } else if (out_avail > 0) { - const can_read: usize = @intCast(@min(data_avail, out_avail)); - const nread = try conn.read(buffer[out_index..][0..can_read]); - r.next_chunk_length -= nread; - out_index += nread; - } - - if (r.next_chunk_length == 0) { - r.state = .chunk_data_suffix; - continue; - } - - return out_index; - }, - } - } - } -}; - -inline fn int16(array: *const [2]u8) u16 { - return @as(u16, @bitCast(array.*)); -} - -inline fn int24(array: *const [3]u8) u24 { - return @as(u24, @bitCast(array.*)); -} - -inline fn int32(array: *const [4]u8) u32 { - return @as(u32, @bitCast(array.*)); -} - -inline fn intShift(comptime T: type, x: anytype) T { - switch (@import("builtin").cpu.arch.endian()) { - .little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))), - .big => return @as(T, @truncate(x)), - } -} - -/// A buffered (and peekable) Connection. -const MockBufferedConnection = struct { - pub const buffer_size = 0x2000; - - conn: std.io.FixedBufferStream([]const u8), - buf: [buffer_size]u8 = undefined, - start: u16 = 0, - end: u16 = 0, - - pub fn fill(conn: *MockBufferedConnection) ReadError!void { - if (conn.end != conn.start) return; - - const nread = try conn.conn.read(conn.buf[0..]); - if (nread == 0) return error.EndOfStream; - conn.start = 0; - conn.end = @as(u16, @truncate(nread)); - } - - pub fn peek(conn: *MockBufferedConnection) []const u8 { - return conn.buf[conn.start..conn.end]; - } - - pub fn drop(conn: *MockBufferedConnection, num: u16) void { - conn.start += num; - } - - pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { - var out_index: u16 = 0; - while (out_index < len) { - const available = conn.end - conn.start; - const left = buffer.len - out_index; - - if (available > 0) { - const can_read = @as(u16, @truncate(@min(available, left))); - - @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]); - out_index += can_read; - conn.start += can_read; - - continue; - } - - if (left > conn.buf.len) { - // skip the buffer if the output is large enough - return conn.conn.read(buffer[out_index..]); - } - - try conn.fill(); - } - - return out_index; - } - - pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); - } - - pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; - pub const Reader = std.io.GenericReader(*MockBufferedConnection, ReadError, read); - - pub fn reader(conn: *MockBufferedConnection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void { - return conn.conn.writeAll(buffer); - } - - pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { - return conn.conn.write(buffer); - } - - pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; - pub const Writer = std.io.GenericWriter(*MockBufferedConnection, WriteError, write); - - pub fn writer(conn: *MockBufferedConnection) Writer { - return Writer{ .context = conn }; - } -}; - -test "HeadersParser.read length" { - // mock BufferedConnection for read - var headers_buf: [256]u8 = undefined; - - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - var buf: [8]u8 = undefined; - - r.next_chunk_length = 5; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked trailer" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get()); -} diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 33bc2eb191..4c3466d5c9 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -10,32 +10,33 @@ const expectError = std.testing.expectError; test "trailers" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [1024]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [1024]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); - try expectEqual(.ready, server.state); + try expectEqual(.ready, server.reader.state); var request = try server.receiveHead(); try serve(&request); - try expectEqual(.ready, server.state); + try expectEqual(.ready, server.reader.state); } } fn serve(request: *http.Server.Request) !void { try expectEqualStrings(request.head.target, "/trailer"); - var send_buffer: [1024]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, - }); - try response.writeAll("Hello, "); + var response = try request.respondStreaming(&.{}, .{}); + try response.writer.writeAll("Hello, "); try response.flush(); - try response.writeAll("World!\n"); + try response.writer.writeAll("World!\n"); try response.flush(); try response.endChunked(.{ .trailers = &.{ @@ -58,34 +59,33 @@ test "trailers" { const uri = try std.Uri.parse(location); { - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&.{}); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - var it = req.response.iterateHeaders(); { + var it = response.head.iterateHeaders(); const header = it.next().?; try expect(!it.is_trailer); try expectEqualStrings("transfer-encoding", header.name); try expectEqualStrings("chunked", header.value); + try expectEqual(null, it.next()); } { + var it = response.iterateTrailers(); const header = it.next().?; try expect(it.is_trailer); try expectEqualStrings("X-Checksum", header.name); try expectEqualStrings("aaaa", header.value); + try expectEqual(null, it.next()); } - try expectEqual(null, it.next()); } // connection has been kept alive @@ -94,19 +94,24 @@ test "trailers" { test "HTTP server handles a chunked transfer coding request" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) !void { - var header_buffer: [8192]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [8192]u8 = undefined; + var send_buffer: [500]u8 = undefined; + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); var request = try server.receiveHead(); try expect(request.head.transfer_encoding == .chunked); var buf: [128]u8 = undefined; - const n = try (try request.reader()).readAll(&buf); - try expect(mem.eql(u8, buf[0..n], "ABCD")); + var br = try request.readerExpectContinue(&.{}); + const n = try br.readSliceShort(&buf); + try expectEqualStrings("ABCD", buf[0..n]); try request.respond("message from server!\n", .{ .extra_headers = &.{ @@ -154,16 +159,20 @@ test "HTTP server handles a chunked transfer coding request" { test "echo content server" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var read_buffer: [1024]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [100]u8 = undefined; - accept: while (true) { - const conn = try net_server.accept(); - defer conn.stream.close(); + accept: while (!test_server.shutting_down) { + const connection = try net_server.accept(); + defer connection.stream.close(); - var http_server = http.Server.init(conn, &read_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var http_server = http.Server.init(connection_br.interface(), &connection_bw.interface); - while (http_server.state == .ready) { + while (http_server.reader.state == .ready) { var request = http_server.receiveHead() catch |err| switch (err) { error.HttpConnectionClosing => continue :accept, else => |e| return e, @@ -173,7 +182,7 @@ test "echo content server" { } if (request.head.expect) |expect_header_value| { if (mem.eql(u8, expect_header_value, "garbage")) { - try expectError(error.HttpExpectationFailed, request.reader()); + try expectError(error.HttpExpectationFailed, request.readerExpectContinue(&.{})); try request.respond("", .{ .keep_alive = false }); continue; } @@ -195,16 +204,14 @@ test "echo content server" { // request.head.target, //}); - const body = try (try request.reader()).readAllAlloc(std.testing.allocator, 8192); + const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .limited(8192)); defer std.testing.allocator.free(body); try expect(mem.startsWith(u8, request.head.target, "/echo-content")); try expectEqualStrings("Hello, World!\n", body); try expectEqualStrings("text/plain", request.head.content_type.?); - var send_buffer: [100]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var response = try request.respondStreaming(&.{}, .{ .content_length = switch (request.head.transfer_encoding) { .chunked => null, .none => len: { @@ -213,9 +220,8 @@ test "echo content server" { }, }, }); - try response.flush(); // Test an early flush to send the HTTP headers before the body. - const w = response.writer(); + const w = &response.writer; try w.writeAll("Hello, "); try w.writeAll("World!\n"); try response.end(); @@ -241,35 +247,36 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { // In this case, the response is expected to stream until the connection is // closed, indicating the end of the body. const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [1000]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1000]u8 = undefined; + var send_buffer: [500]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); - try expectEqual(.ready, server.state); + try expectEqual(.ready, server.reader.state); var request = try server.receiveHead(); try expectEqualStrings(request.head.target, "/foo"); - var send_buffer: [500]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var buf: [30]u8 = undefined; + var response = try request.respondStreaming(&buf, .{ .respond_options = .{ .transfer_encoding = .none, }, }); - var total: usize = 0; + const w = &response.writer; for (0..500) |i| { - var buf: [30]u8 = undefined; - const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i}); - try response.writeAll(line); - total += line.len; + try w.print("{d}, ah ha ha!\n", .{i}); } - try expectEqual(7390, total); + try expectEqual(7390, w.count); + try w.flush(); try response.end(); - try expectEqual(.closing, server.state); + try expectEqual(.closing, server.reader.state); } } }); @@ -308,15 +315,20 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { test "receiving arbitrary http headers from the client" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var read_buffer: [666]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [666]u8 = undefined; + var send_buffer: [777]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &read_buffer); - try expectEqual(.ready, server.state); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); + + try expectEqual(.ready, server.reader.state); var request = try server.receiveHead(); try expectEqualStrings("/bar", request.head.target); var it = request.iterateHeaders(); @@ -368,19 +380,21 @@ test "general client/server API coverage" { return error.SkipZigTest; } - const global = struct { - var handle_new_requests = true; - }; const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var client_header_buffer: [1024]u8 = undefined; - outer: while (global.handle_new_requests) { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [100]u8 = undefined; + + outer: while (!test_server.shutting_down) { var connection = try net_server.accept(); defer connection.stream.close(); - var http_server = http.Server.init(connection, &client_header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var http_server = http.Server.init(connection_br.interface(), &connection_bw.interface); - while (http_server.state == .ready) { + while (http_server.reader.state == .ready) { var request = http_server.receiveHead() catch |err| switch (err) { error.HttpConnectionClosing => continue :outer, else => |e| return e, @@ -399,14 +413,11 @@ test "general client/server API coverage" { }); const gpa = std.testing.allocator; - const body = try (try request.reader()).readAllAlloc(gpa, 8192); + const body = try (try request.readerExpectContinue(&.{})).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); - var send_buffer: [100]u8 = undefined; - if (mem.startsWith(u8, request.head.target, "/get")) { - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var response = try request.respondStreaming(&.{}, .{ .content_length = if (mem.indexOf(u8, request.head.target, "?chunked") == null) 14 else @@ -417,20 +428,19 @@ test "general client/server API coverage" { }, }, }); - const w = response.writer(); + const w = &response.writer; try w.writeAll("Hello, "); try w.writeAll("World!\n"); try response.end(); // Writing again would cause an assertion failure. } else if (mem.startsWith(u8, request.head.target, "/large")) { - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var response = try request.respondStreaming(&.{}, .{ .content_length = 14 * 1024 + 14 * 10, }); try response.flush(); // Test an early flush to send the HTTP headers before the body. - const w = response.writer(); + const w = &response.writer; var i: u32 = 0; while (i < 5) : (i += 1) { @@ -446,8 +456,7 @@ test "general client/server API coverage" { try response.end(); } else if (mem.eql(u8, request.head.target, "/redirect/1")) { - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var response = try request.respondStreaming(&.{}, .{ .respond_options = .{ .status = .found, .extra_headers = &.{ @@ -456,7 +465,7 @@ test "general client/server API coverage" { }, }); - const w = response.writer(); + const w = &response.writer; try w.writeAll("Hello, "); try w.writeAll("Redirected!\n"); try response.end(); @@ -524,17 +533,13 @@ test "general client/server API coverage" { return s.listen_address.in.getPort(); } }); - defer { - global.handle_new_requests = false; - test_server.destroy(); - } + defer test_server.destroy(); const log = std.log.scoped(.client); const gpa = std.testing.allocator; var client: http.Client = .{ .allocator = gpa }; - errdefer client.deinit(); - // defer client.deinit(); handled below + defer client.deinit(); const port = test_server.port(); @@ -544,20 +549,18 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", req.response.content_type.?); + try expectEqualStrings("text/plain", response.head.content_type.?); } // connection has been kept alive @@ -569,16 +572,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192 * 1024); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192 * 1024)); defer gpa.free(body); try expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len); @@ -593,21 +594,19 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.HEAD, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.HEAD, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("", body); - try expectEqualStrings("text/plain", req.response.content_type.?); - try expectEqual(14, req.response.content_length.?); + try expectEqualStrings("text/plain", response.head.content_type.?); + try expectEqual(14, response.head.content_length.?); } // connection has been kept alive @@ -619,20 +618,18 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", req.response.content_type.?); + try expectEqualStrings("text/plain", response.head.content_type.?); } // connection has been kept alive @@ -644,21 +641,19 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.HEAD, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.HEAD, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("", body); - try expectEqualStrings("text/plain", req.response.content_type.?); - try expect(req.response.transfer_encoding == .chunked); + try expectEqualStrings("text/plain", response.head.content_type.?); + try expect(response.head.transfer_encoding == .chunked); } // connection has been kept alive @@ -670,21 +665,20 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{ .keep_alive = false, }); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", req.response.content_type.?); + try expectEqualStrings("text/plain", response.head.content_type.?); } // connection has been closed @@ -696,26 +690,25 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{ .extra_headers = &.{ .{ .name = "empty", .value = "" }, }, }); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - try std.testing.expectEqual(.ok, req.response.status); + try std.testing.expectEqual(.ok, response.head.status); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("", body); - var it = req.response.iterateHeaders(); + var it = response.head.iterateHeaders(); { const header = it.next().?; try expect(!it.is_trailer); @@ -740,16 +733,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -764,16 +755,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -788,16 +777,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -812,17 +799,17 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - req.wait() catch |err| switch (err) { + try req.sendBodiless(); + if (req.receiveHead(&redirect_buffer)) |_| { + return error.TestFailed; + } else |err| switch (err) { error.TooManyHttpRedirects => {}, else => return err, - }; + } } { // redirect to encoded url @@ -831,16 +818,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Encoded redirect successful!\n", body); @@ -855,14 +840,12 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - const result = req.wait(); + try req.sendBodiless(); + const result = req.receiveHead(&redirect_buffer); // a proxy without an upstream is likely to return a 5xx status. if (client.http_proxy == null) { @@ -872,77 +855,40 @@ test "general client/server API coverage" { // connection has been kept alive try expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // issue 16282 *** This test leaves the client in an invalid state, it must be last *** - const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port}); - defer gpa.free(location); - const uri = try std.Uri.parse(location); - - const total_connections = client.connection_pool.free_size + 64; - var requests = try gpa.alloc(http.Client.Request, total_connections); - defer gpa.free(requests); - - var header_bufs = std.ArrayList([]u8).init(gpa); - defer header_bufs.deinit(); - defer for (header_bufs.items) |item| gpa.free(item); - - for (0..total_connections) |i| { - const headers_buf = try gpa.alloc(u8, 1024); - try header_bufs.append(headers_buf); - var req = try client.open(.GET, uri, .{ - .server_header_buffer = headers_buf, - }); - req.response.parser.done = true; - req.connection.?.closing = false; - requests[i] = req; - } - - for (0..total_connections) |i| { - requests[i].deinit(); - } - - // free connections should be full now - try expect(client.connection_pool.free_len == client.connection_pool.free_size); - } - - client.deinit(); - - { - global.handle_new_requests = false; - - const conn = try std.net.tcpConnectToAddress(test_server.net_server.listen_address); - conn.close(); - } } test "Server streams both reading and writing" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [1024]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); - - var server = http.Server.init(conn, &header_buffer); - var request = try server.receiveHead(); - const reader = try request.reader(); - + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; var send_buffer: [777]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + + const connection = try net_server.accept(); + defer connection.stream.close(); + + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); + var request = try server.receiveHead(); + var read_buffer: [100]u8 = undefined; + var br = try request.readerExpectContinue(&read_buffer); + var response = try request.respondStreaming(&.{}, .{ .respond_options = .{ .transfer_encoding = .none, // Causes keep_alive=false }, }); - const writer = response.writer(); + const w = &response.writer; while (true) { try response.flush(); - var buf: [100]u8 = undefined; - const n = try reader.read(&buf); - if (n == 0) break; - const sub_buf = buf[0..n]; - for (sub_buf) |*b| b.* = std.ascii.toUpper(b.*); - try writer.writeAll(sub_buf); + const buf = br.peekGreedy(1) catch |err| switch (err) { + error.EndOfStream => break, + error.ReadFailed => return error.ReadFailed, + }; + br.toss(buf.len); + for (buf) |*b| b.* = std.ascii.toUpper(b.*); + try w.writeAll(buf); } try response.end(); } @@ -952,27 +898,24 @@ test "Server streams both reading and writing" { var client: http.Client = .{ .allocator = std.testing.allocator }; defer client.deinit(); - var server_header_buffer: [555]u8 = undefined; - var req = try client.open(.POST, .{ + var redirect_buffer: [555]u8 = undefined; + var req = try client.request(.POST, .{ .scheme = "http", .host = .{ .raw = "127.0.0.1" }, .port = test_server.port(), .path = .{ .percent_encoded = "/" }, - }, .{ - .server_header_buffer = &server_header_buffer, - }); + }, .{}); defer req.deinit(); req.transfer_encoding = .chunked; - try req.send(); - try req.wait(); + var body_writer = try req.sendBody(&.{}); + var response = try req.receiveHead(&redirect_buffer); - try req.writeAll("one "); - try req.writeAll("fish"); + try body_writer.writer.writeAll("one "); + try body_writer.writer.writeAll("fish"); + try body_writer.end(); - try req.finish(); - - const body = try req.reader().readAllAlloc(std.testing.allocator, 8192); + const body = try response.reader(&.{}).allocRemaining(std.testing.allocator, .limited(8192)); defer std.testing.allocator.free(body); try expectEqualStrings("ONE FISH", body); @@ -987,9 +930,8 @@ fn echoTests(client: *http.Client, port: u16) !void { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, }, @@ -998,14 +940,14 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .{ .content_length = 14 }; - try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); + var body_writer = try req.sendBody(&.{}); + try body_writer.writer.writeAll("Hello, "); + try body_writer.writer.writeAll("World!\n"); + try body_writer.end(); - try req.wait(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1021,9 +963,8 @@ fn echoTests(client: *http.Client, port: u16) !void { .{port}, )); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, }, @@ -1032,14 +973,14 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; - try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); + var body_writer = try req.sendBody(&.{}); + try body_writer.writer.writeAll("Hello, "); + try body_writer.writer.writeAll("World!\n"); + try body_writer.end(); - try req.wait(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1053,8 +994,8 @@ fn echoTests(client: *http.Client, port: u16) !void { const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#fetch", .{port}); defer gpa.free(location); - var body = std.ArrayList(u8).init(gpa); - defer body.deinit(); + var body: std.ArrayListUnmanaged(u8) = .empty; + defer body.deinit(gpa); const res = try client.fetch(.{ .location = .{ .url = location }, @@ -1063,7 +1004,7 @@ fn echoTests(client: *http.Client, port: u16) !void { .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, }, - .response_storage = .{ .dynamic = &body }, + .response_storage = .{ .allocator = gpa, .list = &body }, }); try expectEqual(.ok, res.status); try expectEqualStrings("Hello, World!\n", body.items); @@ -1074,9 +1015,8 @@ fn echoTests(client: *http.Client, port: u16) !void { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "expect", .value = "100-continue" }, .{ .name = "content-type", .value = "text/plain" }, @@ -1086,15 +1026,15 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; - try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); + var body_writer = try req.sendBody(&.{}); + try body_writer.writer.writeAll("Hello, "); + try body_writer.writer.writeAll("World!\n"); + try body_writer.end(); - try req.wait(); - try expectEqual(.ok, req.response.status); + var response = try req.receiveHead(&redirect_buffer); + try expectEqual(.ok, response.head.status); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1105,9 +1045,8 @@ fn echoTests(client: *http.Client, port: u16) !void { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, .{ .name = "expect", .value = "garbage" }, @@ -1117,23 +1056,24 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; - try req.send(); - try req.wait(); - try expectEqual(.expectation_failed, req.response.status); + var body_writer = try req.sendBody(&.{}); + try body_writer.flush(); + var response = try req.receiveHead(&redirect_buffer); + try expectEqual(.expectation_failed, response.head.status); + _ = try response.reader(&.{}).discardRemaining(); } - - _ = try client.fetch(.{ - .location = .{ - .url = try std.fmt.bufPrint(&location_buffer, "http://127.0.0.1:{d}/end", .{port}), - }, - }); } const TestServer = struct { + shutting_down: bool, server_thread: std.Thread, net_server: std.net.Server, fn destroy(self: *@This()) void { + self.shutting_down = true; + const conn = std.net.tcpConnectToAddress(self.net_server.listen_address) catch @panic("shutdown failure"); + conn.close(); + self.server_thread.join(); self.net_server.deinit(); std.testing.allocator.destroy(self); @@ -1153,20 +1093,27 @@ fn createTestServer(S: type) !*TestServer { const address = try std.net.Address.parseIp("127.0.0.1", 0); const test_server = try std.testing.allocator.create(TestServer); - test_server.net_server = try address.listen(.{ .reuse_address = true }); - test_server.server_thread = try std.Thread.spawn(.{}, S.run, .{&test_server.net_server}); + test_server.* = .{ + .net_server = try address.listen(.{ .reuse_address = true }), + .server_thread = try std.Thread.spawn(.{}, S.run, .{test_server}), + .shutting_down = false, + }; return test_server; } test "redirect to different connection" { const test_server_new = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [888]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [888]u8 = undefined; + var send_buffer: [777]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); var request = try server.receiveHead(); try expectEqualStrings(request.head.target, "/ok"); try request.respond("good job, you pass", .{}); @@ -1180,18 +1127,22 @@ test "redirect to different connection" { global.other_port = test_server_new.port(); const test_server_orig = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [999]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [999]u8 = undefined; var send_buffer: [100]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - const new_loc = try std.fmt.bufPrint(&send_buffer, "http://127.0.0.1:{d}/ok", .{ + var loc_buf: [50]u8 = undefined; + const new_loc = try std.fmt.bufPrint(&loc_buf, "http://127.0.0.1:{d}/ok", .{ global.other_port.?, }); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); var request = try server.receiveHead(); try expectEqualStrings(request.head.target, "/help"); try request.respond("", .{ @@ -1216,16 +1167,15 @@ test "redirect to different connection" { const uri = try std.Uri.parse(location); { - var server_header_buffer: [666]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [666]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); + var reader = response.reader(&.{}); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try reader.allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("good job, you pass", body);