From 729a051e9e38674233190aea23c0ac8c134f2d67 Mon Sep 17 00:00:00 2001 From: Mizuochi Keita Date: Thu, 8 Jun 2023 01:09:23 +0900 Subject: [PATCH] std.http: Fix segfault while redirecting Make to avoid releasing request's connection twice. Change the `Request.connection` field optional. This field is null while the connection is released. Fixes #15965 --- lib/std/http/Client.zig | 53 +++++++++++++++++++++------------------- test/standalone/http.zig | 38 ++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 25 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 1085309cbb..975375a2b9 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -451,7 +451,8 @@ pub const Response = struct { pub const Request = struct { uri: Uri, client: *Client, - connection: *ConnectionPool.Node, + /// is null when this connection is released + connection: ?*ConnectionPool.Node, method: http.Method, version: http.Version = .@"HTTP/1.1", @@ -481,13 +482,14 @@ pub const Request = struct { req.response.parser.header_bytes.deinit(req.client.allocator); } - if (!req.response.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - req.connection.data.closing = true; + if (req.connection) |connection| { + if (!req.response.parser.done) { + // If the response wasn't fully read, then we need to close the connection. + connection.data.closing = true; + } + req.client.connection_pool.release(req.client, connection); } - req.client.connection_pool.release(req.client, req.connection); - req.arena.deinit(); req.* = undefined; } @@ -504,7 +506,8 @@ pub const Request = struct { .zstd => |*zstd| zstd.deinit(), } - req.client.connection_pool.release(req.client, req.connection); + req.client.connection_pool.release(req.client, req.connection.?); + req.connection = null; const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; @@ -534,7 +537,7 @@ pub const Request = struct { /// Send the request to the server. pub fn start(req: *Request) StartError!void { - var buffered = std.io.bufferedWriter(req.connection.data.writer()); + var buffered = std.io.bufferedWriter(req.connection.?.data.writer()); const w = buffered.writer(); try w.writeAll(@tagName(req.method)); @@ -544,7 +547,7 @@ pub const Request = struct { try w.writeAll(req.uri.host.?); try w.writeByte(':'); try w.print("{}", .{req.uri.port.?}); - } else if (req.connection.data.proxied) { + } else if (req.connection.?.data.proxied) { // proxied connections require the full uri try w.print("{+/}", .{req.uri}); } else { @@ -625,7 +628,7 @@ pub const Request = struct { var index: usize = 0; while (index == 0) { - const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip); + const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip); if (amt == 0 and req.response.parser.done) break; index += amt; } @@ -643,10 +646,10 @@ pub const Request = struct { pub fn wait(req: *Request) WaitError!void { while (true) { // handle redirects while (true) { // read headers - try req.connection.data.fill(); + try req.connection.?.data.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek()); - req.connection.data.drop(@intCast(u16, nchecked)); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek()); + req.connection.?.data.drop(@intCast(u16, nchecked)); if (req.response.parser.state.isContent()) break; } @@ -654,12 +657,12 @@ pub const Request = struct { try req.response.parse(req.response.parser.header_bytes.items, false); if (req.response.status == .switching_protocols) { - req.connection.data.closing = false; + req.connection.?.data.closing = false; req.response.parser.done = true; } if (req.method == .CONNECT and req.response.status == .ok) { - req.connection.data.closing = false; + req.connection.?.data.closing = false; req.response.parser.done = true; } @@ -670,9 +673,9 @@ pub const Request = struct { const res_connection = req.response.headers.getFirstValue("connection"); const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); if (res_keepalive and (req_keepalive or req_connection == null)) { - req.connection.data.closing = false; + req.connection.?.data.closing = false; } else { - req.connection.data.closing = true; + req.connection.?.data.closing = true; } if (req.response.transfer_encoding) |te| { @@ -762,10 +765,10 @@ pub const Request = struct { const has_trail = !req.response.parser.state.isContent(); while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.data.fill(); + try req.connection.?.data.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek()); - req.connection.data.drop(@intCast(u16, nchecked)); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek()); + req.connection.?.data.drop(@intCast(u16, nchecked)); } if (has_trail) { @@ -803,16 +806,16 @@ pub const Request = struct { pub fn write(req: *Request, bytes: []const u8) WriteError!usize { switch (req.transfer_encoding) { .chunked => { - try req.connection.data.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.data.writeAll(bytes); - try req.connection.data.writeAll("\r\n"); + try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.?.data.writeAll(bytes); + try req.connection.?.data.writeAll("\r\n"); return bytes.len; }, .content_length => |*len| { if (len.* < bytes.len) return error.MessageTooLong; - const amt = try req.connection.data.write(bytes); + const amt = try req.connection.?.data.write(bytes); len.* -= amt; return amt; }, @@ -832,7 +835,7 @@ pub const Request = struct { /// Finish the body of a request. This notifies the server that you have no more data to send. pub fn finish(req: *Request) FinishError!void { switch (req.transfer_encoding) { - .chunked => try req.connection.data.writeAll("0\r\n\r\n"), + .chunked => try req.connection.?.data.writeAll("0\r\n\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } diff --git a/test/standalone/http.zig b/test/standalone/http.zig index ffb7a59276..90ba285105 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -129,6 +129,15 @@ fn handleRequest(res: *Server.Response) !void { try res.writeAll("Hello, "); try res.writeAll("Redirected!\n"); try res.finish(); + } else if (mem.eql(u8, res.request.target, "/redirect/invalid")) { + const invalid_port = try getUnusedTcpPort(); + const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}", .{invalid_port}); + defer salloc.free(location); + + res.status = .found; + try res.headers.append("location", location); + try res.do(); + try res.finish(); } else { res.status = .not_found; try res.do(); @@ -180,6 +189,14 @@ fn killServer(addr: std.net.Address) void { conn.close(); } +fn getUnusedTcpPort() !u16 { + const addr = try std.net.Address.parseIp("127.0.0.1", 0); + var s = std.net.StreamServer.init(.{}); + defer s.deinit(); + try s.listen(addr); + return s.listen_address.in.getPort(); +} + pub fn main() !void { const log = std.log.scoped(.client); @@ -533,6 +550,27 @@ pub fn main() !void { // connection has been kept alive try testing.expect(client.connection_pool.free_len == 1); + { // check client without segfault by connection error after redirection + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/invalid", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + const result = req.wait(); + + try testing.expectError(error.ConnectionRefused, result); // expects not segfault but the regular error + } + + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + client.deinit(); killServer(server.socket.listen_address);