diff --git a/lib/std/http.zig b/lib/std/http.zig index 53a59f0e83..9377efe190 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -393,12 +393,18 @@ pub const Reader = struct { ReadFailed, }; + pub fn restituteHeadBuffer(reader: *Reader) void { + reader.in.restitute(reader.head_buffer.len); + reader.head_buffer.len = 0; + } + /// Buffers the entire head into `head_buffer`, invalidating the previous /// `head_buffer`, if any. pub fn receiveHead(reader: *Reader) HeadError!void { reader.trailers = &.{}; const in = reader.in; in.restitute(reader.head_buffer.len); + reader.head_buffer.len = 0; in.rebase(); var hp: HeadParser = .{}; var head_end: usize = 0; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 851fe583ca..2658628bf9 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -117,14 +117,12 @@ pub const ConnectionPool = struct { /// /// Threadsafe. pub fn release(pool: *ConnectionPool, connection: *Connection) void { - if (connection.closing) return connection.destroy(); - pool.mutex.lock(); defer pool.mutex.unlock(); pool.used.remove(&connection.pool_node); - if (pool.free_size == 0) return connection.destroy(); + if (connection.closing or pool.free_size == 0) return connection.destroy(); if (pool.free_len >= pool.free_size) { const popped: *Connection = @fieldParentPtr("pool_node", pool.free.popFirst().?); @@ -669,8 +667,10 @@ pub const Response = struct { /// See also: /// * `readerDecompressing` pub fn reader(response: *Response) std.io.Reader { + const req = response.request; + if (!req.method.responseHasBody()) return .ending; const head = &response.head; - return response.request.reader.bodyReader(head.transfer_encoding, head.content_length); + return req.reader.bodyReader(head.transfer_encoding, head.content_length); } /// If compressed body has been negotiated this will return decompressed bytes. @@ -805,11 +805,13 @@ pub const Request = struct { /// Returns the request's `Connection` back to the pool of the `Client`. pub fn deinit(r: *Request) void { + r.reader.restituteHeadBuffer(); if (r.connection) |connection| { - if (r.reader.state != .ready) { - // Connection cannot be reused. - connection.closing = true; - } + connection.closing = connection.closing or switch (r.reader.state) { + .ready => false, + .received_head => r.method.requestHasBody(), + else => true, + }; r.client.connection_pool.release(connection); } r.* = undefined; @@ -1025,7 +1027,14 @@ pub const Request = struct { } if (head.status.class() == .redirect and r.redirect_behavior != .unhandled) { - if (r.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; + if (r.redirect_behavior == .not_allowed) { + // Connection can still be reused by skipping the body. + var reader = r.reader.bodyReader(head.transfer_encoding, head.content_length); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => connection.closing = true, + }; + return error.TooManyHttpRedirects; + } try r.redirect(head, &aux_buf); try r.sendBodiless(); continue; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 4f8866e532..d66663d5ff 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -30,6 +30,10 @@ pub fn init(in: *std.io.BufferedReader, out: *std.io.BufferedWriter) Server { }; } +pub fn deinit(s: *Server) void { + s.reader.restituteHeadBuffer(); +} + pub const ReceiveHeadError = http.Reader.HeadError || error{ /// Client sent headers that did not conform to the HTTP protocol. /// @@ -483,6 +487,7 @@ pub const Request = struct { return error.HttpExpectationFailed; } } + if (!request.head.method.requestHasBody()) return .ending; return request.server.reader.bodyReader(request.head.transfer_encoding, request.head.content_length); } diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 9be94d4650..f5a68fed73 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -10,7 +10,8 @@ const expectError = std.testing.expectError; test "trailers" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; var recv_buffer: [1024]u8 = undefined; var send_buffer: [1024]u8 = undefined; var remaining: usize = 1; @@ -96,7 +97,8 @@ test "trailers" { test "HTTP server handles a chunked transfer coding request" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) !void { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; var recv_buffer: [8192]u8 = undefined; var send_buffer: [500]u8 = undefined; const connection = try net_server.accept(); @@ -162,11 +164,12 @@ test "HTTP server handles a chunked transfer coding request" { test "echo content server" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; var recv_buffer: [1024]u8 = undefined; var send_buffer: [100]u8 = undefined; - accept: while (true) { + accept: while (!test_server.shutting_down) { const connection = try net_server.accept(); defer connection.stream.close(); @@ -251,7 +254,8 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { // In this case, the response is expected to stream until the connection is // closed, indicating the end of the body. const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; var recv_buffer: [1000]u8 = undefined; var send_buffer: [500]u8 = undefined; var remaining: usize = 1; @@ -279,6 +283,7 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { try bw.print("{d}, ah ha ha!\n", .{i}); } try expectEqual(7390, bw.count); + try bw.flush(); try response.end(); try expectEqual(.closing, server.reader.state); } @@ -319,7 +324,8 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { test "receiving arbitrary http headers from the client" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; var recv_buffer: [666]u8 = undefined; var send_buffer: [777]u8 = undefined; var remaining: usize = 1; @@ -385,15 +391,13 @@ test "general client/server API coverage" { return error.SkipZigTest; } - const global = struct { - var handle_new_requests = true; - }; const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; var recv_buffer: [1024]u8 = undefined; var send_buffer: [100]u8 = undefined; - outer: while (global.handle_new_requests) { + outer: while (!test_server.shutting_down) { var connection = try net_server.accept(); defer connection.stream.close(); @@ -544,17 +548,13 @@ test "general client/server API coverage" { return s.listen_address.in.getPort(); } }); - defer { - global.handle_new_requests = false; - test_server.destroy(); - } + defer test_server.destroy(); const log = std.log.scoped(.client); const gpa = std.testing.allocator; var client: http.Client = .{ .allocator = gpa }; - errdefer client.deinit(); - // defer client.deinit(); handled below + defer client.deinit(); const port = test_server.port(); @@ -870,20 +870,12 @@ test "general client/server API coverage" { // connection has been kept alive try expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - client.deinit(); - - { - global.handle_new_requests = false; - - const conn = try std.net.tcpConnectToAddress(test_server.net_server.listen_address); - conn.close(); - } } test "Server streams both reading and writing" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; var recv_buffer: [1024]u8 = undefined; var send_buffer: [777]u8 = undefined; @@ -1091,19 +1083,18 @@ fn echoTests(client: *http.Client, port: u16) !void { try expectEqual(.expectation_failed, response.head.status); _ = try response.reader().discardRemaining(); } - - _ = try client.fetch(.{ - .location = .{ - .url = try std.fmt.bufPrint(&location_buffer, "http://127.0.0.1:{d}/end", .{port}), - }, - }); } const TestServer = struct { + shutting_down: bool, server_thread: std.Thread, net_server: std.net.Server, fn destroy(self: *@This()) void { + self.shutting_down = true; + const conn = std.net.tcpConnectToAddress(self.net_server.listen_address) catch @panic("shutdown failure"); + conn.close(); + self.server_thread.join(); self.net_server.deinit(); std.testing.allocator.destroy(self); @@ -1123,14 +1114,18 @@ fn createTestServer(S: type) !*TestServer { const address = try std.net.Address.parseIp("127.0.0.1", 0); const test_server = try std.testing.allocator.create(TestServer); - test_server.net_server = try address.listen(.{ .reuse_address = true }); - test_server.server_thread = try std.Thread.spawn(.{}, S.run, .{&test_server.net_server}); + test_server.* = .{ + .net_server = try address.listen(.{ .reuse_address = true }), + .server_thread = try std.Thread.spawn(.{}, S.run, .{test_server}), + .shutting_down = false, + }; return test_server; } test "redirect to different connection" { const test_server_new = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; var recv_buffer: [888]u8 = undefined; var send_buffer: [777]u8 = undefined; @@ -1155,7 +1150,8 @@ test "redirect to different connection" { global.other_port = test_server_new.port(); const test_server_orig = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; var recv_buffer: [999]u8 = undefined; var send_buffer: [100]u8 = undefined;