From e2d81bf6c04d73cbe236e4b497fd2c8c54916077 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 1 Aug 2025 16:38:41 -0700 Subject: [PATCH] http fixes --- lib/std/Io/Writer.zig | 94 +++++++++++--- lib/std/Uri.zig | 236 +++++++++++++--------------------- lib/std/crypto/tls/Client.zig | 4 +- lib/std/http.zig | 82 +++++++++--- lib/std/http/Client.zig | 29 +++-- lib/std/http/Server.zig | 40 +++--- lib/std/http/test.zig | 56 ++++---- 7 files changed, 297 insertions(+), 244 deletions(-) diff --git a/lib/std/Io/Writer.zig b/lib/std/Io/Writer.zig index a84077f8f3..797a69914c 100644 --- a/lib/std/Io/Writer.zig +++ b/lib/std/Io/Writer.zig @@ -191,29 +191,87 @@ pub fn writeSplatHeader( data: []const []const u8, splat: usize, ) Error!usize { - const new_end = w.end + header.len; - if (new_end <= w.buffer.len) { - @memcpy(w.buffer[w.end..][0..header.len], header); - w.end = new_end; - return header.len + try writeSplat(w, data, splat); + return writeSplatHeaderLimit(w, header, data, splat, .unlimited); +} + +/// Equivalent to `writeSplatHeader` but writes at most `limit` bytes. +pub fn writeSplatHeaderLimit( + w: *Writer, + header: []const u8, + data: []const []const u8, + splat: usize, + limit: Limit, +) Error!usize { + var remaining = @intFromEnum(limit); + { + const copy_len = @min(header.len, w.buffer.len - w.end, remaining); + if (header.len - copy_len != 0) return writeSplatHeaderLimitFinish(w, header, data, splat, remaining); + @memcpy(w.buffer[w.end..][0..copy_len], header[0..copy_len]); + w.end += copy_len; + remaining -= copy_len; } - var vecs: [8][]const u8 = undefined; // Arbitrarily chosen size. - var i: usize = 1; - vecs[0] = header; - for (data[0 .. data.len - 1]) |buf| { - if (buf.len == 0) continue; - vecs[i] = buf; - i += 1; - if (vecs.len - i == 0) break; + for (data[0 .. data.len - 1], 0..) |buf, i| { + const copy_len = @min(buf.len, w.buffer.len - w.end, remaining); + if (buf.len - copy_len != 0) return @intFromEnum(limit) - remaining + + try writeSplatHeaderLimitFinish(w, &.{}, data[i..], splat, remaining); + @memcpy(w.buffer[w.end..][0..copy_len], buf[0..copy_len]); + w.end += copy_len; + remaining -= copy_len; } const pattern = data[data.len - 1]; - const new_splat = s: { - if (pattern.len == 0 or vecs.len - i == 0) break :s 1; + const splat_n = pattern.len * splat; + if (splat_n > @min(w.buffer.len - w.end, remaining)) { + const buffered_n = @intFromEnum(limit) - remaining; + const written = try writeSplatHeaderLimitFinish(w, &.{}, data[data.len - 1 ..][0..1], splat, remaining); + return buffered_n + written; + } + + for (0..splat) |_| { + @memcpy(w.buffer[w.end..][0..pattern.len], pattern); + w.end += pattern.len; + } + + remaining -= splat_n; + return @intFromEnum(limit) - remaining; +} + +fn writeSplatHeaderLimitFinish( + w: *Writer, + header: []const u8, + data: []const []const u8, + splat: usize, + limit: usize, +) Error!usize { + var remaining = limit; + var vecs: [8][]const u8 = undefined; + var i: usize = 0; + v: { + if (header.len != 0) { + const copy_len = @min(header.len, remaining); + vecs[i] = header[0..copy_len]; + i += 1; + remaining -= copy_len; + if (remaining == 0) break :v; + } + for (data[0 .. data.len - 1]) |buf| if (buf.len != 0) { + const copy_len = @min(header.len, remaining); + vecs[i] = buf; + i += 1; + remaining -= copy_len; + if (remaining == 0) break :v; + if (vecs.len - i == 0) break :v; + }; + const pattern = data[data.len - 1]; + if (splat == 1) { + vecs[i] = pattern[0..@min(remaining, pattern.len)]; + i += 1; + break :v; + } vecs[i] = pattern; i += 1; - break :s splat; - }; - return w.vtable.drain(w, vecs[0..i], new_splat); + return w.vtable.drain(w, (&vecs)[0..i], @min(remaining / pattern.len, splat)); + } + return w.vtable.drain(w, (&vecs)[0..i], 1); } test "writeSplatHeader splatting avoids buffer aliasing temptation" { diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 19af1512c2..5390bee5b5 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -4,6 +4,8 @@ const std = @import("std.zig"); const testing = std.testing; const Uri = @This(); +const Allocator = std.mem.Allocator; +const Writer = std.Io.Writer; scheme: []const u8, user: ?Component = null, @@ -14,6 +16,32 @@ path: Component = Component.empty, query: ?Component = null, fragment: ?Component = null, +pub const host_name_max = 255; + +/// Returned value may point into `buffer` or be the original string. +/// +/// Suggested buffer length: `host_name_max`. +/// +/// See also: +/// * `getHostAlloc` +pub fn getHost(uri: Uri, buffer: []u8) error{ UriMissingHost, UriHostTooLong }![]const u8 { + const component = uri.host orelse return error.UriMissingHost; + return component.toRaw(buffer) catch |err| switch (err) { + error.NoSpaceLeft => return error.UriHostTooLong, + }; +} + +/// Returned value may point into `buffer` or be the original string. +/// +/// See also: +/// * `getHost` +pub fn getHostAlloc(uri: Uri, arena: Allocator) error{ UriMissingHost, UriHostTooLong, OutOfMemory }![]const u8 { + const component = uri.host orelse return error.UriMissingHost; + const result = try component.toRawMaybeAlloc(arena); + if (result.len > host_name_max) return error.UriHostTooLong; + return result; +} + pub const Component = union(enum) { /// Invalid characters in this component must be percent encoded /// before being printed as part of a URI. @@ -30,11 +58,19 @@ pub const Component = union(enum) { }; } + /// Returned value may point into `buffer` or be the original string. + pub fn toRaw(component: Component, buffer: []u8) error{NoSpaceLeft}![]const u8 { + return switch (component) { + .raw => |raw| raw, + .percent_encoded => |percent_encoded| if (std.mem.indexOfScalar(u8, percent_encoded, '%')) |_| + try std.fmt.bufPrint(buffer, "{f}", .{std.fmt.alt(component, .formatRaw)}) + else + percent_encoded, + }; + } + /// Allocates the result with `arena` only if needed, so the result should not be freed. - pub fn toRawMaybeAlloc( - component: Component, - arena: std.mem.Allocator, - ) std.mem.Allocator.Error![]const u8 { + pub fn toRawMaybeAlloc(component: Component, arena: Allocator) Allocator.Error![]const u8 { return switch (component) { .raw => |raw| raw, .percent_encoded => |percent_encoded| if (std.mem.indexOfScalar(u8, percent_encoded, '%')) |_| @@ -44,7 +80,7 @@ pub const Component = union(enum) { }; } - pub fn formatRaw(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatRaw(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try w.writeAll(raw), .percent_encoded => |percent_encoded| { @@ -67,56 +103,56 @@ pub const Component = union(enum) { } } - pub fn formatEscaped(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatEscaped(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isUnreserved), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatUser(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatUser(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isUserChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatPassword(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatPassword(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isPasswordChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatHost(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatHost(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isHostChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatPath(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatPath(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isPathChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatQuery(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatQuery(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isQueryChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatFragment(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatFragment(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isFragmentChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn percentEncode(w: *std.io.Writer, raw: []const u8, comptime isValidChar: fn (u8) bool) std.io.Writer.Error!void { + pub fn percentEncode(w: *Writer, raw: []const u8, comptime isValidChar: fn (u8) bool) Writer.Error!void { var start: usize = 0; for (raw, 0..) |char, index| { if (isValidChar(char)) continue; @@ -165,17 +201,15 @@ pub const ParseError = error{ UnexpectedCharacter, InvalidFormat, InvalidPort }; /// The return value will contain strings pointing into the original `text`. /// Each component that is provided, will be non-`null`. pub fn parseAfterScheme(scheme: []const u8, text: []const u8) ParseError!Uri { - var reader = SliceReader{ .slice = text }; - var uri: Uri = .{ .scheme = scheme, .path = undefined }; + var i: usize = 0; - if (reader.peekPrefix("//")) a: { // authority part - std.debug.assert(reader.get().? == '/'); - std.debug.assert(reader.get().? == '/'); - - const authority = reader.readUntil(isAuthoritySeparator); + if (std.mem.startsWith(u8, text, "//")) a: { + i = std.mem.indexOfAnyPos(u8, text, 2, &authority_sep) orelse text.len; + const authority = text[2..i]; if (authority.len == 0) { - if (reader.peekPrefix("/")) break :a else return error.InvalidFormat; + if (!std.mem.startsWith(u8, text[2..], "/")) return error.InvalidFormat; + break :a; } var start_of_host: usize = 0; @@ -225,26 +259,28 @@ pub fn parseAfterScheme(scheme: []const u8, text: []const u8) ParseError!Uri { uri.host = .{ .percent_encoded = authority[start_of_host..end_of_host] }; } - uri.path = .{ .percent_encoded = reader.readUntil(isPathSeparator) }; + const path_start = i; + i = std.mem.indexOfAnyPos(u8, text, path_start, &path_sep) orelse text.len; + uri.path = .{ .percent_encoded = text[path_start..i] }; - if ((reader.peek() orelse 0) == '?') { // query part - std.debug.assert(reader.get().? == '?'); - uri.query = .{ .percent_encoded = reader.readUntil(isQuerySeparator) }; + if (std.mem.startsWith(u8, text[i..], "?")) { + const query_start = i + 1; + i = std.mem.indexOfScalarPos(u8, text, query_start, '#') orelse text.len; + uri.query = .{ .percent_encoded = text[query_start..i] }; } - if ((reader.peek() orelse 0) == '#') { // fragment part - std.debug.assert(reader.get().? == '#'); - uri.fragment = .{ .percent_encoded = reader.readUntilEof() }; + if (std.mem.startsWith(u8, text[i..], "#")) { + uri.fragment = .{ .percent_encoded = text[i + 1 ..] }; } return uri; } -pub fn format(uri: *const Uri, writer: *std.io.Writer) std.io.Writer.Error!void { +pub fn format(uri: *const Uri, writer: *Writer) Writer.Error!void { return writeToStream(uri, writer, .all); } -pub fn writeToStream(uri: *const Uri, writer: *std.io.Writer, flags: Format.Flags) std.io.Writer.Error!void { +pub fn writeToStream(uri: *const Uri, writer: *Writer, flags: Format.Flags) Writer.Error!void { if (flags.scheme) { try writer.print("{s}:", .{uri.scheme}); if (flags.authority and uri.host != null) { @@ -318,7 +354,7 @@ pub const Format = struct { }; }; - pub fn default(f: Format, writer: *std.io.Writer) std.io.Writer.Error!void { + pub fn default(f: Format, writer: *Writer) Writer.Error!void { return writeToStream(f.uri, writer, f.flags); } }; @@ -327,41 +363,33 @@ pub fn fmt(uri: *const Uri, flags: Format.Flags) std.fmt.Formatter(Format, Forma return .{ .data = .{ .uri = uri, .flags = flags } }; } -/// Parses the URI or returns an error. -/// The return value will contain strings pointing into the -/// original `text`. Each component that is provided, will be non-`null`. +/// The return value will contain strings pointing into the original `text`. +/// Each component that is provided will be non-`null`. pub fn parse(text: []const u8) ParseError!Uri { - var reader: SliceReader = .{ .slice = text }; - const scheme = reader.readWhile(isSchemeChar); - - // after the scheme, a ':' must appear - if (reader.get()) |c| { - if (c != ':') - return error.UnexpectedCharacter; - } else { - return error.InvalidFormat; - } - - return parseAfterScheme(scheme, reader.readUntilEof()); + const end = for (text, 0..) |byte, i| { + if (!isSchemeChar(byte)) break i; + } else text.len; + // After the scheme, a ':' must appear. + if (end >= text.len) return error.InvalidFormat; + if (text[end] != ':') return error.UnexpectedCharacter; + return parseAfterScheme(text[0..end], text[end + 1 ..]); } pub const ResolveInPlaceError = ParseError || error{NoSpaceLeft}; /// Resolves a URI against a base URI, conforming to RFC 3986, Section 5. -/// Copies `new` to the beginning of `aux_buf.*`, allowing the slices to overlap, -/// then parses `new` as a URI, and then resolves the path in place. +/// +/// Assumes new location is already copied to the beginning of `aux_buf.*`. +/// Parses that new location as a URI, and then resolves the path in place. +/// /// If a merge needs to take place, the newly constructed path will be stored -/// in `aux_buf.*` just after the copied `new`, and `aux_buf.*` will be modified -/// to only contain the remaining unused space. -pub fn resolve_inplace(base: Uri, new: []const u8, aux_buf: *[]u8) ResolveInPlaceError!Uri { - std.mem.copyForwards(u8, aux_buf.*, new); - // At this point, new is an invalid pointer. - const new_mut = aux_buf.*[0..new.len]; - aux_buf.* = aux_buf.*[new.len..]; - - const new_parsed = parse(new_mut) catch |err| - (parseAfterScheme("", new_mut) catch return err); - // As you can see above, `new_mut` is not a const pointer. +/// in `aux_buf.*` just after the copied location, and `aux_buf.*` will be +/// modified to only contain the remaining unused space. +pub fn resolveInPlace(base: Uri, new_len: usize, aux_buf: *[]u8) ResolveInPlaceError!Uri { + const new = aux_buf.*[0..new_len]; + const new_parsed = parse(new) catch |err| (parseAfterScheme("", new) catch return err); + aux_buf.* = aux_buf.*[new_len..]; + // As you can see above, `new` is not a const pointer. const new_path: []u8 = @constCast(new_parsed.path.percent_encoded); if (new_parsed.scheme.len > 0) return .{ @@ -461,7 +489,7 @@ test remove_dot_segments { /// 5.2.3. Merge Paths fn merge_paths(base: Component, new: []u8, aux_buf: *[]u8) error{NoSpaceLeft}!Component { - var aux: std.io.Writer = .fixed(aux_buf.*); + var aux: Writer = .fixed(aux_buf.*); if (!base.isEmpty()) { base.formatPath(&aux) catch return error.NoSpaceLeft; aux.end = std.mem.lastIndexOfScalar(u8, aux.buffered(), '/') orelse return remove_dot_segments(new); @@ -472,59 +500,6 @@ fn merge_paths(base: Component, new: []u8, aux_buf: *[]u8) error{NoSpaceLeft}!Co return merged_path; } -const SliceReader = struct { - const Self = @This(); - - slice: []const u8, - offset: usize = 0, - - fn get(self: *Self) ?u8 { - if (self.offset >= self.slice.len) - return null; - const c = self.slice[self.offset]; - self.offset += 1; - return c; - } - - fn peek(self: Self) ?u8 { - if (self.offset >= self.slice.len) - return null; - return self.slice[self.offset]; - } - - fn readWhile(self: *Self, comptime predicate: fn (u8) bool) []const u8 { - const start = self.offset; - var end = start; - while (end < self.slice.len and predicate(self.slice[end])) { - end += 1; - } - self.offset = end; - return self.slice[start..end]; - } - - fn readUntil(self: *Self, comptime predicate: fn (u8) bool) []const u8 { - const start = self.offset; - var end = start; - while (end < self.slice.len and !predicate(self.slice[end])) { - end += 1; - } - self.offset = end; - return self.slice[start..end]; - } - - fn readUntilEof(self: *Self) []const u8 { - const start = self.offset; - self.offset = self.slice.len; - return self.slice[start..]; - } - - fn peekPrefix(self: Self, prefix: []const u8) bool { - if (self.offset + prefix.len > self.slice.len) - return false; - return std.mem.eql(u8, self.slice[self.offset..][0..prefix.len], prefix); - } -}; - /// scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) fn isSchemeChar(c: u8) bool { return switch (c) { @@ -533,19 +508,6 @@ fn isSchemeChar(c: u8) bool { }; } -/// reserved = gen-delims / sub-delims -fn isReserved(c: u8) bool { - return isGenLimit(c) or isSubLimit(c); -} - -/// gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" -fn isGenLimit(c: u8) bool { - return switch (c) { - ':', ',', '?', '#', '[', ']', '@' => true, - else => false, - }; -} - /// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" /// / "*" / "+" / "," / ";" / "=" fn isSubLimit(c: u8) bool { @@ -585,26 +547,8 @@ fn isQueryChar(c: u8) bool { const isFragmentChar = isQueryChar; -fn isAuthoritySeparator(c: u8) bool { - return switch (c) { - '/', '?', '#' => true, - else => false, - }; -} - -fn isPathSeparator(c: u8) bool { - return switch (c) { - '?', '#' => true, - else => false, - }; -} - -fn isQuerySeparator(c: u8) bool { - return switch (c) { - '#' => true, - else => false, - }; -} +const authority_sep: [3]u8 = .{ '/', '?', '#' }; +const path_sep: [2]u8 = .{ '?', '#' }; test "basic" { const parsed = try parse("https://ziglang.org/download"); diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 082fc9da70..5a9e333b67 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -328,7 +328,9 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; fragment: while (true) { // Ensure the input buffer pointer is stable in this scope. - input.rebaseCapacity(tls.max_ciphertext_record_len); + input.rebase(tls.max_ciphertext_record_len) catch |err| switch (err) { + error.EndOfStream => {}, // We have assurance the remainder of stream can be buffered. + }; const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { error.EndOfStream => return error.TlsConnectionTruncated, error.ReadFailed => return error.ReadFailed, diff --git a/lib/std/http.zig b/lib/std/http.zig index c64a946a25..24bc016c30 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -343,9 +343,6 @@ pub const Reader = struct { /// read from `in`. trailers: []const u8 = &.{}, body_err: ?BodyError = null, - /// Determines at which point `error.HttpHeadersOversize` occurs, as well - /// as the minimum buffer capacity of `in`. - max_head_len: usize, pub const RemainingChunkLen = enum(u64) { head = 0, @@ -397,27 +394,34 @@ pub const Reader = struct { ReadFailed, }; - /// Buffers the entire head. - pub fn receiveHead(reader: *Reader) HeadError!void { + /// Buffers the entire head inside `in`. + /// + /// The resulting memory is invalidated by any subsequent consumption of + /// the input stream. + pub fn receiveHead(reader: *Reader) HeadError![]const u8 { reader.trailers = &.{}; const in = reader.in; - try in.rebase(reader.max_head_len); var hp: HeadParser = .{}; - var head_end: usize = 0; + var head_len: usize = 0; while (true) { - if (head_end >= in.buffer.len) return error.HttpHeadersOversize; - in.fillMore() catch |err| switch (err) { - error.EndOfStream => switch (head_end) { - 0 => return error.HttpConnectionClosing, - else => return error.HttpRequestTruncated, - }, - error.ReadFailed => return error.ReadFailed, - }; - head_end += hp.feed(in.buffered()[head_end..]); + if (in.buffer.len - head_len == 0) return error.HttpHeadersOversize; + const remaining = in.buffered()[head_len..]; + if (remaining.len == 0) { + in.fillMore() catch |err| switch (err) { + error.EndOfStream => switch (head_len) { + 0 => return error.HttpConnectionClosing, + else => return error.HttpRequestTruncated, + }, + error.ReadFailed => return error.ReadFailed, + }; + continue; + } + head_len += hp.feed(remaining); if (hp.state == .finished) { - reader.head_buffer = in.steal(head_end); reader.state = .received_head; - return; + const head_buffer = in.buffered()[0..head_len]; + in.toss(head_len); + return head_buffer; } } } @@ -786,7 +790,7 @@ pub const BodyWriter = struct { }; pub fn isEliding(w: *const BodyWriter) bool { - return w.writer.vtable.drain == Writer.discardingDrain; + return w.writer.vtable.drain == elidingDrain; } /// Sends all buffered data across `BodyWriter.http_protocol_output`. @@ -930,6 +934,46 @@ pub const BodyWriter = struct { return w.consume(n); } + pub fn elidingDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + const slice = data[0 .. data.len - 1]; + const pattern = data[slice.len]; + var written: usize = pattern.len * splat; + for (slice) |bytes| written += bytes.len; + switch (bw.state) { + .content_length => |*len| len.* -= written + w.end, + else => {}, + } + w.end = 0; + return written; + } + + pub fn elidingSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + if (File.Handle == void) return error.Unimplemented; + if (builtin.zig_backend == .stage2_aarch64) return error.Unimplemented; + switch (bw.state) { + .content_length => |*len| len.* -= w.end, + else => {}, + } + w.end = 0; + if (limit == .nothing) return 0; + if (file_reader.getSize()) |size| { + const n = limit.minInt64(size - file_reader.pos); + if (n == 0) return error.EndOfStream; + file_reader.seekBy(@intCast(n)) catch return error.Unimplemented; + switch (bw.state) { + .content_length => |*len| len.* -= n, + else => {}, + } + return n; + } else |_| { + // Error is observable on `file_reader` instance, and it is better to + // treat the file as a pipe. + return error.Unimplemented; + } + } + /// Returns `null` if size cannot be computed without making any syscalls. pub fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { const bw: *BodyWriter = @fieldParentPtr("writer", w); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 61a9eeb5c3..6660761930 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -821,7 +821,6 @@ pub const Request = struct { /// Returns the request's `Connection` back to the pool of the `Client`. pub fn deinit(r: *Request) void { - r.reader.restituteHeadBuffer(); if (r.connection) |connection| { connection.closing = connection.closing or switch (r.reader.state) { .ready => false, @@ -908,13 +907,13 @@ pub const Request = struct { const connection = r.connection.?; const w = connection.writer(); - try r.method.write(w); + try r.method.format(w); try w.writeByte(' '); if (r.method == .CONNECT) { - try uri.writeToStream(.{ .authority = true }, w); + try uri.writeToStream(w, .{ .authority = true }); } else { - try uri.writeToStream(.{ + try uri.writeToStream(w, .{ .scheme = connection.proxied, .authentication = connection.proxied, .authority = connection.proxied, @@ -928,7 +927,7 @@ pub const Request = struct { if (try emitOverridableHeader("host: ", r.headers.host, w)) { try w.writeAll("host: "); - try uri.writeToStream(.{ .authority = true }, w); + try uri.writeToStream(w, .{ .authority = true }); try w.writeAll("\r\n"); } @@ -1046,10 +1045,10 @@ pub const Request = struct { pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response { var aux_buf = redirect_buffer; while (true) { - try r.reader.receiveHead(); + const head_buffer = try r.reader.receiveHead(); const response: Response = .{ .request = r, - .head = Response.Head.parse(r.reader.head_buffer) catch return error.HttpHeadersInvalid, + .head = Response.Head.parse(head_buffer) catch return error.HttpHeadersInvalid, }; const head = &response.head; @@ -1121,7 +1120,6 @@ pub const Request = struct { _ = reader.discardRemaining() catch |err| switch (err) { error.ReadFailed => return r.reader.body_err.?, }; - r.reader.restituteHeadBuffer(); } const new_uri = r.uri.resolveInPlace(location.len, aux_buf) catch |err| switch (err) { error.UnexpectedCharacter => return error.HttpRedirectLocationInvalid, @@ -1302,12 +1300,13 @@ pub const basic_authorization = struct { } pub fn write(uri: Uri, out: *Writer) Writer.Error!void { - var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; + var buf: [max_user_len + 1 + max_password_len]u8 = undefined; var w: Writer = .fixed(&buf); - w.print("{fuser}:{fpassword}", .{ - uri.user orelse Uri.Component.empty, - uri.password orelse Uri.Component.empty, - }) catch unreachable; + const user: Uri.Component = uri.user orelse .empty; + const password: Uri.Component = uri.user orelse .empty; + user.formatUser(&w) catch unreachable; + w.writeByte(':') catch unreachable; + password.formatPassword(&w) catch unreachable; try out.print("Basic {b64}", .{w.buffered()}); } }; @@ -1697,6 +1696,7 @@ pub const FetchError = Uri.ParseError || RequestError || Request.ReceiveHeadErro StreamTooLong, /// TODO provide optional diagnostics when this occurs or break into more error codes WriteFailed, + UnsupportedCompressionMethod, }; /// Perform a one-shot HTTP request with the provided options. @@ -1748,7 +1748,8 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult { const decompress_buffer: []u8 = switch (response.head.content_encoding) { .identity => &.{}, .zstd => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.zstd.default_window_len), - else => options.decompress_buffer orelse try client.allocator.alloc(u8, 8 * 1024), + .deflate, .gzip => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.flate.max_window_len), + .compress => return error.UnsupportedCompressionMethod, }; defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer); diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 004741d1ae..aa6c72a5b3 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -6,7 +6,7 @@ const mem = std.mem; const Uri = std.Uri; const assert = std.debug.assert; const testing = std.testing; -const Writer = std.io.Writer; +const Writer = std.Io.Writer; const Server = @This(); @@ -21,7 +21,7 @@ reader: http.Reader, /// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`. /// /// The returned `Server` is ready for `receiveHead` to be called. -pub fn init(in: *std.io.Reader, out: *Writer) Server { +pub fn init(in: *std.Io.Reader, out: *Writer) Server { return .{ .reader = .{ .in = in, @@ -33,25 +33,22 @@ pub fn init(in: *std.io.Reader, out: *Writer) Server { }; } -pub fn deinit(s: *Server) void { - s.reader.restituteHeadBuffer(); -} - pub const ReceiveHeadError = http.Reader.HeadError || error{ /// Client sent headers that did not conform to the HTTP protocol. /// - /// To find out more detailed diagnostics, `http.Reader.head_buffer` can be + /// To find out more detailed diagnostics, `Request.head_buffer` can be /// passed directly to `Request.Head.parse`. HttpHeadersInvalid, }; pub fn receiveHead(s: *Server) ReceiveHeadError!Request { - try s.reader.receiveHead(); + const head_buffer = try s.reader.receiveHead(); return .{ .server = s, + .head_buffer = head_buffer, // No need to track the returned error here since users can repeat the // parse with the header buffer to get detailed diagnostics. - .head = Request.Head.parse(s.reader.head_buffer) catch return error.HttpHeadersInvalid, + .head = Request.Head.parse(head_buffer) catch return error.HttpHeadersInvalid, }; } @@ -60,6 +57,7 @@ pub const Request = struct { /// Pointers in this struct are invalidated with the next call to /// `receiveHead`. head: Head, + head_buffer: []const u8, respond_err: ?RespondError = null, pub const RespondError = error{ @@ -229,7 +227,7 @@ pub const Request = struct { pub fn iterateHeaders(r: *Request) http.HeaderIterator { assert(r.server.reader.state == .received_head); - return http.HeaderIterator.init(r.server.reader.head_buffer); + return http.HeaderIterator.init(r.head_buffer); } test iterateHeaders { @@ -244,7 +242,6 @@ pub const Request = struct { .reader = .{ .in = undefined, .state = .received_head, - .head_buffer = @constCast(request_bytes), .interface = undefined, }, .out = undefined, @@ -253,6 +250,7 @@ pub const Request = struct { var request: Request = .{ .server = &server, .head = undefined, + .head_buffer = @constCast(request_bytes), }; var it = request.iterateHeaders(); @@ -435,10 +433,8 @@ pub const Request = struct { for (o.extra_headers) |header| { assert(header.name.len != 0); - try out.writeAll(header.name); - try out.writeAll(": "); - try out.writeAll(header.value); - try out.writeAll("\r\n"); + var bufs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" }; + try out.writeVecAll(&bufs); } try out.writeAll("\r\n"); @@ -453,7 +449,13 @@ pub const Request = struct { return if (elide_body) .{ .http_protocol_output = request.server.out, .state = state, - .writer = .discarding(buffer), + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.elidingDrain, + .sendFile = http.BodyWriter.elidingSendFile, + }, + }, } else .{ .http_protocol_output = request.server.out, .state = state, @@ -564,7 +566,7 @@ pub const Request = struct { /// /// See `readerExpectNone` for an infallible alternative that cannot write /// to the server output stream. - pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*std.io.Reader { + pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*std.Io.Reader { const flush = request.head.expect != null; try writeExpectContinue(request); if (flush) try request.server.out.flush(); @@ -576,7 +578,7 @@ pub const Request = struct { /// this function. /// /// Asserts that this function is only called once. - pub fn readerExpectNone(request: *Request, buffer: []u8) *std.io.Reader { + pub fn readerExpectNone(request: *Request, buffer: []u8) *std.Io.Reader { assert(request.server.reader.state == .received_head); assert(request.head.expect == null); if (!request.head.method.requestHasBody()) return .ending; @@ -640,7 +642,7 @@ pub const Request = struct { /// See https://tools.ietf.org/html/rfc6455 pub const WebSocket = struct { key: []const u8, - input: *std.io.Reader, + input: *std.Io.Reader, output: *Writer, pub const Header0 = packed struct(u8) { diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 4c3466d5c9..df80ca6339 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -65,7 +65,7 @@ test "trailers" { try req.sendBodiless(); var response = try req.receiveHead(&.{}); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -183,7 +183,11 @@ test "echo content server" { if (request.head.expect) |expect_header_value| { if (mem.eql(u8, expect_header_value, "garbage")) { try expectError(error.HttpExpectationFailed, request.readerExpectContinue(&.{})); - try request.respond("", .{ .keep_alive = false }); + request.head.expect = null; + try request.respond("", .{ + .keep_alive = false, + .status = .expectation_failed, + }); continue; } } @@ -204,7 +208,7 @@ test "echo content server" { // request.head.target, //}); - const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .limited(8192)); + const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .unlimited); defer std.testing.allocator.free(body); try expect(mem.startsWith(u8, request.head.target, "/echo-content")); @@ -273,7 +277,6 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { for (0..500) |i| { try w.print("{d}, ah ha ha!\n", .{i}); } - try expectEqual(7390, w.count); try w.flush(); try response.end(); try expectEqual(.closing, server.reader.state); @@ -291,7 +294,7 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded var stream_reader = stream.reader(&tiny_buffer); - const response = try stream_reader.interface().allocRemaining(gpa, .limited(8192)); + const response = try stream_reader.interface().allocRemaining(gpa, .unlimited); defer gpa.free(response); var expected_response = std.ArrayList(u8).init(gpa); @@ -362,7 +365,7 @@ test "receiving arbitrary http headers from the client" { var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded var stream_reader = stream.reader(&tiny_buffer); - const response = try stream_reader.interface().allocRemaining(gpa, .limited(8192)); + const response = try stream_reader.interface().allocRemaining(gpa, .unlimited); defer gpa.free(response); var expected_response = std.ArrayList(u8).init(gpa); @@ -408,12 +411,10 @@ test "general client/server API coverage" { fn handleRequest(request: *http.Server.Request, listen_port: u16) !void { const log = std.log.scoped(.server); - log.info("{f} {s} {s}", .{ - request.head.method, @tagName(request.head.version), request.head.target, - }); + log.info("{f} {t} {s}", .{ request.head.method, request.head.version, request.head.target }); const gpa = std.testing.allocator; - const body = try (try request.readerExpectContinue(&.{})).allocRemaining(gpa, .limited(8192)); + const body = try (try request.readerExpectContinue(&.{})).allocRemaining(gpa, .unlimited); defer gpa.free(body); if (mem.startsWith(u8, request.head.target, "/get")) { @@ -447,7 +448,8 @@ test "general client/server API coverage" { try w.writeAll("Hello, World!\n"); } - try w.writeAll("Hello, World!\n" ** 1024); + var vec: [1][]const u8 = .{"Hello, World!\n"}; + try w.writeSplatAll(&vec, 1024); i = 0; while (i < 5) : (i += 1) { @@ -556,7 +558,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -579,7 +581,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192 * 1024)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len); @@ -601,7 +603,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); @@ -625,7 +627,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -648,7 +650,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); @@ -674,7 +676,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -703,7 +705,7 @@ test "general client/server API coverage" { try std.testing.expectEqual(.ok, response.head.status); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); @@ -740,7 +742,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -762,7 +764,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -784,7 +786,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -825,7 +827,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Encoded redirect successful!\n", body); @@ -915,7 +917,7 @@ test "Server streams both reading and writing" { try body_writer.writer.writeAll("fish"); try body_writer.end(); - const body = try response.reader(&.{}).allocRemaining(std.testing.allocator, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(std.testing.allocator, .unlimited); defer std.testing.allocator.free(body); try expectEqualStrings("ONE FISH", body); @@ -947,7 +949,7 @@ fn echoTests(client: *http.Client, port: u16) !void { var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -980,7 +982,7 @@ fn echoTests(client: *http.Client, port: u16) !void { var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1034,7 +1036,7 @@ fn echoTests(client: *http.Client, port: u16) !void { var response = try req.receiveHead(&redirect_buffer); try expectEqual(.ok, response.head.status); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1175,7 +1177,7 @@ test "redirect to different connection" { var response = try req.receiveHead(&redirect_buffer); var reader = response.reader(&.{}); - const body = try reader.allocRemaining(gpa, .limited(8192)); + const body = try reader.allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("good job, you pass", body);