mirror of
https://github.com/ziglang/zig.git
synced 2026-02-21 16:54:52 +00:00
std.http: Fix segfault while redirecting
Make to avoid releasing request's connection twice. Change the `Request.connection` field optional. This field is null while the connection is released. Fixes #15965
This commit is contained in:
parent
e23d48e61a
commit
729a051e9e
@ -451,7 +451,8 @@ pub const Response = struct {
|
||||
pub const Request = struct {
|
||||
uri: Uri,
|
||||
client: *Client,
|
||||
connection: *ConnectionPool.Node,
|
||||
/// is null when this connection is released
|
||||
connection: ?*ConnectionPool.Node,
|
||||
|
||||
method: http.Method,
|
||||
version: http.Version = .@"HTTP/1.1",
|
||||
@ -481,13 +482,14 @@ pub const Request = struct {
|
||||
req.response.parser.header_bytes.deinit(req.client.allocator);
|
||||
}
|
||||
|
||||
if (!req.response.parser.done) {
|
||||
// If the response wasn't fully read, then we need to close the connection.
|
||||
req.connection.data.closing = true;
|
||||
if (req.connection) |connection| {
|
||||
if (!req.response.parser.done) {
|
||||
// If the response wasn't fully read, then we need to close the connection.
|
||||
connection.data.closing = true;
|
||||
}
|
||||
req.client.connection_pool.release(req.client, connection);
|
||||
}
|
||||
|
||||
req.client.connection_pool.release(req.client, req.connection);
|
||||
|
||||
req.arena.deinit();
|
||||
req.* = undefined;
|
||||
}
|
||||
@ -504,7 +506,8 @@ pub const Request = struct {
|
||||
.zstd => |*zstd| zstd.deinit(),
|
||||
}
|
||||
|
||||
req.client.connection_pool.release(req.client, req.connection);
|
||||
req.client.connection_pool.release(req.client, req.connection.?);
|
||||
req.connection = null;
|
||||
|
||||
const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
|
||||
|
||||
@ -534,7 +537,7 @@ pub const Request = struct {
|
||||
|
||||
/// Send the request to the server.
|
||||
pub fn start(req: *Request) StartError!void {
|
||||
var buffered = std.io.bufferedWriter(req.connection.data.writer());
|
||||
var buffered = std.io.bufferedWriter(req.connection.?.data.writer());
|
||||
const w = buffered.writer();
|
||||
|
||||
try w.writeAll(@tagName(req.method));
|
||||
@ -544,7 +547,7 @@ pub const Request = struct {
|
||||
try w.writeAll(req.uri.host.?);
|
||||
try w.writeByte(':');
|
||||
try w.print("{}", .{req.uri.port.?});
|
||||
} else if (req.connection.data.proxied) {
|
||||
} else if (req.connection.?.data.proxied) {
|
||||
// proxied connections require the full uri
|
||||
try w.print("{+/}", .{req.uri});
|
||||
} else {
|
||||
@ -625,7 +628,7 @@ pub const Request = struct {
|
||||
|
||||
var index: usize = 0;
|
||||
while (index == 0) {
|
||||
const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip);
|
||||
const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip);
|
||||
if (amt == 0 and req.response.parser.done) break;
|
||||
index += amt;
|
||||
}
|
||||
@ -643,10 +646,10 @@ pub const Request = struct {
|
||||
pub fn wait(req: *Request) WaitError!void {
|
||||
while (true) { // handle redirects
|
||||
while (true) { // read headers
|
||||
try req.connection.data.fill();
|
||||
try req.connection.?.data.fill();
|
||||
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
|
||||
req.connection.data.drop(@intCast(u16, nchecked));
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
|
||||
req.connection.?.data.drop(@intCast(u16, nchecked));
|
||||
|
||||
if (req.response.parser.state.isContent()) break;
|
||||
}
|
||||
@ -654,12 +657,12 @@ pub const Request = struct {
|
||||
try req.response.parse(req.response.parser.header_bytes.items, false);
|
||||
|
||||
if (req.response.status == .switching_protocols) {
|
||||
req.connection.data.closing = false;
|
||||
req.connection.?.data.closing = false;
|
||||
req.response.parser.done = true;
|
||||
}
|
||||
|
||||
if (req.method == .CONNECT and req.response.status == .ok) {
|
||||
req.connection.data.closing = false;
|
||||
req.connection.?.data.closing = false;
|
||||
req.response.parser.done = true;
|
||||
}
|
||||
|
||||
@ -670,9 +673,9 @@ pub const Request = struct {
|
||||
const res_connection = req.response.headers.getFirstValue("connection");
|
||||
const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
|
||||
if (res_keepalive and (req_keepalive or req_connection == null)) {
|
||||
req.connection.data.closing = false;
|
||||
req.connection.?.data.closing = false;
|
||||
} else {
|
||||
req.connection.data.closing = true;
|
||||
req.connection.?.data.closing = true;
|
||||
}
|
||||
|
||||
if (req.response.transfer_encoding) |te| {
|
||||
@ -762,10 +765,10 @@ pub const Request = struct {
|
||||
const has_trail = !req.response.parser.state.isContent();
|
||||
|
||||
while (!req.response.parser.state.isContent()) { // read trailing headers
|
||||
try req.connection.data.fill();
|
||||
try req.connection.?.data.fill();
|
||||
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
|
||||
req.connection.data.drop(@intCast(u16, nchecked));
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
|
||||
req.connection.?.data.drop(@intCast(u16, nchecked));
|
||||
}
|
||||
|
||||
if (has_trail) {
|
||||
@ -803,16 +806,16 @@ pub const Request = struct {
|
||||
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
|
||||
switch (req.transfer_encoding) {
|
||||
.chunked => {
|
||||
try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
|
||||
try req.connection.data.writeAll(bytes);
|
||||
try req.connection.data.writeAll("\r\n");
|
||||
try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len});
|
||||
try req.connection.?.data.writeAll(bytes);
|
||||
try req.connection.?.data.writeAll("\r\n");
|
||||
|
||||
return bytes.len;
|
||||
},
|
||||
.content_length => |*len| {
|
||||
if (len.* < bytes.len) return error.MessageTooLong;
|
||||
|
||||
const amt = try req.connection.data.write(bytes);
|
||||
const amt = try req.connection.?.data.write(bytes);
|
||||
len.* -= amt;
|
||||
return amt;
|
||||
},
|
||||
@ -832,7 +835,7 @@ pub const Request = struct {
|
||||
/// Finish the body of a request. This notifies the server that you have no more data to send.
|
||||
pub fn finish(req: *Request) FinishError!void {
|
||||
switch (req.transfer_encoding) {
|
||||
.chunked => try req.connection.data.writeAll("0\r\n\r\n"),
|
||||
.chunked => try req.connection.?.data.writeAll("0\r\n\r\n"),
|
||||
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
|
||||
.none => {},
|
||||
}
|
||||
|
||||
@ -129,6 +129,15 @@ fn handleRequest(res: *Server.Response) !void {
|
||||
try res.writeAll("Hello, ");
|
||||
try res.writeAll("Redirected!\n");
|
||||
try res.finish();
|
||||
} else if (mem.eql(u8, res.request.target, "/redirect/invalid")) {
|
||||
const invalid_port = try getUnusedTcpPort();
|
||||
const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}", .{invalid_port});
|
||||
defer salloc.free(location);
|
||||
|
||||
res.status = .found;
|
||||
try res.headers.append("location", location);
|
||||
try res.do();
|
||||
try res.finish();
|
||||
} else {
|
||||
res.status = .not_found;
|
||||
try res.do();
|
||||
@ -180,6 +189,14 @@ fn killServer(addr: std.net.Address) void {
|
||||
conn.close();
|
||||
}
|
||||
|
||||
fn getUnusedTcpPort() !u16 {
|
||||
const addr = try std.net.Address.parseIp("127.0.0.1", 0);
|
||||
var s = std.net.StreamServer.init(.{});
|
||||
defer s.deinit();
|
||||
try s.listen(addr);
|
||||
return s.listen_address.in.getPort();
|
||||
}
|
||||
|
||||
pub fn main() !void {
|
||||
const log = std.log.scoped(.client);
|
||||
|
||||
@ -533,6 +550,27 @@ pub fn main() !void {
|
||||
// connection has been kept alive
|
||||
try testing.expect(client.connection_pool.free_len == 1);
|
||||
|
||||
{ // check client without segfault by connection error after redirection
|
||||
var h = http.Headers{ .allocator = calloc };
|
||||
defer h.deinit();
|
||||
|
||||
const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/invalid", .{port});
|
||||
defer calloc.free(location);
|
||||
const uri = try std.Uri.parse(location);
|
||||
|
||||
log.info("{s}", .{location});
|
||||
var req = try client.request(.GET, uri, h, .{});
|
||||
defer req.deinit();
|
||||
|
||||
try req.start();
|
||||
const result = req.wait();
|
||||
|
||||
try testing.expectError(error.ConnectionRefused, result); // expects not segfault but the regular error
|
||||
}
|
||||
|
||||
// connection has been kept alive
|
||||
try testing.expect(client.connection_pool.free_len == 1);
|
||||
|
||||
client.deinit();
|
||||
|
||||
killServer(server.socket.listen_address);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user