std.http: Fix segfault while redirecting

Make to avoid releasing request's connection twice.
Change the `Request.connection` field optional. This field is null while the connection is released.

Fixes #15965
This commit is contained in:
Mizuochi Keita 2023-06-08 01:09:23 +09:00 committed by Andrew Kelley
parent e23d48e61a
commit 729a051e9e
2 changed files with 66 additions and 25 deletions

View File

@ -451,7 +451,8 @@ pub const Response = struct {
pub const Request = struct {
uri: Uri,
client: *Client,
connection: *ConnectionPool.Node,
/// is null when this connection is released
connection: ?*ConnectionPool.Node,
method: http.Method,
version: http.Version = .@"HTTP/1.1",
@ -481,13 +482,14 @@ pub const Request = struct {
req.response.parser.header_bytes.deinit(req.client.allocator);
}
if (!req.response.parser.done) {
// If the response wasn't fully read, then we need to close the connection.
req.connection.data.closing = true;
if (req.connection) |connection| {
if (!req.response.parser.done) {
// If the response wasn't fully read, then we need to close the connection.
connection.data.closing = true;
}
req.client.connection_pool.release(req.client, connection);
}
req.client.connection_pool.release(req.client, req.connection);
req.arena.deinit();
req.* = undefined;
}
@ -504,7 +506,8 @@ pub const Request = struct {
.zstd => |*zstd| zstd.deinit(),
}
req.client.connection_pool.release(req.client, req.connection);
req.client.connection_pool.release(req.client, req.connection.?);
req.connection = null;
const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
@ -534,7 +537,7 @@ pub const Request = struct {
/// Send the request to the server.
pub fn start(req: *Request) StartError!void {
var buffered = std.io.bufferedWriter(req.connection.data.writer());
var buffered = std.io.bufferedWriter(req.connection.?.data.writer());
const w = buffered.writer();
try w.writeAll(@tagName(req.method));
@ -544,7 +547,7 @@ pub const Request = struct {
try w.writeAll(req.uri.host.?);
try w.writeByte(':');
try w.print("{}", .{req.uri.port.?});
} else if (req.connection.data.proxied) {
} else if (req.connection.?.data.proxied) {
// proxied connections require the full uri
try w.print("{+/}", .{req.uri});
} else {
@ -625,7 +628,7 @@ pub const Request = struct {
var index: usize = 0;
while (index == 0) {
const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip);
const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip);
if (amt == 0 and req.response.parser.done) break;
index += amt;
}
@ -643,10 +646,10 @@ pub const Request = struct {
pub fn wait(req: *Request) WaitError!void {
while (true) { // handle redirects
while (true) { // read headers
try req.connection.data.fill();
try req.connection.?.data.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
req.connection.data.drop(@intCast(u16, nchecked));
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
req.connection.?.data.drop(@intCast(u16, nchecked));
if (req.response.parser.state.isContent()) break;
}
@ -654,12 +657,12 @@ pub const Request = struct {
try req.response.parse(req.response.parser.header_bytes.items, false);
if (req.response.status == .switching_protocols) {
req.connection.data.closing = false;
req.connection.?.data.closing = false;
req.response.parser.done = true;
}
if (req.method == .CONNECT and req.response.status == .ok) {
req.connection.data.closing = false;
req.connection.?.data.closing = false;
req.response.parser.done = true;
}
@ -670,9 +673,9 @@ pub const Request = struct {
const res_connection = req.response.headers.getFirstValue("connection");
const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
if (res_keepalive and (req_keepalive or req_connection == null)) {
req.connection.data.closing = false;
req.connection.?.data.closing = false;
} else {
req.connection.data.closing = true;
req.connection.?.data.closing = true;
}
if (req.response.transfer_encoding) |te| {
@ -762,10 +765,10 @@ pub const Request = struct {
const has_trail = !req.response.parser.state.isContent();
while (!req.response.parser.state.isContent()) { // read trailing headers
try req.connection.data.fill();
try req.connection.?.data.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
req.connection.data.drop(@intCast(u16, nchecked));
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
req.connection.?.data.drop(@intCast(u16, nchecked));
}
if (has_trail) {
@ -803,16 +806,16 @@ pub const Request = struct {
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
switch (req.transfer_encoding) {
.chunked => {
try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
try req.connection.data.writeAll(bytes);
try req.connection.data.writeAll("\r\n");
try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len});
try req.connection.?.data.writeAll(bytes);
try req.connection.?.data.writeAll("\r\n");
return bytes.len;
},
.content_length => |*len| {
if (len.* < bytes.len) return error.MessageTooLong;
const amt = try req.connection.data.write(bytes);
const amt = try req.connection.?.data.write(bytes);
len.* -= amt;
return amt;
},
@ -832,7 +835,7 @@ pub const Request = struct {
/// Finish the body of a request. This notifies the server that you have no more data to send.
pub fn finish(req: *Request) FinishError!void {
switch (req.transfer_encoding) {
.chunked => try req.connection.data.writeAll("0\r\n\r\n"),
.chunked => try req.connection.?.data.writeAll("0\r\n\r\n"),
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
.none => {},
}

View File

@ -129,6 +129,15 @@ fn handleRequest(res: *Server.Response) !void {
try res.writeAll("Hello, ");
try res.writeAll("Redirected!\n");
try res.finish();
} else if (mem.eql(u8, res.request.target, "/redirect/invalid")) {
const invalid_port = try getUnusedTcpPort();
const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}", .{invalid_port});
defer salloc.free(location);
res.status = .found;
try res.headers.append("location", location);
try res.do();
try res.finish();
} else {
res.status = .not_found;
try res.do();
@ -180,6 +189,14 @@ fn killServer(addr: std.net.Address) void {
conn.close();
}
fn getUnusedTcpPort() !u16 {
const addr = try std.net.Address.parseIp("127.0.0.1", 0);
var s = std.net.StreamServer.init(.{});
defer s.deinit();
try s.listen(addr);
return s.listen_address.in.getPort();
}
pub fn main() !void {
const log = std.log.scoped(.client);
@ -533,6 +550,27 @@ pub fn main() !void {
// connection has been kept alive
try testing.expect(client.connection_pool.free_len == 1);
{ // check client without segfault by connection error after redirection
var h = http.Headers{ .allocator = calloc };
defer h.deinit();
const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/invalid", .{port});
defer calloc.free(location);
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
var req = try client.request(.GET, uri, h, .{});
defer req.deinit();
try req.start();
const result = req.wait();
try testing.expectError(error.ConnectionRefused, result); // expects not segfault but the regular error
}
// connection has been kept alive
try testing.expect(client.connection_pool.free_len == 1);
client.deinit();
killServer(server.socket.listen_address);