From 2cdc0a8b504a501d3eb5184f096720e6a5dc351a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 3 Jan 2023 16:03:28 -0700 Subject: [PATCH 1/7] std.http.Client: do not heap allocate for requests --- lib/std/http/Client.zig | 105 +++++++++++++++------------------------- 1 file changed, 38 insertions(+), 67 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 8a4a771416..b6be5cee10 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -9,9 +9,8 @@ const net = std.net; const Client = @This(); const Url = std.Url; +/// TODO: remove this field (currently required due to tcpConnectToHost) allocator: std.mem.Allocator, -headers: std.ArrayListUnmanaged(u8) = .{}, -active_requests: usize = 0, ca_bundle: std.crypto.Certificate.Bundle = .{}, /// TODO: emit error.UnexpectedEndOfStream or something like that when the read @@ -20,44 +19,23 @@ ca_bundle: std.crypto.Certificate.Bundle = .{}, pub const Request = struct { client: *Client, stream: net.Stream, - headers: std.ArrayListUnmanaged(u8) = .{}, tls_client: std.crypto.tls.Client, protocol: Protocol, response_headers: http.Headers = .{}, - pub const Protocol = enum { http, https }; - - pub const Options = struct { + pub const Headers = struct { method: http.Method = .GET, + connection: Connection, + + pub const Connection = enum { + close, + @"keep-alive", + }; }; - pub fn deinit(req: *Request) void { - req.client.active_requests -= 1; - req.headers.deinit(req.client.allocator); - req.* = undefined; - } + pub const Protocol = enum { http, https }; - pub fn addHeader(req: *Request, name: []const u8, value: []const u8) !void { - const gpa = req.client.allocator; - // Ensure an extra +2 for the \r\n in end() - try req.headers.ensureUnusedCapacity(gpa, name.len + value.len + 6); - req.headers.appendSliceAssumeCapacity(name); - req.headers.appendSliceAssumeCapacity(": "); - req.headers.appendSliceAssumeCapacity(value); - req.headers.appendSliceAssumeCapacity("\r\n"); - } - - pub fn end(req: *Request) !void { - req.headers.appendSliceAssumeCapacity("\r\n"); - switch (req.protocol) { - .http => { - try req.stream.writeAll(req.headers.items); - }, - .https => { - try req.tls_client.writeAll(req.stream, req.headers.items); - }, - } - } + pub const Options = struct {}; pub fn readAll(req: *Request, buffer: []u8) !usize { return readAtLeast(req, buffer, buffer.len); @@ -113,13 +91,14 @@ pub const Request = struct { } }; -pub fn deinit(client: *Client) void { - assert(client.active_requests == 0); - client.headers.deinit(client.allocator); +pub fn deinit(client: *Client, gpa: std.mem.Allocator) void { + client.ca_bundle.deinit(gpa); client.* = undefined; } -pub fn request(client: *Client, url: Url, options: Request.Options) !Request { +pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Request.Options) !Request { + _ = options; // we have no options yet + const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse return error.UnsupportedUrlScheme; const port: u16 = url.port orelse switch (protocol) { @@ -133,8 +112,6 @@ pub fn request(client: *Client, url: Url, options: Request.Options) !Request { .protocol = protocol, .tls_client = undefined, }; - client.active_requests += 1; - errdefer req.deinit(); switch (protocol) { .http => {}, @@ -146,36 +123,30 @@ pub fn request(client: *Client, url: Url, options: Request.Options) !Request { }, } - try req.headers.ensureUnusedCapacity( - client.allocator, - @tagName(options.method).len + - 1 + - url.path.len + - " HTTP/1.1\r\nHost: ".len + - url.host.len + - "\r\nUpgrade-Insecure-Requests: 1\r\n".len + - client.headers.items.len + - 2, // for the \r\n at the end of headers - ); - req.headers.appendSliceAssumeCapacity(@tagName(options.method)); - req.headers.appendSliceAssumeCapacity(" "); - req.headers.appendSliceAssumeCapacity(url.path); - req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: "); - req.headers.appendSliceAssumeCapacity(url.host); - switch (protocol) { - .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"), - .http => req.headers.appendSliceAssumeCapacity("\r\n"), + { + var h = try std.BoundedArray(u8, 1000).init(0); + try h.appendSlice(@tagName(headers.method)); + try h.appendSlice(" "); + try h.appendSlice(url.path); + try h.appendSlice(" HTTP/1.1\r\nHost: "); + try h.appendSlice(url.host); + switch (protocol) { + .https => try h.appendSlice("\r\nUpgrade-Insecure-Requests: 1\r\n"), + .http => try h.appendSlice("\r\n"), + } + try h.writer().print("Connection: {s}\r\n", .{@tagName(headers.connection)}); + try h.appendSlice("\r\n"); + + const header_bytes = h.slice(); + switch (req.protocol) { + .http => { + try req.stream.writeAll(header_bytes); + }, + .https => { + try req.tls_client.writeAll(req.stream, header_bytes); + }, + } } - req.headers.appendSliceAssumeCapacity(client.headers.items); return req; } - -pub fn addHeader(client: *Client, name: []const u8, value: []const u8) !void { - const gpa = client.allocator; - try client.headers.ensureUnusedCapacity(gpa, name.len + value.len + 4); - client.headers.appendSliceAssumeCapacity(name); - client.headers.appendSliceAssumeCapacity(": "); - client.headers.appendSliceAssumeCapacity(value); - client.headers.appendSliceAssumeCapacity("\r\n"); -} From 5d9429579dbab98d5d565e291dfe89f9b8261526 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 3 Jan 2023 19:14:44 -0700 Subject: [PATCH 2/7] std.http.Headers.Parser: parse version and status --- lib/std/http.zig | 55 ++--------- lib/std/http/Client.zig | 10 +- lib/std/http/Headers.zig | 193 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+), 53 deletions(-) create mode 100644 lib/std/http/Headers.zig diff --git a/lib/std/http.zig b/lib/std/http.zig index 944271df27..d6d21d2c9e 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,4 +1,10 @@ pub const Client = @import("http/Client.zig"); +pub const Headers = @import("http/Headers.zig"); + +pub const Version = enum { + @"HTTP/1.0", + @"HTTP/1.1", +}; /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods /// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton @@ -242,55 +248,6 @@ pub const Status = enum(u10) { } }; -pub const Headers = struct { - state: State = .start, - invalid_index: u32 = undefined, - - pub const State = enum { invalid, start, line, nl_r, nl_n, nl2_r, finished }; - - /// Returns how many bytes are processed into headers. Always less than or - /// equal to bytes.len. If the amount returned is less than bytes.len, it - /// means the headers ended and the first byte after the double \r\n\r\n is - /// located at `bytes[result]`. - pub fn feed(h: *Headers, bytes: []const u8) usize { - for (bytes) |b, i| { - switch (h.state) { - .start => switch (b) { - '\r' => h.state = .nl_r, - '\n' => return invalid(h, i), - else => {}, - }, - .nl_r => switch (b) { - '\n' => h.state = .nl_n, - else => return invalid(h, i), - }, - .nl_n => switch (b) { - '\r' => h.state = .nl2_r, - else => h.state = .line, - }, - .nl2_r => switch (b) { - '\n' => h.state = .finished, - else => return invalid(h, i), - }, - .line => switch (b) { - '\r' => h.state = .nl_r, - '\n' => return invalid(h, i), - else => {}, - }, - .invalid => return i, - .finished => return i, - } - } - return bytes.len; - } - - fn invalid(h: *Headers, i: usize) usize { - h.invalid_index = @intCast(u32, i); - h.state = .invalid; - return i; - } -}; - const std = @import("std.zig"); test { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index b6be5cee10..0b5f06669a 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -21,7 +21,8 @@ pub const Request = struct { stream: net.Stream, tls_client: std.crypto.tls.Client, protocol: Protocol, - response_headers: http.Headers = .{}, + response_headers: http.Headers, + redirects_left: u32, pub const Headers = struct { method: http.Method = .GET, @@ -35,7 +36,9 @@ pub const Request = struct { pub const Protocol = enum { http, https }; - pub const Options = struct {}; + pub const Options = struct { + max_redirects: u32 = 3, + }; pub fn readAll(req: *Request, buffer: []u8) !usize { return readAtLeast(req, buffer, buffer.len); @@ -97,8 +100,6 @@ pub fn deinit(client: *Client, gpa: std.mem.Allocator) void { } pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Request.Options) !Request { - _ = options; // we have no options yet - const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse return error.UnsupportedUrlScheme; const port: u16 = url.port orelse switch (protocol) { @@ -111,6 +112,7 @@ pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Req .stream = try net.tcpConnectToHost(client.allocator, url.host, port), .protocol = protocol, .tls_client = undefined, + .redirects_left = options.max_redirects, }; switch (protocol) { diff --git a/lib/std/http/Headers.zig b/lib/std/http/Headers.zig new file mode 100644 index 0000000000..63666e1a53 --- /dev/null +++ b/lib/std/http/Headers.zig @@ -0,0 +1,193 @@ +status: http.Status, +version: http.Version, + +pub const Parser = struct { + state: State, + headers: Headers, + buffer: [16]u8, + buffer_index: u4, + + pub const init: Parser = .{ + .state = .start, + .headers = .{ + .status = undefined, + .version = undefined, + }, + .buffer = undefined, + .buffer_index = 0, + }; + + pub const State = enum { + invalid, + finished, + start, + expect_status, + find_start_line_end, + line, + line_r, + }; + + /// Returns how many bytes are processed into headers. Always less than or + /// equal to bytes.len. If the amount returned is less than bytes.len, it + /// means the headers ended and the first byte after the double \r\n\r\n is + /// located at `bytes[result]`. + pub fn feed(p: *Parser, bytes: []const u8) usize { + var index: usize = 0; + + while (bytes.len - index >= 16) { + index += p.feed16(bytes[index..][0..16]); + switch (p.state) { + .invalid, .finished => return index, + else => continue, + } + } + + while (index < bytes.len) { + var buffer = [1]u8{0} ** 16; + const src = bytes[index..bytes.len]; + std.mem.copy(u8, &buffer, src); + index += p.feed16(&buffer); + switch (p.state) { + .invalid, .finished => return index, + else => continue, + } + } + + return index; + } + + pub fn feed16(p: *Parser, chunk: *const [16]u8) u8 { + switch (p.state) { + .invalid, .finished => return 0, + .start => { + p.headers.version = switch (std.mem.readIntNative(u64, chunk[0..8])) { + std.mem.readIntNative(u64, "HTTP/1.0") => .@"HTTP/1.0", + std.mem.readIntNative(u64, "HTTP/1.1") => .@"HTTP/1.1", + else => return invalid(p, 0), + }; + p.state = .expect_status; + return 8; + }, + .expect_status => { + // example: " 200 OK\r\n" + // example; " 301 Moved Permanently\r\n" + switch (std.mem.readIntNative(u64, chunk[0..8])) { + std.mem.readIntNative(u64, " 200 OK\r") => { + if (chunk[8] != '\n') return invalid(p, 8); + p.headers.status = .ok; + p.state = .line; + return 9; + }, + std.mem.readIntNative(u64, " 301 Mov") => { + p.headers.status = .moved_permanently; + if (!std.mem.eql(u8, chunk[9..], "ed Perma")) + return invalid(p, 9); + p.state = .find_start_line_end; + return 16; + }, + else => { + if (chunk[0] != ' ') return invalid(p, 0); + const status = std.fmt.parseInt(u10, chunk[1..][0..3], 10) catch + return invalid(p, 1); + p.headers.status = @intToEnum(http.Status, status); + const v: @Vector(12, u8) = chunk[4..16].*; + const matches_r = v == @splat(12, @as(u8, '\r')); + const iota = std.simd.iota(u8, 12); + const default = @splat(12, @as(u8, 12)); + const index = 4 + @reduce(.Min, @select(u8, matches_r, iota, default)); + if (index >= 15) { + p.state = .find_start_line_end; + return index; + } + if (chunk[index + 1] != '\n') + return invalid(p, index + 1); + p.state = .line; + return index + 2; + }, + } + }, + .find_start_line_end => { + const v: @Vector(16, u8) = chunk.*; + const matches_r = v == @splat(16, @as(u8, '\r')); + const iota = std.simd.iota(u8, 16); + const default = @splat(16, @as(u8, 16)); + const index = @reduce(.Min, @select(u8, matches_r, iota, default)); + if (index >= 15) { + p.state = .find_start_line_end; + return index; + } + if (chunk[index + 1] != '\n') + return invalid(p, index + 1); + p.state = .line; + return index + 2; + }, + .line => { + const v: @Vector(16, u8) = chunk.*; + const matches_r = v == @splat(16, @as(u8, '\r')); + const iota = std.simd.iota(u8, 16); + const default = @splat(16, @as(u8, 16)); + const index = @reduce(.Min, @select(u8, matches_r, iota, default)); + if (index >= 15) { + return index; + } + if (chunk[index + 1] != '\n') + return invalid(p, index + 1); + if (index + 4 <= 16 and chunk[index + 2] == '\r') { + if (chunk[index + 3] != '\n') return invalid(p, index + 3); + p.state = .finished; + return index + 4; + } + p.state = .line_r; + return index + 2; + }, + .line_r => { + if (chunk[0] == '\r') { + if (chunk[1] != '\n') return invalid(p, 1); + p.state = .finished; + return 2; + } + p.state = .line; + // Here would be nice to use this proposal when it is implemented: + // https://github.com/ziglang/zig/issues/8220 + return 0; + }, + } + } + + fn invalid(p: *Parser, i: u8) u8 { + p.state = .invalid; + return i; + } +}; + +const std = @import("../std.zig"); +const http = std.http; +const Headers = @This(); +const testing = std.testing; + +test "status line ok" { + var p = Parser.init; + const line = "HTTP/1.1 200 OK\r\n"; + try testing.expect(p.feed(line) == line.len); + try testing.expectEqual(Parser.State.line, p.state); + try testing.expect(p.headers.version == .@"HTTP/1.1"); + try testing.expect(p.headers.status == .ok); +} + +test "status line non hot path long msg" { + var p = Parser.init; + const line = "HTTP/1.0 418 I'm a teapot\r\n"; + try testing.expect(p.feed(line) == line.len); + try testing.expectEqual(Parser.State.line, p.state); + try testing.expect(p.headers.version == .@"HTTP/1.0"); + try testing.expect(p.headers.status == .teapot); +} + +test "status line non hot path short msg" { + var p = Parser.init; + const line = "HTTP/1.1 418 lol\r\n"; + try testing.expect(p.feed(line) == line.len); + try testing.expectEqual(Parser.State.line, p.state); + try testing.expect(p.headers.version == .@"HTTP/1.1"); + try testing.expect(p.headers.status == .teapot); +} From 079f62881ee8291e1b290848c1081c822ea8389f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 4 Jan 2023 18:27:52 -0700 Subject: [PATCH 3/7] std.simd.iota: make it always called at comptime There's no reason for this to ever run at runtime; it should always be used to generate a constant. --- lib/std/simd.zig | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/lib/std/simd.zig b/lib/std/simd.zig index 95de3cc11c..868f9864e7 100644 --- a/lib/std/simd.zig +++ b/lib/std/simd.zig @@ -86,16 +86,18 @@ pub fn VectorCount(comptime VectorType: type) type { /// Returns a vector containing the first `len` integers in order from 0 to `len`-1. /// For example, `iota(i32, 8)` will return a vector containing `.{0, 1, 2, 3, 4, 5, 6, 7}`. -pub fn iota(comptime T: type, comptime len: usize) @Vector(len, T) { - var out: [len]T = undefined; - for (out) |*element, i| { - element.* = switch (@typeInfo(T)) { - .Int => @intCast(T, i), - .Float => @intToFloat(T, i), - else => @compileError("Can't use type " ++ @typeName(T) ++ " in iota."), - }; +pub inline fn iota(comptime T: type, comptime len: usize) @Vector(len, T) { + comptime { + var out: [len]T = undefined; + for (out) |*element, i| { + element.* = switch (@typeInfo(T)) { + .Int => @intCast(T, i), + .Float => @intToFloat(T, i), + else => @compileError("Can't use type " ++ @typeName(T) ++ " in iota."), + }; + } + return @as(@Vector(len, T), out); } - return @as(@Vector(len, T), out); } /// Returns a vector containing the same elements as the input, but repeated until the desired length is reached. From 8248fdbbdb88cc861c3d02a26ec4a214df3f9a1e Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 4 Jan 2023 18:28:37 -0700 Subject: [PATCH 4/7] std.http.Client: support HTTP redirects * std.http.Status.Class: add a "nonstandard" enum tag. Instead of having `class` return an optional value, it can potentially return nonstandard. * extract out std.http.Client.Connection from std.http.Client.Request - this code abstracts over plain/TLS only - this is the type that will potentially be stored in a client's LRU connection map * introduce two-staged HTTP header parsing - API users can rely on a heap-allocated buffer with a maximum limit, which defaults to 16 KB, or they can provide a static buffer that is borrowed by the Request instance. - The entire HTTP header is buffered because there are strings in there and they must be accessed later, such as with the case of HTTP redirects. - When buffering the HTTP header, the parser only looks for the \r\n\r\n pattern. Further validation is done later. - After the full HTTP header is buffered, it is parsed into components such as Content-Length and Location. * HTTP redirects are handled, with a maximum redirect count option that defaults to 3. - Connection: close is always used for now; implementing keep-alive connections and an LRU connection pool in std.http.Client is a task for another day. see #2007 --- lib/std/http.zig | 7 +- lib/std/http/Client.zig | 537 ++++++++++++++++++++++++++++++++++----- lib/std/http/Headers.zig | 193 -------------- 3 files changed, 472 insertions(+), 265 deletions(-) delete mode 100644 lib/std/http/Headers.zig diff --git a/lib/std/http.zig b/lib/std/http.zig index d6d21d2c9e..baeb7679f1 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,5 +1,4 @@ pub const Client = @import("http/Client.zig"); -pub const Headers = @import("http/Headers.zig"); pub const Version = enum { @"HTTP/1.0", @@ -219,6 +218,7 @@ pub const Status = enum(u10) { } pub const Class = enum { + nonstandard, informational, success, redirect, @@ -226,14 +226,14 @@ pub const Status = enum(u10) { server_error, }; - pub fn class(self: Status) ?Class { + pub fn class(self: Status) Class { return switch (@enumToInt(self)) { 100...199 => .informational, 200...299 => .success, 300...399 => .redirect, 400...499 => .client_error, 500...599 => .server_error, - else => null, + else => .nonstandard, }; } @@ -254,5 +254,4 @@ test { _ = Client; _ = Method; _ = Status; - _ = Headers; } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 0b5f06669a..72e9321a7c 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1,45 +1,419 @@ //! This API is a barely-touched, barely-functional http client, just the //! absolute minimum thing I needed in order to test `std.crypto.tls`. Bear //! with me and I promise the API will become useful and streamlined. +//! +//! TODO: send connection: keep-alive and LRU cache a configurable number of +//! open connections to skip DNS and TLS handshake for subsequent requests. const std = @import("../std.zig"); +const mem = std.mem; const assert = std.debug.assert; const http = std.http; const net = std.net; const Client = @This(); const Url = std.Url; +const Allocator = std.mem.Allocator; +const testing = std.testing; -/// TODO: remove this field (currently required due to tcpConnectToHost) -allocator: std.mem.Allocator, +/// Used for tcpConnectToHost and storing HTTP headers when an externally +/// managed buffer is not provided. +allocator: Allocator, ca_bundle: std.crypto.Certificate.Bundle = .{}, +pub const Connection = struct { + stream: net.Stream, + /// undefined unless protocol is tls. + tls_client: std.crypto.tls.Client, + protocol: Protocol, + + pub const Protocol = enum { plain, tls }; + + pub fn read(conn: *Connection, buffer: []u8) !usize { + switch (conn.protocol) { + .plain => return conn.stream.read(buffer), + .tls => return conn.tls_client.read(conn.stream, buffer), + } + } + + pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize { + switch (conn.protocol) { + .plain => return conn.stream.readAtLeast(buffer, len), + .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len), + } + } + + pub fn writeAll(conn: *Connection, buffer: []const u8) !void { + switch (conn.protocol) { + .plain => return conn.stream.writeAll(buffer), + .tls => return conn.tls_client.writeAll(conn.stream, buffer), + } + } + + pub fn write(conn: *Connection, buffer: []const u8) !usize { + switch (conn.protocol) { + .plain => return conn.stream.write(buffer), + .tls => return conn.tls_client.write(conn.stream, buffer), + } + } +}; + /// TODO: emit error.UnexpectedEndOfStream or something like that when the read /// data does not match the content length. This is necessary since HTTPS disables /// close_notify protection on underlying TLS streams. pub const Request = struct { client: *Client, - stream: net.Stream, - tls_client: std.crypto.tls.Client, - protocol: Protocol, - response_headers: http.Headers, + connection: Connection, redirects_left: u32, + response: Response, + /// These are stored in Request so that they are available when following + /// redirects. + headers: Headers, + + pub const Response = struct { + headers: Response.Headers, + state: State, + header_bytes_owned: bool, + /// This could either be a fixed buffer provided by the API user or it + /// could be our own array list. + header_bytes: std.ArrayListUnmanaged(u8), + max_header_bytes: usize, + + pub const Headers = struct { + location: ?[]const u8 = null, + status: http.Status, + version: http.Version, + content_length: ?u64 = null, + + pub fn parse(bytes: []const u8) !Response.Headers { + var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); + + const first_line = it.first(); + if (first_line.len < 12) + return error.ShortHttpStatusLine; + + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.BadHttpVersion, + }; + if (first_line[8] != ' ') return error.InvalidHttpHeaders; + const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); + + var headers: Response.Headers = .{ + .version = version, + .status = status, + }; + + while (it.next()) |line| { + var line_it = mem.split(u8, line, ": "); + const header_name = line_it.first(); + const header_value = line_it.rest(); + if (std.ascii.eqlIgnoreCase(header_name, "location")) { + headers.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + headers.content_length = try std.fmt.parseInt(u64, header_value, 10); + } + } + + return headers; + } + + test "parse headers" { + const example = + "HTTP/1.1 301 Moved Permanently\r\n" ++ + "Location: https://www.example.com/\r\n" ++ + "Content-Type: text/html; charset=UTF-8\r\n" ++ + "Content-Length: 220\r\n\r\n"; + const parsed = try Response.Headers.parse(example); + try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version); + try testing.expectEqual(http.Status.moved_permanently, parsed.status); + try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse + return error.TestFailed); + try testing.expectEqual(@as(?u64, 220), parsed.content_length); + } + }; + + pub const State = enum { + invalid, + finished, + start, + seen_r, + seen_rn, + seen_rnr, + }; + + pub fn initDynamic(max: usize) Response { + return .{ + .state = .start, + .headers = undefined, + .header_bytes = .{}, + .max_header_bytes = max, + .header_bytes_owned = true, + }; + } + + pub fn initStatic(buf: []u8) Response { + return .{ + .state = .start, + .headers = undefined, + .header_bytes = .{ .items = buf[0..0], .capacity = buf.len }, + .max_header_bytes = buf.len, + .header_bytes_owned = false, + }; + } + + /// Returns how many bytes are part of HTTP headers. Always less than or + /// equal to bytes.len. If the amount returned is less than bytes.len, it + /// means the headers ended and the first byte after the double \r\n\r\n is + /// located at `bytes[result]`. + pub fn findHeadersEnd(r: *Response, bytes: []const u8) usize { + var index: usize = 0; + + // TODO: https://github.com/ziglang/zig/issues/8220 + state: while (true) { + switch (r.state) { + .invalid => unreachable, + .finished => unreachable, + .start => while (true) { + switch (bytes.len - index) { + 0 => return index, + 1 => { + if (bytes[index] == '\r') + r.state = .seen_r; + return index + 1; + }, + 2 => { + if (int16(bytes[index..][0..2]) == int16("\r\n")) { + r.state = .seen_rn; + } else if (bytes[index + 1] == '\r') { + r.state = .seen_r; + } + return index + 2; + }, + 3 => { + if (int16(bytes[index..][0..2]) == int16("\r\n") and + bytes[index + 2] == '\r') + { + r.state = .seen_rnr; + } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n")) { + r.state = .seen_rn; + } else if (bytes[index + 2] == '\r') { + r.state = .seen_r; + } + return index + 3; + }, + 4...15 => { + if (int32(bytes[index..][0..4]) == int32("\r\n\r\n")) { + r.state = .finished; + return index + 4; + } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n") and + bytes[index + 3] == '\r') + { + r.state = .seen_rnr; + index += 4; + continue :state; + } else if (int16(bytes[index + 2 ..][0..2]) == int16("\r\n")) { + r.state = .seen_rn; + index += 4; + continue :state; + } else if (bytes[index + 3] == '\r') { + r.state = .seen_r; + index += 4; + continue :state; + } + index += 4; + continue; + }, + else => { + const chunk = bytes[index..][0..16]; + const v: @Vector(16, u8) = chunk.*; + const matches_r = v == @splat(16, @as(u8, '\r')); + const iota = std.simd.iota(u8, 16); + const default = @splat(16, @as(u8, 16)); + const sub_index = @reduce(.Min, @select(u8, matches_r, iota, default)); + switch (sub_index) { + 0...12 => { + index += sub_index + 4; + if (int32(chunk[sub_index..][0..4]) == int32("\r\n\r\n")) { + r.state = .finished; + return index; + } + continue; + }, + 13 => { + index += 16; + if (int16(chunk[14..][0..2]) == int16("\n\r")) { + r.state = .seen_rnr; + continue :state; + } + continue; + }, + 14 => { + index += 16; + if (chunk[15] == '\n') { + r.state = .seen_rn; + continue :state; + } + continue; + }, + 15 => { + r.state = .seen_r; + index += 16; + continue :state; + }, + 16 => { + index += 16; + continue; + }, + else => unreachable, + } + }, + } + }, + + .seen_r => switch (bytes.len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\n' => r.state = .seen_rn, + '\r' => r.state = .seen_r, + else => r.state = .start, + } + return index + 1; + }, + 2 => { + if (int16(bytes[index..][0..2]) == int16("\n\r")) { + r.state = .seen_rnr; + return index + 2; + } + r.state = .start; + return index + 2; + }, + else => { + if (int16(bytes[index..][0..2]) == int16("\n\r") and + bytes[index + 2] == '\n') + { + r.state = .finished; + return index + 3; + } + index += 3; + r.state = .start; + continue :state; + }, + }, + .seen_rn => switch (bytes.len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => r.state = .seen_rnr, + else => r.state = .start, + } + return index + 1; + }, + else => { + if (int16(bytes[index..][0..2]) == int16("\r\n")) { + r.state = .finished; + return index + 2; + } + index += 2; + r.state = .start; + continue :state; + }, + }, + .seen_rnr => switch (bytes.len - index) { + 0 => return index, + else => { + if (bytes[index] == '\n') { + r.state = .finished; + return index + 1; + } + index += 1; + r.state = .start; + continue :state; + }, + }, + } + + return index; + } + } + + fn parseInt3(nnn: @Vector(3, u8)) u10 { + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); + } + + test parseInt3 { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000".*)); + try expectEqual(@as(u10, 418), parseInt3("418".*)); + try expectEqual(@as(u10, 999), parseInt3("999".*)); + } + + inline fn int16(array: *const [2]u8) u16 { + return @bitCast(u16, array.*); + } + + inline fn int32(array: *const [4]u8) u32 { + return @bitCast(u32, array.*); + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(u64, array.*); + } + + test "find headers end basic" { + var buffer: [1]u8 = undefined; + var r = Response.initStatic(&buffer); + try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4")); + try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18")); + try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah")); + } + + test "find headers end vectorized" { + var buffer: [1]u8 = undefined; + var r = Response.initStatic(&buffer); + const example = + "HTTP/1.1 301 Moved Permanently\r\n" ++ + "Location: https://www.example.com/\r\n" ++ + "Content-Type: text/html; charset=UTF-8\r\n" ++ + "Content-Length: 220\r\n" ++ + "\r\ncontent"; + try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example)); + } + }; pub const Headers = struct { method: http.Method = .GET, - connection: Connection, - - pub const Connection = enum { - close, - @"keep-alive", - }; }; - pub const Protocol = enum { http, https }; - pub const Options = struct { max_redirects: u32 = 3, + header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 }, + + pub const HeaderStrategy = union(enum) { + /// In this case, the client's Allocator will be used to store the + /// entire HTTP header. This value is the maximum total size of + /// HTTP headers allowed, otherwise + /// error.HttpHeadersExceededSizeLimit is returned from read(). + dynamic: usize, + /// This is used to store the entire HTTP header. If the HTTP + /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` + /// is returned from read(). When this is used, `error.OutOfMemory` + /// cannot be returned from `read()`. + static: []u8, + }; }; + /// May be skipped if header strategy is buffer. + pub fn deinit(req: *Request) void { + if (req.response.header_bytes_owned) { + req.response.header_bytes.deinit(req.client.allocator); + } + req.* = undefined; + } + pub fn readAll(req: *Request, buffer: []u8) !usize { return readAtLeast(req, buffer, buffer.len); } @@ -52,7 +426,7 @@ pub const Request = struct { assert(len <= buffer.len); var index: usize = 0; while (index < len) { - const headers_finished = req.response_headers.state == .finished; + const headers_finished = req.response.state == .finished; const amt = try readAdvanced(req, buffer[index..]); if (amt == 0 and headers_finished) break; index += amt; @@ -63,67 +437,102 @@ pub const Request = struct { /// This one can return 0 without meaning EOF. /// TODO change to readvAdvanced pub fn readAdvanced(req: *Request, buffer: []u8) !usize { - if (req.response_headers.state == .finished) return readRaw(req, buffer); + if (req.response.state == .finished) return req.connection.read(buffer); - const amt = try readRaw(req, buffer); + const amt = try req.connection.read(buffer); const data = buffer[0..amt]; - const i = req.response_headers.feed(data); - if (req.response_headers.state == .invalid) return error.InvalidHttpHeaders; - if (i < data.len) { - const rest = data[i..]; - std.mem.copy(u8, buffer, rest); - return rest.len; + const i = req.response.findHeadersEnd(data); + if (req.response.state == .invalid) return error.InvalidHttpHeaders; + + const headers_data = data[0..i]; + if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { + return error.HttpHeadersExceededSizeLimit; + } + try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); + + if (req.response.state == .finished) { + req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); + } + + if (req.response.headers.status.class() == .redirect) { + if (req.redirects_left == 0) return error.TooManyHttpRedirects; + const location = req.response.headers.location orelse + return error.HttpRedirectMissingLocation; + const new_url = try std.Url.parse(location); + const new_req = try req.client.request(new_url, req.headers, .{ + .max_redirects = req.redirects_left - 1, + .header_strategy = if (req.response.header_bytes_owned) .{ + .dynamic = req.response.max_header_bytes, + } else .{ + .static = req.response.header_bytes.unusedCapacitySlice(), + }, + }); + req.deinit(); + req.* = new_req; + return readAdvanced(req, buffer); + } + + const body_data = data[i..]; + if (body_data.len > 0) { + mem.copy(u8, buffer, body_data); + return body_data.len; } return 0; } - /// Only abstracts over http/https. - fn readRaw(req: *Request, buffer: []u8) !usize { - switch (req.protocol) { - .http => return req.stream.read(buffer), - .https => return req.tls_client.read(req.stream, buffer), - } - } - - /// Only abstracts over http/https. - fn readAtLeastRaw(req: *Request, buffer: []u8, len: usize) !usize { - switch (req.protocol) { - .http => return req.stream.readAtLeast(buffer, len), - .https => return req.tls_client.readAtLeast(req.stream, buffer, len), - } + test { + _ = Response; } }; -pub fn deinit(client: *Client, gpa: std.mem.Allocator) void { +pub fn deinit(client: *Client, gpa: Allocator) void { client.ca_bundle.deinit(gpa); client.* = undefined; } +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !Connection { + var conn: Connection = .{ + .stream = try net.tcpConnectToHost(client.allocator, host, port), + .tls_client = undefined, + .protocol = protocol, + }; + + switch (protocol) { + .plain => {}, + .tls => { + conn.tls_client = try std.crypto.tls.Client.init(conn.stream, client.ca_bundle, host); + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + conn.tls_client.allow_truncation_attacks = true; + }, + } + + return conn; +} + pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Request.Options) !Request { - const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse + const protocol: Connection.Protocol = if (mem.eql(u8, url.scheme, "http")) + .plain + else if (mem.eql(u8, url.scheme, "https")) + .tls + else return error.UnsupportedUrlScheme; + const port: u16 = url.port orelse switch (protocol) { - .http => 80, - .https => 443, + .plain => 80, + .tls => 443, }; var req: Request = .{ .client = client, - .stream = try net.tcpConnectToHost(client.allocator, url.host, port), - .protocol = protocol, - .tls_client = undefined, + .headers = headers, + .connection = try client.connect(url.host, port, protocol), .redirects_left = options.max_redirects, - }; - - switch (protocol) { - .http => {}, - .https => { - req.tls_client = try std.crypto.tls.Client.init(req.stream, client.ca_bundle, url.host); - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - req.tls_client.allow_truncation_attacks = true; + .response = switch (options.header_strategy) { + .dynamic => |max| Request.Response.initDynamic(max), + .static => |buf| Request.Response.initStatic(buf), }, - } + }; { var h = try std.BoundedArray(u8, 1000).init(0); @@ -132,23 +541,15 @@ pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Req try h.appendSlice(url.path); try h.appendSlice(" HTTP/1.1\r\nHost: "); try h.appendSlice(url.host); - switch (protocol) { - .https => try h.appendSlice("\r\nUpgrade-Insecure-Requests: 1\r\n"), - .http => try h.appendSlice("\r\n"), - } - try h.writer().print("Connection: {s}\r\n", .{@tagName(headers.connection)}); - try h.appendSlice("\r\n"); + try h.appendSlice("\r\nConnection: close\r\n\r\n"); const header_bytes = h.slice(); - switch (req.protocol) { - .http => { - try req.stream.writeAll(header_bytes); - }, - .https => { - try req.tls_client.writeAll(req.stream, header_bytes); - }, - } + try req.connection.writeAll(header_bytes); } return req; } + +test { + _ = Request; +} diff --git a/lib/std/http/Headers.zig b/lib/std/http/Headers.zig deleted file mode 100644 index 63666e1a53..0000000000 --- a/lib/std/http/Headers.zig +++ /dev/null @@ -1,193 +0,0 @@ -status: http.Status, -version: http.Version, - -pub const Parser = struct { - state: State, - headers: Headers, - buffer: [16]u8, - buffer_index: u4, - - pub const init: Parser = .{ - .state = .start, - .headers = .{ - .status = undefined, - .version = undefined, - }, - .buffer = undefined, - .buffer_index = 0, - }; - - pub const State = enum { - invalid, - finished, - start, - expect_status, - find_start_line_end, - line, - line_r, - }; - - /// Returns how many bytes are processed into headers. Always less than or - /// equal to bytes.len. If the amount returned is less than bytes.len, it - /// means the headers ended and the first byte after the double \r\n\r\n is - /// located at `bytes[result]`. - pub fn feed(p: *Parser, bytes: []const u8) usize { - var index: usize = 0; - - while (bytes.len - index >= 16) { - index += p.feed16(bytes[index..][0..16]); - switch (p.state) { - .invalid, .finished => return index, - else => continue, - } - } - - while (index < bytes.len) { - var buffer = [1]u8{0} ** 16; - const src = bytes[index..bytes.len]; - std.mem.copy(u8, &buffer, src); - index += p.feed16(&buffer); - switch (p.state) { - .invalid, .finished => return index, - else => continue, - } - } - - return index; - } - - pub fn feed16(p: *Parser, chunk: *const [16]u8) u8 { - switch (p.state) { - .invalid, .finished => return 0, - .start => { - p.headers.version = switch (std.mem.readIntNative(u64, chunk[0..8])) { - std.mem.readIntNative(u64, "HTTP/1.0") => .@"HTTP/1.0", - std.mem.readIntNative(u64, "HTTP/1.1") => .@"HTTP/1.1", - else => return invalid(p, 0), - }; - p.state = .expect_status; - return 8; - }, - .expect_status => { - // example: " 200 OK\r\n" - // example; " 301 Moved Permanently\r\n" - switch (std.mem.readIntNative(u64, chunk[0..8])) { - std.mem.readIntNative(u64, " 200 OK\r") => { - if (chunk[8] != '\n') return invalid(p, 8); - p.headers.status = .ok; - p.state = .line; - return 9; - }, - std.mem.readIntNative(u64, " 301 Mov") => { - p.headers.status = .moved_permanently; - if (!std.mem.eql(u8, chunk[9..], "ed Perma")) - return invalid(p, 9); - p.state = .find_start_line_end; - return 16; - }, - else => { - if (chunk[0] != ' ') return invalid(p, 0); - const status = std.fmt.parseInt(u10, chunk[1..][0..3], 10) catch - return invalid(p, 1); - p.headers.status = @intToEnum(http.Status, status); - const v: @Vector(12, u8) = chunk[4..16].*; - const matches_r = v == @splat(12, @as(u8, '\r')); - const iota = std.simd.iota(u8, 12); - const default = @splat(12, @as(u8, 12)); - const index = 4 + @reduce(.Min, @select(u8, matches_r, iota, default)); - if (index >= 15) { - p.state = .find_start_line_end; - return index; - } - if (chunk[index + 1] != '\n') - return invalid(p, index + 1); - p.state = .line; - return index + 2; - }, - } - }, - .find_start_line_end => { - const v: @Vector(16, u8) = chunk.*; - const matches_r = v == @splat(16, @as(u8, '\r')); - const iota = std.simd.iota(u8, 16); - const default = @splat(16, @as(u8, 16)); - const index = @reduce(.Min, @select(u8, matches_r, iota, default)); - if (index >= 15) { - p.state = .find_start_line_end; - return index; - } - if (chunk[index + 1] != '\n') - return invalid(p, index + 1); - p.state = .line; - return index + 2; - }, - .line => { - const v: @Vector(16, u8) = chunk.*; - const matches_r = v == @splat(16, @as(u8, '\r')); - const iota = std.simd.iota(u8, 16); - const default = @splat(16, @as(u8, 16)); - const index = @reduce(.Min, @select(u8, matches_r, iota, default)); - if (index >= 15) { - return index; - } - if (chunk[index + 1] != '\n') - return invalid(p, index + 1); - if (index + 4 <= 16 and chunk[index + 2] == '\r') { - if (chunk[index + 3] != '\n') return invalid(p, index + 3); - p.state = .finished; - return index + 4; - } - p.state = .line_r; - return index + 2; - }, - .line_r => { - if (chunk[0] == '\r') { - if (chunk[1] != '\n') return invalid(p, 1); - p.state = .finished; - return 2; - } - p.state = .line; - // Here would be nice to use this proposal when it is implemented: - // https://github.com/ziglang/zig/issues/8220 - return 0; - }, - } - } - - fn invalid(p: *Parser, i: u8) u8 { - p.state = .invalid; - return i; - } -}; - -const std = @import("../std.zig"); -const http = std.http; -const Headers = @This(); -const testing = std.testing; - -test "status line ok" { - var p = Parser.init; - const line = "HTTP/1.1 200 OK\r\n"; - try testing.expect(p.feed(line) == line.len); - try testing.expectEqual(Parser.State.line, p.state); - try testing.expect(p.headers.version == .@"HTTP/1.1"); - try testing.expect(p.headers.status == .ok); -} - -test "status line non hot path long msg" { - var p = Parser.init; - const line = "HTTP/1.0 418 I'm a teapot\r\n"; - try testing.expect(p.feed(line) == line.len); - try testing.expectEqual(Parser.State.line, p.state); - try testing.expect(p.headers.version == .@"HTTP/1.0"); - try testing.expect(p.headers.status == .teapot); -} - -test "status line non hot path short msg" { - var p = Parser.init; - const line = "HTTP/1.1 418 lol\r\n"; - try testing.expect(p.feed(line) == line.len); - try testing.expectEqual(Parser.State.line, p.state); - try testing.expect(p.headers.version == .@"HTTP/1.1"); - try testing.expect(p.headers.status == .teapot); -} From ba1e53f116b27f97828d495af2ea5b87fd9632cb Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 4 Jan 2023 21:07:24 -0700 Subject: [PATCH 5/7] avoid triggering LLVM bug on MIPS See #13782 --- lib/std/http/Client.zig | 7 +++++++ test/behavior/bitcast.zig | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 72e9321a7c..9a247208e3 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -551,5 +551,12 @@ pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Req } test { + const builtin = @import("builtin"); + const native_endian = comptime builtin.cpu.arch.endian(); + if (builtin.zig_backend == .stage2_llvm and native_endian == .Big) { + // https://github.com/ziglang/zig/issues/13782 + return error.SkipZigTest; + } + _ = Request; } diff --git a/test/behavior/bitcast.zig b/test/behavior/bitcast.zig index 0e8ff65414..0c9e96c67e 100644 --- a/test/behavior/bitcast.zig +++ b/test/behavior/bitcast.zig @@ -109,7 +109,7 @@ fn testBitCastuXToBytes(comptime N: usize) !void { const bytes = std.mem.asBytes(&x); const byte_count = (N + 7) / 8; - switch (builtin.cpu.arch.endian()) { + switch (native_endian) { .Little => { var byte_i = 0; while (byte_i < (byte_count - 1)) : (byte_i += 1) { @@ -333,7 +333,7 @@ test "comptime @bitCast packed struct to int and back" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (comptime builtin.zig_backend == .stage2_llvm and builtin.cpu.arch.endian() == .Big) { + if (builtin.zig_backend == .stage2_llvm and native_endian == .Big) { // https://github.com/ziglang/zig/issues/13782 return error.SkipZigTest; } From 450f3bc9250eca86b470ec29c228165054016c76 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 5 Jan 2023 13:31:26 -0700 Subject: [PATCH 6/7] std.http.Class: classify out-of-range codes as server_error RFC 9110 section 15: Values outside the range 100..599 are invalid. Implementations often use three-digit integer values outside of that range (i.e., 600..999) for internal communication of non-HTTP status (e.g., library errors). A client that receives a response with an invalid status code SHOULD process the response as if it had a 5xx (Server Error) status code. --- lib/std/http.zig | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/std/http.zig b/lib/std/http.zig index baeb7679f1..f9ac019495 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -218,7 +218,6 @@ pub const Status = enum(u10) { } pub const Class = enum { - nonstandard, informational, success, redirect, @@ -232,8 +231,7 @@ pub const Status = enum(u10) { 200...299 => .success, 300...399 => .redirect, 400...499 => .client_error, - 500...599 => .server_error, - else => .nonstandard, + else => .server_error, }; } From 3055ab7f8639deca318f238f21680776a7149acb Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 5 Jan 2023 13:39:17 -0700 Subject: [PATCH 7/7] std.http.Client: fail header parsing under more conditions * when HTTP header continuations are used * when content-type or location header occurs more than once --- lib/std/http/Client.zig | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 9a247208e3..c6262f4706 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -96,7 +96,7 @@ pub const Request = struct { int64("HTTP/1.1") => .@"HTTP/1.1", else => return error.BadHttpVersion, }; - if (first_line[8] != ' ') return error.InvalidHttpHeaders; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); var headers: Response.Headers = .{ @@ -105,12 +105,19 @@ pub const Request = struct { }; while (it.next()) |line| { + if (line.len == 0) return error.HttpHeadersInvalid; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } var line_it = mem.split(u8, line, ": "); const header_name = line_it.first(); const header_value = line_it.rest(); if (std.ascii.eqlIgnoreCase(header_name, "location")) { + if (headers.location != null) return error.HttpHeadersInvalid; headers.location = header_value; } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (headers.content_length != null) return error.HttpHeadersInvalid; headers.content_length = try std.fmt.parseInt(u64, header_value, 10); } } @@ -131,6 +138,29 @@ pub const Request = struct { return error.TestFailed); try testing.expectEqual(@as(?u64, 220), parsed.content_length); } + + test "header continuation" { + const example = + "HTTP/1.0 200 OK\r\n" ++ + "Content-Type: text/html;\r\n charset=UTF-8\r\n" ++ + "Content-Length: 220\r\n\r\n"; + try testing.expectError( + error.HttpHeaderContinuationsUnsupported, + Response.Headers.parse(example), + ); + } + + test "extra content length" { + const example = + "HTTP/1.0 200 OK\r\n" ++ + "Content-Length: 220\r\n" ++ + "Content-Type: text/html; charset=UTF-8\r\n" ++ + "content-length: 220\r\n\r\n"; + try testing.expectError( + error.HttpHeadersInvalid, + Response.Headers.parse(example), + ); + } }; pub const State = enum { @@ -442,7 +472,7 @@ pub const Request = struct { const amt = try req.connection.read(buffer); const data = buffer[0..amt]; const i = req.response.findHeadersEnd(data); - if (req.response.state == .invalid) return error.InvalidHttpHeaders; + if (req.response.state == .invalid) return error.HttpHeadersInvalid; const headers_data = data[0..i]; if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) {