diff --git a/lib/std/http.zig b/lib/std/http.zig index 944271df27..d6d21d2c9e 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,4 +1,10 @@ pub const Client = @import("http/Client.zig"); +pub const Headers = @import("http/Headers.zig"); + +pub const Version = enum { + @"HTTP/1.0", + @"HTTP/1.1", +}; /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods /// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton @@ -242,55 +248,6 @@ pub const Status = enum(u10) { } }; -pub const Headers = struct { - state: State = .start, - invalid_index: u32 = undefined, - - pub const State = enum { invalid, start, line, nl_r, nl_n, nl2_r, finished }; - - /// Returns how many bytes are processed into headers. Always less than or - /// equal to bytes.len. If the amount returned is less than bytes.len, it - /// means the headers ended and the first byte after the double \r\n\r\n is - /// located at `bytes[result]`. - pub fn feed(h: *Headers, bytes: []const u8) usize { - for (bytes) |b, i| { - switch (h.state) { - .start => switch (b) { - '\r' => h.state = .nl_r, - '\n' => return invalid(h, i), - else => {}, - }, - .nl_r => switch (b) { - '\n' => h.state = .nl_n, - else => return invalid(h, i), - }, - .nl_n => switch (b) { - '\r' => h.state = .nl2_r, - else => h.state = .line, - }, - .nl2_r => switch (b) { - '\n' => h.state = .finished, - else => return invalid(h, i), - }, - .line => switch (b) { - '\r' => h.state = .nl_r, - '\n' => return invalid(h, i), - else => {}, - }, - .invalid => return i, - .finished => return i, - } - } - return bytes.len; - } - - fn invalid(h: *Headers, i: usize) usize { - h.invalid_index = @intCast(u32, i); - h.state = .invalid; - return i; - } -}; - const std = @import("std.zig"); test { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index b6be5cee10..0b5f06669a 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -21,7 +21,8 @@ pub const Request = struct { stream: net.Stream, tls_client: std.crypto.tls.Client, protocol: Protocol, - response_headers: http.Headers = .{}, + response_headers: http.Headers, + redirects_left: u32, pub const Headers = struct { method: http.Method = .GET, @@ -35,7 +36,9 @@ pub const Request = struct { pub const Protocol = enum { http, https }; - pub const Options = struct {}; + pub const Options = struct { + max_redirects: u32 = 3, + }; pub fn readAll(req: *Request, buffer: []u8) !usize { return readAtLeast(req, buffer, buffer.len); @@ -97,8 +100,6 @@ pub fn deinit(client: *Client, gpa: std.mem.Allocator) void { } pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Request.Options) !Request { - _ = options; // we have no options yet - const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse return error.UnsupportedUrlScheme; const port: u16 = url.port orelse switch (protocol) { @@ -111,6 +112,7 @@ pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Req .stream = try net.tcpConnectToHost(client.allocator, url.host, port), .protocol = protocol, .tls_client = undefined, + .redirects_left = options.max_redirects, }; switch (protocol) { diff --git a/lib/std/http/Headers.zig b/lib/std/http/Headers.zig new file mode 100644 index 0000000000..63666e1a53 --- /dev/null +++ b/lib/std/http/Headers.zig @@ -0,0 +1,193 @@ +status: http.Status, +version: http.Version, + +pub const Parser = struct { + state: State, + headers: Headers, + buffer: [16]u8, + buffer_index: u4, + + pub const init: Parser = .{ + .state = .start, + .headers = .{ + .status = undefined, + .version = undefined, + }, + .buffer = undefined, + .buffer_index = 0, + }; + + pub const State = enum { + invalid, + finished, + start, + expect_status, + find_start_line_end, + line, + line_r, + }; + + /// Returns how many bytes are processed into headers. Always less than or + /// equal to bytes.len. If the amount returned is less than bytes.len, it + /// means the headers ended and the first byte after the double \r\n\r\n is + /// located at `bytes[result]`. + pub fn feed(p: *Parser, bytes: []const u8) usize { + var index: usize = 0; + + while (bytes.len - index >= 16) { + index += p.feed16(bytes[index..][0..16]); + switch (p.state) { + .invalid, .finished => return index, + else => continue, + } + } + + while (index < bytes.len) { + var buffer = [1]u8{0} ** 16; + const src = bytes[index..bytes.len]; + std.mem.copy(u8, &buffer, src); + index += p.feed16(&buffer); + switch (p.state) { + .invalid, .finished => return index, + else => continue, + } + } + + return index; + } + + pub fn feed16(p: *Parser, chunk: *const [16]u8) u8 { + switch (p.state) { + .invalid, .finished => return 0, + .start => { + p.headers.version = switch (std.mem.readIntNative(u64, chunk[0..8])) { + std.mem.readIntNative(u64, "HTTP/1.0") => .@"HTTP/1.0", + std.mem.readIntNative(u64, "HTTP/1.1") => .@"HTTP/1.1", + else => return invalid(p, 0), + }; + p.state = .expect_status; + return 8; + }, + .expect_status => { + // example: " 200 OK\r\n" + // example; " 301 Moved Permanently\r\n" + switch (std.mem.readIntNative(u64, chunk[0..8])) { + std.mem.readIntNative(u64, " 200 OK\r") => { + if (chunk[8] != '\n') return invalid(p, 8); + p.headers.status = .ok; + p.state = .line; + return 9; + }, + std.mem.readIntNative(u64, " 301 Mov") => { + p.headers.status = .moved_permanently; + if (!std.mem.eql(u8, chunk[9..], "ed Perma")) + return invalid(p, 9); + p.state = .find_start_line_end; + return 16; + }, + else => { + if (chunk[0] != ' ') return invalid(p, 0); + const status = std.fmt.parseInt(u10, chunk[1..][0..3], 10) catch + return invalid(p, 1); + p.headers.status = @intToEnum(http.Status, status); + const v: @Vector(12, u8) = chunk[4..16].*; + const matches_r = v == @splat(12, @as(u8, '\r')); + const iota = std.simd.iota(u8, 12); + const default = @splat(12, @as(u8, 12)); + const index = 4 + @reduce(.Min, @select(u8, matches_r, iota, default)); + if (index >= 15) { + p.state = .find_start_line_end; + return index; + } + if (chunk[index + 1] != '\n') + return invalid(p, index + 1); + p.state = .line; + return index + 2; + }, + } + }, + .find_start_line_end => { + const v: @Vector(16, u8) = chunk.*; + const matches_r = v == @splat(16, @as(u8, '\r')); + const iota = std.simd.iota(u8, 16); + const default = @splat(16, @as(u8, 16)); + const index = @reduce(.Min, @select(u8, matches_r, iota, default)); + if (index >= 15) { + p.state = .find_start_line_end; + return index; + } + if (chunk[index + 1] != '\n') + return invalid(p, index + 1); + p.state = .line; + return index + 2; + }, + .line => { + const v: @Vector(16, u8) = chunk.*; + const matches_r = v == @splat(16, @as(u8, '\r')); + const iota = std.simd.iota(u8, 16); + const default = @splat(16, @as(u8, 16)); + const index = @reduce(.Min, @select(u8, matches_r, iota, default)); + if (index >= 15) { + return index; + } + if (chunk[index + 1] != '\n') + return invalid(p, index + 1); + if (index + 4 <= 16 and chunk[index + 2] == '\r') { + if (chunk[index + 3] != '\n') return invalid(p, index + 3); + p.state = .finished; + return index + 4; + } + p.state = .line_r; + return index + 2; + }, + .line_r => { + if (chunk[0] == '\r') { + if (chunk[1] != '\n') return invalid(p, 1); + p.state = .finished; + return 2; + } + p.state = .line; + // Here would be nice to use this proposal when it is implemented: + // https://github.com/ziglang/zig/issues/8220 + return 0; + }, + } + } + + fn invalid(p: *Parser, i: u8) u8 { + p.state = .invalid; + return i; + } +}; + +const std = @import("../std.zig"); +const http = std.http; +const Headers = @This(); +const testing = std.testing; + +test "status line ok" { + var p = Parser.init; + const line = "HTTP/1.1 200 OK\r\n"; + try testing.expect(p.feed(line) == line.len); + try testing.expectEqual(Parser.State.line, p.state); + try testing.expect(p.headers.version == .@"HTTP/1.1"); + try testing.expect(p.headers.status == .ok); +} + +test "status line non hot path long msg" { + var p = Parser.init; + const line = "HTTP/1.0 418 I'm a teapot\r\n"; + try testing.expect(p.feed(line) == line.len); + try testing.expectEqual(Parser.State.line, p.state); + try testing.expect(p.headers.version == .@"HTTP/1.0"); + try testing.expect(p.headers.status == .teapot); +} + +test "status line non hot path short msg" { + var p = Parser.init; + const line = "HTTP/1.1 418 lol\r\n"; + try testing.expect(p.feed(line) == line.len); + try testing.expectEqual(Parser.State.line, p.state); + try testing.expect(p.headers.version == .@"HTTP/1.1"); + try testing.expect(p.headers.status == .teapot); +}