From aa090a49d94155c4804644377db110f3b13f0500 Mon Sep 17 00:00:00 2001 From: Nameless Date: Tue, 22 Aug 2023 10:05:03 -0500 Subject: [PATCH] std.http: handle expect:100-continue and continue responses --- lib/std/http/Client.zig | 49 ++++++++++++++++++++---- lib/std/http/Server.zig | 80 +++++++++++++++++++++------------------ lib/std/http/protocol.zig | 9 +++-- test/standalone/http.zig | 70 +++++++++++++++++++++++++++++++++- 4 files changed, 161 insertions(+), 47 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 152e6ad2dd..2e1be3da76 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -478,6 +478,7 @@ pub const Request = struct { .zstd => |*zstd| zstd.deinit(), } + req.headers.deinit(); req.response.headers.deinit(); if (req.response.parser.header_bytes_owned) { @@ -667,17 +668,19 @@ pub const Request = struct { try req.response.parse(req.response.parser.header_bytes.items, false); - if (req.response.status == .switching_protocols) { + if (req.response.status == .@"continue") { + req.response.parser.done = true; // we're done parsing the continue response, reset to prepare for the real response + req.response.parser.reset(); + break; + } + + // we're switching protocols, so this connection is no longer doing http + if (req.response.status == .switching_protocols or (req.method == .CONNECT and req.response.status == .ok)) { req.connection.?.data.closing = false; req.response.parser.done = true; } - if (req.method == .CONNECT and req.response.status == .ok) { - req.connection.?.data.closing = false; - req.response.parser.done = true; - } - - // we default to using keep-alive if not provided + // we default to using keep-alive if not provided in the client if the server asks for it const req_connection = req.headers.getFirstValue("connection"); const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); @@ -955,6 +958,38 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: return conn; } +pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{NameTooLong} || std.os.ConnectError; + +pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*ConnectionPool.Node { + if (client.connection_pool.findConnection(.{ + .host = path, + .port = 0, + .is_tls = false, + })) |node| + return node; + + const conn = try client.allocator.create(ConnectionPool.Node); + errdefer client.allocator.destroy(conn); + conn.* = .{ .data = undefined }; + + const stream = try std.net.connectUnixSocket(path); + errdefer stream.close(); + + conn.data = .{ + .stream = stream, + .tls_client = undefined, + .protocol = .plain, + + .host = try client.allocator.dupe(u8, path), + .port = 0, + }; + errdefer client.allocator.free(conn.data.host); + + client.connection_pool.addUsed(conn); + + return conn; +} + // Prevents a dependency loop in request() const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused }; pub const ConnectError = ConnectErrorPartial || RequestError; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 59d404a81a..4bb11eada6 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -411,48 +411,52 @@ pub const Response = struct { } try w.writeAll("\r\n"); - if (!res.headers.contains("server")) { - try w.writeAll("Server: zig (std.http)\r\n"); - } - - if (!res.headers.contains("connection")) { - const req_connection = res.request.headers.getFirstValue("connection"); - const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - - if (req_keepalive) { - try w.writeAll("Connection: keep-alive\r\n"); - } else { - try w.writeAll("Connection: close\r\n"); - } - } - - const has_transfer_encoding = res.headers.contains("transfer-encoding"); - const has_content_length = res.headers.contains("content-length"); - - if (!has_transfer_encoding and !has_content_length) { - switch (res.transfer_encoding) { - .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"), - .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}), - .none => {}, - } + if (res.status == .@"continue") { + res.state = .waited; // we still need to send another request after this } else { - if (has_content_length) { - const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; + if (!res.headers.contains("server")) { + try w.writeAll("Server: zig (std.http)\r\n"); + } - res.transfer_encoding = .{ .content_length = content_length }; - } else if (has_transfer_encoding) { - const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?; - if (std.mem.eql(u8, transfer_encoding, "chunked")) { - res.transfer_encoding = .chunked; + if (!res.headers.contains("connection")) { + const req_connection = res.request.headers.getFirstValue("connection"); + const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); + + if (req_keepalive) { + try w.writeAll("Connection: keep-alive\r\n"); } else { - return error.UnsupportedTransferEncoding; + try w.writeAll("Connection: close\r\n"); + } + } + + const has_transfer_encoding = res.headers.contains("transfer-encoding"); + const has_content_length = res.headers.contains("content-length"); + + if (!has_transfer_encoding and !has_content_length) { + switch (res.transfer_encoding) { + .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"), + .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}), + .none => {}, } } else { - res.transfer_encoding = .none; - } - } + if (has_content_length) { + const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; - try w.print("{}", .{res.headers}); + res.transfer_encoding = .{ .content_length = content_length }; + } else if (has_transfer_encoding) { + const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?; + if (std.mem.eql(u8, transfer_encoding, "chunked")) { + res.transfer_encoding = .chunked; + } else { + return error.UnsupportedTransferEncoding; + } + } else { + res.transfer_encoding = .none; + } + } + + try w.print("{}", .{res.headers}); + } try w.writeAll("\r\n"); @@ -516,6 +520,10 @@ pub const Response = struct { res.request.parser.done = true; } + if (res.request.method == .HEAD) { + res.request.parser.done = true; + } + if (!res.request.parser.done) { if (res.request.transfer_compression) |tc| switch (tc) { .compress => return error.CompressionNotSupported, diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 05b8f51b5e..b0dfc73027 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -534,9 +534,9 @@ pub const HeadersParser = struct { if (r.next_chunk_length == 0) r.done = true; - return 0; - } else { - const out_avail = buffer.len; + return out_index; + } else if (out_index < buffer.len) { + const out_avail = buffer.len - out_index; const can_read = @as(usize, @intCast(@min(data_avail, out_avail))); const nread = try conn.read(buffer[0..can_read]); @@ -545,6 +545,8 @@ pub const HeadersParser = struct { if (r.next_chunk_length == 0) r.done = true; return nread; + } else { + return out_index; } }, .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { @@ -558,6 +560,7 @@ pub const HeadersParser = struct { .chunk_data => if (r.next_chunk_length == 0) { if (std.mem.eql(u8, conn.peek(), "\r\n")) { r.state = .finished; + r.done = true; } else { // The trailer section is formatted identically to the header section. r.state = .seen_rn; diff --git a/test/standalone/http.zig b/test/standalone/http.zig index ca9f8f258e..5a59544a14 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -22,6 +22,18 @@ fn handleRequest(res: *Server.Response) !void { log.info("{s} {s} {s}", .{ @tagName(res.request.method), @tagName(res.request.version), res.request.target }); + if (res.request.headers.contains("expect")) { + if (mem.eql(u8, res.request.headers.getFirstValue("expect").?, "100-continue")) { + res.status = .@"continue"; + try res.do(); + res.status = .ok; + } else { + res.status = .expectation_failed; + try res.do(); + return; + } + } + const body = try res.reader().readAllAlloc(salloc, 8192); defer salloc.free(body); @@ -62,7 +74,7 @@ fn handleRequest(res: *Server.Response) !void { } try res.finish(); - } else if (mem.eql(u8, res.request.target, "/echo-content")) { + } else if (mem.startsWith(u8, res.request.target, "/echo-content")) { try testing.expectEqualStrings("Hello, World!\n", body); try testing.expectEqualStrings("text/plain", res.request.headers.getFirstValue("content-type").?); @@ -592,6 +604,62 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", res.body.?); } + { // expect: 100-continue + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + try h.append("expect", "100-continue"); + try h.append("content-type", "text/plain"); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-100", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.POST, uri, h, .{}); + defer req.deinit(); + + req.transfer_encoding = .chunked; + + try req.start(); + try req.wait(); + try testing.expectEqual(http.Status.@"continue", req.response.status); + + try req.writeAll("Hello, "); + try req.writeAll("World!\n"); + try req.finish(); + + try req.wait(); + try testing.expectEqual(http.Status.ok, req.response.status); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + } + + { // expect: garbage + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + try h.append("content-type", "text/plain"); + try h.append("expect", "garbage"); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-garbage", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.POST, uri, h, .{}); + defer req.deinit(); + + req.transfer_encoding = .chunked; + + try req.start(); + try req.wait(); + try testing.expectEqual(http.Status.expectation_failed, req.response.status); + } + { // issue 16282 *** This test leaves the client in an invalid state, it must be last *** const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); defer calloc.free(location);