std.http: handle expect:100-continue and continue responses

This commit is contained in:
Nameless 2023-08-22 10:05:03 -05:00
parent 5d40338f21
commit aa090a49d9
No known key found for this signature in database
GPG Key ID: A477BC03CAFCCAF7
4 changed files with 161 additions and 47 deletions

View File

@ -478,6 +478,7 @@ pub const Request = struct {
.zstd => |*zstd| zstd.deinit(),
}
req.headers.deinit();
req.response.headers.deinit();
if (req.response.parser.header_bytes_owned) {
@ -667,17 +668,19 @@ pub const Request = struct {
try req.response.parse(req.response.parser.header_bytes.items, false);
if (req.response.status == .switching_protocols) {
if (req.response.status == .@"continue") {
req.response.parser.done = true; // we're done parsing the continue response, reset to prepare for the real response
req.response.parser.reset();
break;
}
// we're switching protocols, so this connection is no longer doing http
if (req.response.status == .switching_protocols or (req.method == .CONNECT and req.response.status == .ok)) {
req.connection.?.data.closing = false;
req.response.parser.done = true;
}
if (req.method == .CONNECT and req.response.status == .ok) {
req.connection.?.data.closing = false;
req.response.parser.done = true;
}
// we default to using keep-alive if not provided
// we default to using keep-alive if not provided in the client if the server asks for it
const req_connection = req.headers.getFirstValue("connection");
const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
@ -955,6 +958,38 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol:
return conn;
}
pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{NameTooLong} || std.os.ConnectError;
pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*ConnectionPool.Node {
if (client.connection_pool.findConnection(.{
.host = path,
.port = 0,
.is_tls = false,
})) |node|
return node;
const conn = try client.allocator.create(ConnectionPool.Node);
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
const stream = try std.net.connectUnixSocket(path);
errdefer stream.close();
conn.data = .{
.stream = stream,
.tls_client = undefined,
.protocol = .plain,
.host = try client.allocator.dupe(u8, path),
.port = 0,
};
errdefer client.allocator.free(conn.data.host);
client.connection_pool.addUsed(conn);
return conn;
}
// Prevents a dependency loop in request()
const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused };
pub const ConnectError = ConnectErrorPartial || RequestError;

View File

@ -411,48 +411,52 @@ pub const Response = struct {
}
try w.writeAll("\r\n");
if (!res.headers.contains("server")) {
try w.writeAll("Server: zig (std.http)\r\n");
}
if (!res.headers.contains("connection")) {
const req_connection = res.request.headers.getFirstValue("connection");
const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
if (req_keepalive) {
try w.writeAll("Connection: keep-alive\r\n");
} else {
try w.writeAll("Connection: close\r\n");
}
}
const has_transfer_encoding = res.headers.contains("transfer-encoding");
const has_content_length = res.headers.contains("content-length");
if (!has_transfer_encoding and !has_content_length) {
switch (res.transfer_encoding) {
.chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"),
.content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}),
.none => {},
}
if (res.status == .@"continue") {
res.state = .waited; // we still need to send another request after this
} else {
if (has_content_length) {
const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
if (!res.headers.contains("server")) {
try w.writeAll("Server: zig (std.http)\r\n");
}
res.transfer_encoding = .{ .content_length = content_length };
} else if (has_transfer_encoding) {
const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?;
if (std.mem.eql(u8, transfer_encoding, "chunked")) {
res.transfer_encoding = .chunked;
if (!res.headers.contains("connection")) {
const req_connection = res.request.headers.getFirstValue("connection");
const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
if (req_keepalive) {
try w.writeAll("Connection: keep-alive\r\n");
} else {
return error.UnsupportedTransferEncoding;
try w.writeAll("Connection: close\r\n");
}
}
const has_transfer_encoding = res.headers.contains("transfer-encoding");
const has_content_length = res.headers.contains("content-length");
if (!has_transfer_encoding and !has_content_length) {
switch (res.transfer_encoding) {
.chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"),
.content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}),
.none => {},
}
} else {
res.transfer_encoding = .none;
}
}
if (has_content_length) {
const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
try w.print("{}", .{res.headers});
res.transfer_encoding = .{ .content_length = content_length };
} else if (has_transfer_encoding) {
const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?;
if (std.mem.eql(u8, transfer_encoding, "chunked")) {
res.transfer_encoding = .chunked;
} else {
return error.UnsupportedTransferEncoding;
}
} else {
res.transfer_encoding = .none;
}
}
try w.print("{}", .{res.headers});
}
try w.writeAll("\r\n");
@ -516,6 +520,10 @@ pub const Response = struct {
res.request.parser.done = true;
}
if (res.request.method == .HEAD) {
res.request.parser.done = true;
}
if (!res.request.parser.done) {
if (res.request.transfer_compression) |tc| switch (tc) {
.compress => return error.CompressionNotSupported,

View File

@ -534,9 +534,9 @@ pub const HeadersParser = struct {
if (r.next_chunk_length == 0) r.done = true;
return 0;
} else {
const out_avail = buffer.len;
return out_index;
} else if (out_index < buffer.len) {
const out_avail = buffer.len - out_index;
const can_read = @as(usize, @intCast(@min(data_avail, out_avail)));
const nread = try conn.read(buffer[0..can_read]);
@ -545,6 +545,8 @@ pub const HeadersParser = struct {
if (r.next_chunk_length == 0) r.done = true;
return nread;
} else {
return out_index;
}
},
.chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
@ -558,6 +560,7 @@ pub const HeadersParser = struct {
.chunk_data => if (r.next_chunk_length == 0) {
if (std.mem.eql(u8, conn.peek(), "\r\n")) {
r.state = .finished;
r.done = true;
} else {
// The trailer section is formatted identically to the header section.
r.state = .seen_rn;

View File

@ -22,6 +22,18 @@ fn handleRequest(res: *Server.Response) !void {
log.info("{s} {s} {s}", .{ @tagName(res.request.method), @tagName(res.request.version), res.request.target });
if (res.request.headers.contains("expect")) {
if (mem.eql(u8, res.request.headers.getFirstValue("expect").?, "100-continue")) {
res.status = .@"continue";
try res.do();
res.status = .ok;
} else {
res.status = .expectation_failed;
try res.do();
return;
}
}
const body = try res.reader().readAllAlloc(salloc, 8192);
defer salloc.free(body);
@ -62,7 +74,7 @@ fn handleRequest(res: *Server.Response) !void {
}
try res.finish();
} else if (mem.eql(u8, res.request.target, "/echo-content")) {
} else if (mem.startsWith(u8, res.request.target, "/echo-content")) {
try testing.expectEqualStrings("Hello, World!\n", body);
try testing.expectEqualStrings("text/plain", res.request.headers.getFirstValue("content-type").?);
@ -592,6 +604,62 @@ pub fn main() !void {
try testing.expectEqualStrings("Hello, World!\n", res.body.?);
}
{ // expect: 100-continue
var h = http.Headers{ .allocator = calloc };
defer h.deinit();
try h.append("expect", "100-continue");
try h.append("content-type", "text/plain");
const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-100", .{port});
defer calloc.free(location);
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
var req = try client.request(.POST, uri, h, .{});
defer req.deinit();
req.transfer_encoding = .chunked;
try req.start();
try req.wait();
try testing.expectEqual(http.Status.@"continue", req.response.status);
try req.writeAll("Hello, ");
try req.writeAll("World!\n");
try req.finish();
try req.wait();
try testing.expectEqual(http.Status.ok, req.response.status);
const body = try req.reader().readAllAlloc(calloc, 8192);
defer calloc.free(body);
try testing.expectEqualStrings("Hello, World!\n", body);
}
{ // expect: garbage
var h = http.Headers{ .allocator = calloc };
defer h.deinit();
try h.append("content-type", "text/plain");
try h.append("expect", "garbage");
const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-garbage", .{port});
defer calloc.free(location);
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
var req = try client.request(.POST, uri, h, .{});
defer req.deinit();
req.transfer_encoding = .chunked;
try req.start();
try req.wait();
try testing.expectEqual(http.Status.expectation_failed, req.response.status);
}
{ // issue 16282 *** This test leaves the client in an invalid state, it must be last ***
const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port});
defer calloc.free(location);