std.http.Headers.Parser: parse version and status

This commit is contained in:
Andrew Kelley 2023-01-03 19:14:44 -07:00
parent 2cdc0a8b50
commit 5d9429579d
3 changed files with 205 additions and 53 deletions

View File

@ -1,4 +1,10 @@
pub const Client = @import("http/Client.zig");
pub const Headers = @import("http/Headers.zig");
pub const Version = enum {
@"HTTP/1.0",
@"HTTP/1.1",
};
/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton
@ -242,55 +248,6 @@ pub const Status = enum(u10) {
}
};
pub const Headers = struct {
state: State = .start,
invalid_index: u32 = undefined,
pub const State = enum { invalid, start, line, nl_r, nl_n, nl2_r, finished };
/// Returns how many bytes are processed into headers. Always less than or
/// equal to bytes.len. If the amount returned is less than bytes.len, it
/// means the headers ended and the first byte after the double \r\n\r\n is
/// located at `bytes[result]`.
pub fn feed(h: *Headers, bytes: []const u8) usize {
for (bytes) |b, i| {
switch (h.state) {
.start => switch (b) {
'\r' => h.state = .nl_r,
'\n' => return invalid(h, i),
else => {},
},
.nl_r => switch (b) {
'\n' => h.state = .nl_n,
else => return invalid(h, i),
},
.nl_n => switch (b) {
'\r' => h.state = .nl2_r,
else => h.state = .line,
},
.nl2_r => switch (b) {
'\n' => h.state = .finished,
else => return invalid(h, i),
},
.line => switch (b) {
'\r' => h.state = .nl_r,
'\n' => return invalid(h, i),
else => {},
},
.invalid => return i,
.finished => return i,
}
}
return bytes.len;
}
fn invalid(h: *Headers, i: usize) usize {
h.invalid_index = @intCast(u32, i);
h.state = .invalid;
return i;
}
};
const std = @import("std.zig");
test {

View File

@ -21,7 +21,8 @@ pub const Request = struct {
stream: net.Stream,
tls_client: std.crypto.tls.Client,
protocol: Protocol,
response_headers: http.Headers = .{},
response_headers: http.Headers,
redirects_left: u32,
pub const Headers = struct {
method: http.Method = .GET,
@ -35,7 +36,9 @@ pub const Request = struct {
pub const Protocol = enum { http, https };
pub const Options = struct {};
pub const Options = struct {
max_redirects: u32 = 3,
};
pub fn readAll(req: *Request, buffer: []u8) !usize {
return readAtLeast(req, buffer, buffer.len);
@ -97,8 +100,6 @@ pub fn deinit(client: *Client, gpa: std.mem.Allocator) void {
}
pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Request.Options) !Request {
_ = options; // we have no options yet
const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse
return error.UnsupportedUrlScheme;
const port: u16 = url.port orelse switch (protocol) {
@ -111,6 +112,7 @@ pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Req
.stream = try net.tcpConnectToHost(client.allocator, url.host, port),
.protocol = protocol,
.tls_client = undefined,
.redirects_left = options.max_redirects,
};
switch (protocol) {

193
lib/std/http/Headers.zig Normal file
View File

@ -0,0 +1,193 @@
status: http.Status,
version: http.Version,
pub const Parser = struct {
state: State,
headers: Headers,
buffer: [16]u8,
buffer_index: u4,
pub const init: Parser = .{
.state = .start,
.headers = .{
.status = undefined,
.version = undefined,
},
.buffer = undefined,
.buffer_index = 0,
};
pub const State = enum {
invalid,
finished,
start,
expect_status,
find_start_line_end,
line,
line_r,
};
/// Returns how many bytes are processed into headers. Always less than or
/// equal to bytes.len. If the amount returned is less than bytes.len, it
/// means the headers ended and the first byte after the double \r\n\r\n is
/// located at `bytes[result]`.
pub fn feed(p: *Parser, bytes: []const u8) usize {
var index: usize = 0;
while (bytes.len - index >= 16) {
index += p.feed16(bytes[index..][0..16]);
switch (p.state) {
.invalid, .finished => return index,
else => continue,
}
}
while (index < bytes.len) {
var buffer = [1]u8{0} ** 16;
const src = bytes[index..bytes.len];
std.mem.copy(u8, &buffer, src);
index += p.feed16(&buffer);
switch (p.state) {
.invalid, .finished => return index,
else => continue,
}
}
return index;
}
pub fn feed16(p: *Parser, chunk: *const [16]u8) u8 {
switch (p.state) {
.invalid, .finished => return 0,
.start => {
p.headers.version = switch (std.mem.readIntNative(u64, chunk[0..8])) {
std.mem.readIntNative(u64, "HTTP/1.0") => .@"HTTP/1.0",
std.mem.readIntNative(u64, "HTTP/1.1") => .@"HTTP/1.1",
else => return invalid(p, 0),
};
p.state = .expect_status;
return 8;
},
.expect_status => {
// example: " 200 OK\r\n"
// example; " 301 Moved Permanently\r\n"
switch (std.mem.readIntNative(u64, chunk[0..8])) {
std.mem.readIntNative(u64, " 200 OK\r") => {
if (chunk[8] != '\n') return invalid(p, 8);
p.headers.status = .ok;
p.state = .line;
return 9;
},
std.mem.readIntNative(u64, " 301 Mov") => {
p.headers.status = .moved_permanently;
if (!std.mem.eql(u8, chunk[9..], "ed Perma"))
return invalid(p, 9);
p.state = .find_start_line_end;
return 16;
},
else => {
if (chunk[0] != ' ') return invalid(p, 0);
const status = std.fmt.parseInt(u10, chunk[1..][0..3], 10) catch
return invalid(p, 1);
p.headers.status = @intToEnum(http.Status, status);
const v: @Vector(12, u8) = chunk[4..16].*;
const matches_r = v == @splat(12, @as(u8, '\r'));
const iota = std.simd.iota(u8, 12);
const default = @splat(12, @as(u8, 12));
const index = 4 + @reduce(.Min, @select(u8, matches_r, iota, default));
if (index >= 15) {
p.state = .find_start_line_end;
return index;
}
if (chunk[index + 1] != '\n')
return invalid(p, index + 1);
p.state = .line;
return index + 2;
},
}
},
.find_start_line_end => {
const v: @Vector(16, u8) = chunk.*;
const matches_r = v == @splat(16, @as(u8, '\r'));
const iota = std.simd.iota(u8, 16);
const default = @splat(16, @as(u8, 16));
const index = @reduce(.Min, @select(u8, matches_r, iota, default));
if (index >= 15) {
p.state = .find_start_line_end;
return index;
}
if (chunk[index + 1] != '\n')
return invalid(p, index + 1);
p.state = .line;
return index + 2;
},
.line => {
const v: @Vector(16, u8) = chunk.*;
const matches_r = v == @splat(16, @as(u8, '\r'));
const iota = std.simd.iota(u8, 16);
const default = @splat(16, @as(u8, 16));
const index = @reduce(.Min, @select(u8, matches_r, iota, default));
if (index >= 15) {
return index;
}
if (chunk[index + 1] != '\n')
return invalid(p, index + 1);
if (index + 4 <= 16 and chunk[index + 2] == '\r') {
if (chunk[index + 3] != '\n') return invalid(p, index + 3);
p.state = .finished;
return index + 4;
}
p.state = .line_r;
return index + 2;
},
.line_r => {
if (chunk[0] == '\r') {
if (chunk[1] != '\n') return invalid(p, 1);
p.state = .finished;
return 2;
}
p.state = .line;
// Here would be nice to use this proposal when it is implemented:
// https://github.com/ziglang/zig/issues/8220
return 0;
},
}
}
fn invalid(p: *Parser, i: u8) u8 {
p.state = .invalid;
return i;
}
};
const std = @import("../std.zig");
const http = std.http;
const Headers = @This();
const testing = std.testing;
test "status line ok" {
var p = Parser.init;
const line = "HTTP/1.1 200 OK\r\n";
try testing.expect(p.feed(line) == line.len);
try testing.expectEqual(Parser.State.line, p.state);
try testing.expect(p.headers.version == .@"HTTP/1.1");
try testing.expect(p.headers.status == .ok);
}
test "status line non hot path long msg" {
var p = Parser.init;
const line = "HTTP/1.0 418 I'm a teapot\r\n";
try testing.expect(p.feed(line) == line.len);
try testing.expectEqual(Parser.State.line, p.state);
try testing.expect(p.headers.version == .@"HTTP/1.0");
try testing.expect(p.headers.status == .teapot);
}
test "status line non hot path short msg" {
var p = Parser.init;
const line = "HTTP/1.1 418 lol\r\n";
try testing.expect(p.feed(line) == line.len);
try testing.expectEqual(Parser.State.line, p.state);
try testing.expect(p.headers.version == .@"HTTP/1.1");
try testing.expect(p.headers.status == .teapot);
}