From 1b3ebfefd8d8fd05152d55c431f741880d1ce2a7 Mon Sep 17 00:00:00 2001 From: Nameless Date: Wed, 3 May 2023 14:34:10 -0500 Subject: [PATCH] fix keepalive and large buffered writes --- lib/std/http/Client.zig | 22 ++++---- lib/std/http/Server.zig | 23 ++++---- test/standalone/http.zig | 112 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 129 insertions(+), 28 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 72f16b0e76..77f4c6c6c3 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -71,7 +71,7 @@ pub const ConnectionPool = struct { while (next) |node| : (next = node.prev) { if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue; if (node.data.port != criteria.port) continue; - if (mem.eql(u8, node.data.host, criteria.host)) continue; + if (!mem.eql(u8, node.data.host, criteria.host)) continue; pool.acquireUnsafe(node); return node; @@ -317,32 +317,29 @@ pub const BufferedConnection = struct { } pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - if (bconn.write_buf.len - bconn.write_end <= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..], buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); bconn.write_end += @intCast(u16, buffer.len); } else { - try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - bconn.write_end = 0; - + try bconn.flush(); try bconn.conn.writeAll(buffer); } } pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - if (bconn.write_buf.len - bconn.write_end <= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..], buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); bconn.write_end += @intCast(u16, buffer.len); return buffer.len; } else { - try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - bconn.write_end = 0; - + try bconn.flush(); return try bconn.conn.write(buffer); } } pub fn flush(bconn: *BufferedConnection) WriteError!void { + defer bconn.write_end = 0; return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); } @@ -720,12 +717,13 @@ pub const Request = struct { req.response.parser.done = true; } + // we default to using keep-alive if not provided const req_connection = req.headers.getFirstValue("connection"); const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); const res_connection = req.response.headers.getFirstValue("connection"); const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); - if (req_keepalive and res_keepalive) { + if (res_keepalive and (req_keepalive or req_connection == null)) { req.connection.data.closing = false; } else { req.connection.data.closing = true; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index d473366092..2d3032bcdb 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -161,32 +161,29 @@ pub const BufferedConnection = struct { } pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - if (bconn.write_buf.len - bconn.write_end <= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..], buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); bconn.write_end += @intCast(u16, buffer.len); } else { - try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - bconn.write_end = 0; - + try bconn.flush(); try bconn.conn.writeAll(buffer); } } pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - if (bconn.write_buf.len - bconn.write_end <= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..], buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); bconn.write_end += @intCast(u16, buffer.len); return buffer.len; } else { - try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - bconn.write_end = 0; - + try bconn.flush(); return try bconn.conn.write(buffer); } } pub fn flush(bconn: *BufferedConnection) WriteError!void { + defer bconn.write_end = 0; return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); } @@ -397,12 +394,14 @@ pub const Response = struct { // A connection is only keep-alive if the Connection header is present and it's value is not "close". // The server and client must both agree + // + // do() defaults to using keep-alive if the client requests it. const res_connection = res.headers.getFirstValue("connection"); const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); const req_connection = res.request.headers.getFirstValue("connection"); const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - if (res_keepalive and req_keepalive) { + if (req_keepalive and (res_keepalive or res_connection == null)) { res.connection.conn.closing = false; } else { res.connection.conn.closing = true; @@ -424,7 +423,7 @@ pub const Response = struct { res.headers.clearRetainingCapacity(); - res.request.headers.clearRetainingCapacity(); + res.request.headers.clearAndFree(); // FIXME: figure out why `clearRetainingCapacity` causes a leak in hash_map here res.request.parser.reset(); res.request = Request{ diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 180e863bba..13dc278b6d 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -9,8 +9,8 @@ const testing = std.testing; const max_header_size = 8192; -var gpa_server = std.heap.GeneralPurposeAllocator(.{}){}; -var gpa_client = std.heap.GeneralPurposeAllocator(.{}){}; +var gpa_server = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){}; +var gpa_client = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){}; const salloc = gpa_server.allocator(); const calloc = gpa_client.allocator(); @@ -44,6 +44,24 @@ fn handleRequest(res: *Server.Response) !void { try res.writeAll("World!\n"); try res.finish(); } + } else if (mem.startsWith(u8, res.request.target, "/large")) { + res.transfer_encoding = .{ .content_length = 14 * 1024 + 14 * 10 }; + + try res.do(); + + var i: u32 = 0; + while (i < 5) : (i += 1) { + try res.writeAll("Hello, World!\n"); + } + + try res.writeAll("Hello, World!\n" ** 1024); + + i = 0; + while (i < 5) : (i += 1) { + try res.writeAll("Hello, World!\n"); + } + + try res.finish(); } else if (mem.eql(u8, res.request.target, "/echo-content")) { try testing.expectEqualStrings("Hello, World!\n", body); try testing.expectEqualStrings("text/plain", res.request.headers.getFirstValue("content-type").?); @@ -68,6 +86,7 @@ fn handleRequest(res: *Server.Response) !void { try res.writeAll("World!\n"); // try res.finish(); try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); + try res.connection.flush(); } else if (mem.eql(u8, res.request.target, "/redirect/1")) { res.transfer_encoding = .chunked; @@ -177,8 +196,7 @@ pub fn main() !void { const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server}); var client = Client{ .allocator = calloc }; - - defer client.deinit(); + // defer client.deinit(); handled below { // read content-length response var h = http.Headers{ .allocator = calloc }; @@ -202,6 +220,33 @@ pub fn main() !void { try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + + { // read large content-length response + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/large", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192 * 1024); + defer calloc.free(body); + + try testing.expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len); + } + + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // send head request and not read chunked var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -225,6 +270,9 @@ pub fn main() !void { try testing.expectEqualStrings("14", req.response.headers.getFirstValue("content-length").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // read chunked response var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -247,6 +295,9 @@ pub fn main() !void { try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // send head request and not read chunked var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -270,6 +321,9 @@ pub fn main() !void { try testing.expectEqualStrings("chunked", req.response.headers.getFirstValue("transfer-encoding").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // check trailing headers var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -292,6 +346,9 @@ pub fn main() !void { try testing.expectEqualStrings("aaaa", req.response.headers.getFirstValue("x-checksum").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // send content-length request var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -321,6 +378,36 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + + { // read content-length response with connection close + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + try h.append("connection", "close"); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); + } + + // connection has been closed + try testing.expect(client.connection_pool.free_len == 0); + { // send chunked request var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -350,6 +437,9 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // relative redirect var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -371,6 +461,9 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // redirect from root var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -392,6 +485,9 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // absolute redirect var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -413,6 +509,9 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // too many redirects var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -432,6 +531,11 @@ pub fn main() !void { }; } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + + client.deinit(); + killServer(server.socket.listen_address); server_thread.join(); }