std.http: parser fixes

* add API for iterating over custom HTTP headers
* remove `trailing` flag from std.http.Client.parse. Instead, simply
  don't call parse() for trailers.
* fix the logic inside that parse() function. it was using wrong std.mem
  functions, ignoring malformed data, and returned errors on dead
  branches.
* simplify logic inside wait()
* fix HeadersParser not dropping the 2 read bytes of \r\n after a
  chunked transfer
* move the trailers test to be a std lib unit test and make it pass
This commit is contained in:
Andrew Kelley 2024-02-16 18:35:57 -07:00
parent d574875f00
commit 78192637fb
3 changed files with 139 additions and 61 deletions

View File

@ -428,12 +428,14 @@ pub const Response = struct {
CompressionUnsupported,
};
pub fn parse(res: *Response, bytes: []const u8, trailing: bool) ParseError!void {
var it = mem.tokenizeAny(u8, bytes, "\r\n");
pub fn parse(res: *Response, bytes: []const u8) ParseError!void {
var it = mem.splitSequence(u8, bytes, "\r\n");
const first_line = it.next() orelse return error.HttpHeadersInvalid;
if (first_line.len < 12)
const first_line = it.next().?;
if (first_line.len < 12) {
std.debug.print("first line: '{s}'\n", .{first_line});
return error.HttpHeadersInvalid;
}
const version: http.Version = switch (int64(first_line[0..8])) {
int64("HTTP/1.0") => .@"HTTP/1.0",
@ -449,17 +451,16 @@ pub const Response = struct {
res.reason = reason;
while (it.next()) |line| {
if (line.len == 0) return error.HttpHeadersInvalid;
if (line.len == 0) return;
switch (line[0]) {
' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
else => {},
}
var line_it = mem.tokenizeAny(u8, line, ": ");
const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
var line_it = mem.splitSequence(u8, line, ": ");
const header_name = line_it.next().?;
const header_value = line_it.rest();
if (trailing) continue;
if (header_value.len == 0) return error.HttpHeadersInvalid;
if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close");
@ -538,6 +539,10 @@ pub const Response = struct {
try expectEqual(@as(u10, 999), parseInt3("999"));
}
pub fn iterateHeaders(r: Response) proto.HeaderIterator {
return proto.HeaderIterator.init(r.parser.get());
}
version: http.Version,
status: http.Status,
reason: []const u8,
@ -868,7 +873,7 @@ pub const Request = struct {
if (req.response.parser.state.isContent()) break;
}
try req.response.parse(req.response.parser.get(), false);
try req.response.parse(req.response.parser.get());
if (req.response.status == .@"continue") {
// We're done parsing the continue response; reset to prepare
@ -903,21 +908,21 @@ pub const Request = struct {
return; // The response is empty; no further setup or redirection is necessary.
}
if (req.response.transfer_encoding != .none) {
switch (req.response.transfer_encoding) {
.none => unreachable,
.chunked => {
req.response.parser.next_chunk_length = 0;
req.response.parser.state = .chunk_head_size;
},
}
} else if (req.response.content_length) |cl| {
req.response.parser.next_chunk_length = cl;
switch (req.response.transfer_encoding) {
.none => {
if (req.response.content_length) |cl| {
req.response.parser.next_chunk_length = cl;
if (cl == 0) req.response.parser.done = true;
} else {
// read until the connection is closed
req.response.parser.next_chunk_length = std.math.maxInt(u64);
if (cl == 0) req.response.parser.done = true;
} else {
// read until the connection is closed
req.response.parser.next_chunk_length = std.math.maxInt(u64);
}
},
.chunked => {
req.response.parser.next_chunk_length = 0;
req.response.parser.state = .chunk_head_size;
},
}
if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) {
@ -1014,27 +1019,16 @@ pub const Request = struct {
//.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
else => try req.transferRead(buffer),
};
if (out_index > 0) return out_index;
if (out_index == 0) {
const has_trail = !req.response.parser.state.isContent();
while (!req.response.parser.state.isContent()) { // read trailing headers
try req.connection.?.fill();
while (!req.response.parser.state.isContent()) { // read trailing headers
try req.connection.?.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek());
req.connection.?.drop(@intCast(nchecked));
}
if (has_trail) {
// The response headers before the trailers are already
// guaranteed to be valid, so they will always be parsed again
// and cannot return an error.
// This will *only* fail for a malformed trailer.
req.response.parse(req.response.parser.get(), true) catch return error.InvalidTrailers;
}
const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek());
req.connection.?.drop(@intCast(nchecked));
}
return out_index;
return 0;
}
/// Reads data from the response body. Must be called after `wait`.

View File

@ -570,9 +570,10 @@ pub const HeadersParser = struct {
.chunk_data => if (r.next_chunk_length == 0) {
if (std.mem.eql(u8, conn.peek(), "\r\n")) {
r.state = .finished;
r.done = true;
conn.drop(2);
} else {
// The trailer section is formatted identically to the header section.
// The trailer section is formatted identically
// to the header section.
r.state = .seen_rn;
}
r.done = true;
@ -613,6 +614,68 @@ pub const HeadersParser = struct {
}
};
pub const HeaderIterator = struct {
bytes: []const u8,
index: usize,
is_trailer: bool,
pub fn init(bytes: []const u8) HeaderIterator {
return .{
.bytes = bytes,
.index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2,
.is_trailer = false,
};
}
pub fn next(it: *HeaderIterator) ?std.http.Header {
const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?;
var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": ");
const name = kv_it.next().?;
const value = kv_it.rest();
if (value.len == 0) {
if (it.is_trailer) return null;
const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse
return null;
it.is_trailer = true;
it.index = next_end + 2;
kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": ");
return .{
.name = kv_it.next().?,
.value = kv_it.rest(),
};
}
it.index = end + 2;
return .{
.name = name,
.value = value,
};
}
test next {
var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n");
try std.testing.expect(!it.is_trailer);
{
const header = it.next().?;
try std.testing.expect(!it.is_trailer);
try std.testing.expectEqualStrings("a", header.name);
try std.testing.expectEqualStrings("b", header.value);
}
{
const header = it.next().?;
try std.testing.expect(!it.is_trailer);
try std.testing.expectEqualStrings("c", header.name);
try std.testing.expectEqualStrings("d", header.value);
}
{
const header = it.next().?;
try std.testing.expect(it.is_trailer);
try std.testing.expectEqualStrings("e", header.name);
try std.testing.expectEqualStrings("f", header.value);
}
try std.testing.expectEqual(null, it.next());
}
};
inline fn int16(array: *const [2]u8) u16 {
return @as(u16, @bitCast(array.*));
}

View File

@ -1,7 +1,8 @@
const std = @import("std");
const testing = std.testing;
test "trailers" {
const gpa = std.testing.allocator;
const gpa = testing.allocator;
var http_server = std.http.Server.init(.{
.reuse_address = true,
@ -21,28 +22,49 @@ test "trailers" {
defer gpa.free(location);
const uri = try std.Uri.parse(location);
var server_header_buffer: [1024]u8 = undefined;
var req = try client.open(.GET, uri, .{
.server_header_buffer = &server_header_buffer,
});
defer req.deinit();
{
var server_header_buffer: [1024]u8 = undefined;
var req = try client.open(.GET, uri, .{
.server_header_buffer = &server_header_buffer,
});
defer req.deinit();
try req.send(.{});
try req.wait();
try req.send(.{});
try req.wait();
const body = try req.reader().readAllAlloc(gpa, 8192);
defer gpa.free(body);
const body = try req.reader().readAllAlloc(gpa, 8192);
defer gpa.free(body);
try std.testing.expectEqualStrings("Hello, World!\n", body);
if (true) @panic("TODO implement inspecting custom headers in responses");
//try testing.expectEqualStrings("aaaa", req.response.headers.getFirstValue("x-checksum").?);
try testing.expectEqualStrings("Hello, World!\n", body);
var it = req.response.iterateHeaders();
{
const header = it.next().?;
try testing.expect(!it.is_trailer);
try testing.expectEqualStrings("connection", header.name);
try testing.expectEqualStrings("keep-alive", header.value);
}
{
const header = it.next().?;
try testing.expect(!it.is_trailer);
try testing.expectEqualStrings("transfer-encoding", header.name);
try testing.expectEqualStrings("chunked", header.value);
}
{
const header = it.next().?;
try testing.expect(it.is_trailer);
try testing.expectEqualStrings("X-Checksum", header.name);
try testing.expectEqualStrings("aaaa", header.value);
}
try testing.expectEqual(null, it.next());
}
// connection has been kept alive
try std.testing.expect(client.connection_pool.free_len == 1);
try testing.expect(client.connection_pool.free_len == 1);
}
fn serverThread(http_server: *std.http.Server) anyerror!void {
const gpa = std.testing.allocator;
const gpa = testing.allocator;
var header_buffer: [1024]u8 = undefined;
var remaining: usize = 1;
@ -60,17 +82,16 @@ fn serverThread(http_server: *std.http.Server) anyerror!void {
};
try serve(&res);
try std.testing.expectEqual(.reset, res.reset());
try testing.expectEqual(.reset, res.reset());
}
}
fn serve(res: *std.http.Server.Response) !void {
try std.testing.expectEqualStrings(res.request.target, "/trailer");
try testing.expectEqualStrings(res.request.target, "/trailer");
res.transfer_encoding = .chunked;
try res.send();
try res.writeAll("Hello, ");
try res.writeAll("World!\n");
// try res.finish();
try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n");
}