diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 2064d767ba..00cd136307 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -428,12 +428,14 @@ pub const Response = struct { CompressionUnsupported, }; - pub fn parse(res: *Response, bytes: []const u8, trailing: bool) ParseError!void { - var it = mem.tokenizeAny(u8, bytes, "\r\n"); + pub fn parse(res: *Response, bytes: []const u8) ParseError!void { + var it = mem.splitSequence(u8, bytes, "\r\n"); - const first_line = it.next() orelse return error.HttpHeadersInvalid; - if (first_line.len < 12) + const first_line = it.next().?; + if (first_line.len < 12) { + std.debug.print("first line: '{s}'\n", .{first_line}); return error.HttpHeadersInvalid; + } const version: http.Version = switch (int64(first_line[0..8])) { int64("HTTP/1.0") => .@"HTTP/1.0", @@ -449,17 +451,16 @@ pub const Response = struct { res.reason = reason; while (it.next()) |line| { - if (line.len == 0) return error.HttpHeadersInvalid; + if (line.len == 0) return; switch (line[0]) { ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, else => {}, } - var line_it = mem.tokenizeAny(u8, line, ": "); - const header_name = line_it.next() orelse return error.HttpHeadersInvalid; + var line_it = mem.splitSequence(u8, line, ": "); + const header_name = line_it.next().?; const header_value = line_it.rest(); - - if (trailing) continue; + if (header_value.len == 0) return error.HttpHeadersInvalid; if (std.ascii.eqlIgnoreCase(header_name, "connection")) { res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); @@ -538,6 +539,10 @@ pub const Response = struct { try expectEqual(@as(u10, 999), parseInt3("999")); } + pub fn iterateHeaders(r: Response) proto.HeaderIterator { + return proto.HeaderIterator.init(r.parser.get()); + } + version: http.Version, status: http.Status, reason: []const u8, @@ -868,7 +873,7 @@ pub const Request = struct { if (req.response.parser.state.isContent()) break; } - try req.response.parse(req.response.parser.get(), false); + try req.response.parse(req.response.parser.get()); if (req.response.status == .@"continue") { // We're done parsing the continue response; reset to prepare @@ -903,21 +908,21 @@ pub const Request = struct { return; // The response is empty; no further setup or redirection is necessary. } - if (req.response.transfer_encoding != .none) { - switch (req.response.transfer_encoding) { - .none => unreachable, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, - } - } else if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; + switch (req.response.transfer_encoding) { + .none => { + if (req.response.content_length) |cl| { + req.response.parser.next_chunk_length = cl; - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); + if (cl == 0) req.response.parser.done = true; + } else { + // read until the connection is closed + req.response.parser.next_chunk_length = std.math.maxInt(u64); + } + }, + .chunked => { + req.response.parser.next_chunk_length = 0; + req.response.parser.state = .chunk_head_size; + }, } if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { @@ -1014,27 +1019,16 @@ pub const Request = struct { //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, else => try req.transferRead(buffer), }; + if (out_index > 0) return out_index; - if (out_index == 0) { - const has_trail = !req.response.parser.state.isContent(); + while (!req.response.parser.state.isContent()) { // read trailing headers + try req.connection.?.fill(); - while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); - } - - if (has_trail) { - // The response headers before the trailers are already - // guaranteed to be valid, so they will always be parsed again - // and cannot return an error. - // This will *only* fail for a malformed trailer. - req.response.parse(req.response.parser.get(), true) catch return error.InvalidTrailers; - } + const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); } - return out_index; + return 0; } /// Reads data from the response body. Must be called after `wait`. diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 62016e408d..4c69a79105 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -570,9 +570,10 @@ pub const HeadersParser = struct { .chunk_data => if (r.next_chunk_length == 0) { if (std.mem.eql(u8, conn.peek(), "\r\n")) { r.state = .finished; - r.done = true; + conn.drop(2); } else { - // The trailer section is formatted identically to the header section. + // The trailer section is formatted identically + // to the header section. r.state = .seen_rn; } r.done = true; @@ -613,6 +614,68 @@ pub const HeadersParser = struct { } }; +pub const HeaderIterator = struct { + bytes: []const u8, + index: usize, + is_trailer: bool, + + pub fn init(bytes: []const u8) HeaderIterator { + return .{ + .bytes = bytes, + .index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2, + .is_trailer = false, + }; + } + + pub fn next(it: *HeaderIterator) ?std.http.Header { + const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?; + var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": "); + const name = kv_it.next().?; + const value = kv_it.rest(); + if (value.len == 0) { + if (it.is_trailer) return null; + const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse + return null; + it.is_trailer = true; + it.index = next_end + 2; + kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": "); + return .{ + .name = kv_it.next().?, + .value = kv_it.rest(), + }; + } + it.index = end + 2; + return .{ + .name = name, + .value = value, + }; + } + + test next { + var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n"); + try std.testing.expect(!it.is_trailer); + { + const header = it.next().?; + try std.testing.expect(!it.is_trailer); + try std.testing.expectEqualStrings("a", header.name); + try std.testing.expectEqualStrings("b", header.value); + } + { + const header = it.next().?; + try std.testing.expect(!it.is_trailer); + try std.testing.expectEqualStrings("c", header.name); + try std.testing.expectEqualStrings("d", header.value); + } + { + const header = it.next().?; + try std.testing.expect(it.is_trailer); + try std.testing.expectEqualStrings("e", header.name); + try std.testing.expectEqualStrings("f", header.value); + } + try std.testing.expectEqual(null, it.next()); + } +}; + inline fn int16(array: *const [2]u8) u16 { return @as(u16, @bitCast(array.*)); } diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 9175660e7d..3b3a008922 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -1,7 +1,8 @@ const std = @import("std"); +const testing = std.testing; test "trailers" { - const gpa = std.testing.allocator; + const gpa = testing.allocator; var http_server = std.http.Server.init(.{ .reuse_address = true, @@ -21,28 +22,49 @@ test "trailers" { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); - defer req.deinit(); + { + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); - try req.send(.{}); - try req.wait(); + try req.send(.{}); + try req.wait(); - const body = try req.reader().readAllAlloc(gpa, 8192); - defer gpa.free(body); + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); - try std.testing.expectEqualStrings("Hello, World!\n", body); - if (true) @panic("TODO implement inspecting custom headers in responses"); - //try testing.expectEqualStrings("aaaa", req.response.headers.getFirstValue("x-checksum").?); + try testing.expectEqualStrings("Hello, World!\n", body); + + var it = req.response.iterateHeaders(); + { + const header = it.next().?; + try testing.expect(!it.is_trailer); + try testing.expectEqualStrings("connection", header.name); + try testing.expectEqualStrings("keep-alive", header.value); + } + { + const header = it.next().?; + try testing.expect(!it.is_trailer); + try testing.expectEqualStrings("transfer-encoding", header.name); + try testing.expectEqualStrings("chunked", header.value); + } + { + const header = it.next().?; + try testing.expect(it.is_trailer); + try testing.expectEqualStrings("X-Checksum", header.name); + try testing.expectEqualStrings("aaaa", header.value); + } + try testing.expectEqual(null, it.next()); + } // connection has been kept alive - try std.testing.expect(client.connection_pool.free_len == 1); + try testing.expect(client.connection_pool.free_len == 1); } fn serverThread(http_server: *std.http.Server) anyerror!void { - const gpa = std.testing.allocator; + const gpa = testing.allocator; var header_buffer: [1024]u8 = undefined; var remaining: usize = 1; @@ -60,17 +82,16 @@ fn serverThread(http_server: *std.http.Server) anyerror!void { }; try serve(&res); - try std.testing.expectEqual(.reset, res.reset()); + try testing.expectEqual(.reset, res.reset()); } } fn serve(res: *std.http.Server.Response) !void { - try std.testing.expectEqualStrings(res.request.target, "/trailer"); + try testing.expectEqualStrings(res.request.target, "/trailer"); res.transfer_encoding = .chunked; try res.send(); try res.writeAll("Hello, "); try res.writeAll("World!\n"); - // try res.finish(); try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); }