std.http: curate some Server errors, fix reading chunked bodies

This commit is contained in:
Nameless 2023-04-16 16:26:25 -05:00
parent 134294230a
commit 85221b4e97
No known key found for this signature in database
GPG Key ID: A477BC03CAFCCAF7
3 changed files with 138 additions and 81 deletions

View File

@ -193,7 +193,13 @@ pub const Connection = struct {
};
}
pub const ReadError = error{ TlsFailure, TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure };
pub const ReadError = error{
TlsFailure,
TlsAlert,
ConnectionTimedOut,
ConnectionResetByPeer,
UnexpectedReadFailure,
};
pub const Reader = std.io.Reader(*Connection, ReadError, read);
@ -518,7 +524,10 @@ pub const Request = struct {
req.* = undefined;
}
pub fn start(req: *Request, uri: Uri) !void {
pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
/// Send the request to the server.
pub fn start(req: *Request, uri: Uri) StartError!void {
var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
const w = buffered.writer();
@ -575,7 +584,7 @@ pub const Request = struct {
}
} else {
if (has_content_length) {
const content_length = try std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10);
const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
req.transfer_encoding = .{ .content_length = content_length };
} else if (has_transfer_encoding) {
@ -618,7 +627,7 @@ pub const Request = struct {
return index;
}
pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed };
pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported };
/// Waits for a response from the server and parses any headers that are sent.
/// This function will block until the final response is received.
@ -739,25 +748,23 @@ pub const Request = struct {
/// Reads data from the response body. Must be called after `do`.
pub fn read(req: *Request, buffer: []u8) ReadError!usize {
while (true) {
const out_index = switch (req.response.compression) {
.deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
else => try req.transferRead(buffer),
};
const out_index = switch (req.response.compression) {
.deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
else => try req.transferRead(buffer),
};
if (out_index == 0) {
while (!req.response.parser.state.isContent()) { // read trailing headers
try req.connection.data.buffered.fill();
if (out_index == 0) {
while (!req.response.parser.state.isContent()) { // read trailing headers
try req.connection.data.buffered.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
req.connection.data.buffered.clear(@intCast(u16, nchecked));
}
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
req.connection.data.buffered.clear(@intCast(u16, nchecked));
}
return out_index;
}
return out_index;
}
/// Reads data from the response body. Must be called after `do`.
@ -800,15 +807,19 @@ pub const Request = struct {
}
}
pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
var index: usize = 0;
while (index < bytes.len) {
index += try write(req, bytes[index..]);
}
}
pub const FinishError = WriteError || error{MessageNotCompleted};
/// Finish the body of a request. This notifies the server that you have no more data to send.
pub fn finish(req: *Request) FinishError!void {
switch (req.transfer_encoding) {
.chunked => req.connection.data.conn.writeAll("0\r\n\r\n") catch |err| {
req.client.last_error = .{ .write = err };
return error.WriteFailed;
},
.chunked => try req.connection.data.conn.writeAll("0\r\n\r\n"),
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
.none => {},
}
@ -923,7 +934,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
}
}
pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || std.fmt.ParseIntError || BufferedConnection.WriteError || error{
pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || BufferedConnection.WriteError || error{
UnsupportedUrlScheme,
UriMissingHost,
@ -998,6 +1009,7 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option
.handle_redirects = options.handle_redirects,
.response = .{
.status = undefined,
.reason = undefined,
.version = undefined,
.headers = undefined,
.parser = switch (options.header_strategy) {
@ -1011,8 +1023,6 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option
req.arena = std.heap.ArenaAllocator.init(client.allocator);
try req.start(uri);
return req;
}

View File

@ -23,21 +23,33 @@ pub const Connection = struct {
pub const Protocol = enum { plain };
pub fn read(conn: *Connection, buffer: []u8) !usize {
switch (conn.protocol) {
.plain => return conn.stream.read(buffer),
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
return switch (conn.protocol) {
.plain => conn.stream.read(buffer),
// .tls => return conn.tls_client.read(conn.stream, buffer),
}
} catch |err| switch (err) {
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
}
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
switch (conn.protocol) {
.plain => return conn.stream.readAtLeast(buffer, len),
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
return switch (conn.protocol) {
.plain => conn.stream.readAtLeast(buffer, len),
// .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
}
} catch |err| switch (err) {
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
}
pub const ReadError = net.Stream.ReadError;
pub const ReadError = error{
ConnectionTimedOut,
ConnectionResetByPeer,
UnexpectedReadFailure,
};
pub const Reader = std.io.Reader(*Connection, ReadError, read);
@ -45,21 +57,31 @@ pub const Connection = struct {
return Reader{ .context = conn };
}
pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
switch (conn.protocol) {
.plain => return conn.stream.writeAll(buffer),
pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void {
return switch (conn.protocol) {
.plain => conn.stream.writeAll(buffer),
// .tls => return conn.tls_client.writeAll(conn.stream, buffer),
}
} catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
pub fn write(conn: *Connection, buffer: []const u8) !usize {
switch (conn.protocol) {
.plain => return conn.stream.write(buffer),
pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
return switch (conn.protocol) {
.plain => conn.stream.write(buffer),
// .tls => return conn.tls_client.write(conn.stream, buffer),
}
} catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
pub const WriteError = net.Stream.WriteError || error{};
pub const WriteError = error{
ConnectionResetByPeer,
UnexpectedWriteFailure,
};
pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer {
@ -155,6 +177,25 @@ pub const BufferedConnection = struct {
}
};
/// The mode of transport for responses.
pub const ResponseTransfer = union(enum) {
content_length: u64,
chunked: void,
none: void,
};
/// The decompressor for request messages.
pub const Compression = union(enum) {
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
deflate: DeflateDecompressor,
gzip: GzipDecompressor,
zstd: ZstdDecompressor,
none: void,
};
/// A HTTP request originating from a client.
pub const Request = struct {
pub const ParseError = Allocator.Error || error{
@ -165,10 +206,11 @@ pub const Request = struct {
HttpHeaderContinuationsUnsupported,
HttpTransferEncodingUnsupported,
HttpConnectionHeaderUnsupported,
InvalidCharacter,
InvalidContentLength,
CompressionNotSupported,
};
pub fn parse(req: *Request, bytes: []const u8) !void {
pub fn parse(req: *Request, bytes: []const u8) ParseError!void {
var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
const first_line = it.next() orelse return error.HttpHeadersInvalid;
@ -211,7 +253,7 @@ pub const Request = struct {
if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
if (req.content_length != null) return error.HttpHeadersInvalid;
req.content_length = try std.fmt.parseInt(u64, header_value, 10);
req.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
} else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
// Transfer-Encoding: second, first
// Transfer-Encoding: deflate, chunked
@ -321,6 +363,8 @@ pub const Response = struct {
}
}
pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength };
/// Send the response headers.
pub fn do(res: *Response) !void {
var buffered = std.io.bufferedWriter(res.connection.writer());
@ -356,7 +400,7 @@ pub const Response = struct {
}
} else {
if (has_content_length) {
const content_length = try std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10);
const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
res.transfer_encoding = .{ .content_length = content_length };
} else if (has_transfer_encoding) {
@ -386,23 +430,23 @@ pub const Response = struct {
return .{ .context = res };
}
pub fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
if (res.request.parser.isComplete()) return 0;
fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
if (res.request.parser.done) return 0;
var index: usize = 0;
while (index == 0) {
const amt = try res.request.parser.read(&res.connection, buf[index..], false);
if (amt == 0 and res.request.parser.isComplete()) break;
if (amt == 0 and res.request.parser.done) break;
index += amt;
}
return index;
}
pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported};
pub const WaitError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported };
/// Wait for the client to send a complete request head.
pub fn wait(res: *Response) !void {
pub fn wait(res: *Response) WaitError!void {
while (true) {
try res.connection.fill();
@ -445,10 +489,10 @@ pub const Response = struct {
if (res.request.transfer_compression) |tc| switch (tc) {
.compress => return error.CompressionNotSupported,
.deflate => res.request.compression = .{
.deflate = try std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()),
.deflate = std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed,
},
.gzip => res.request.compression = .{
.gzip = try std.compress.gzip.decompress(res.server.allocator, res.transferReader()),
.gzip = std.compress.gzip.decompress(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed,
},
.zstd => res.request.compression = .{
.zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()),
@ -457,7 +501,7 @@ pub const Response = struct {
}
}
pub const ReadError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error || WaitForCompleteHeadError;
pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{DecompressionFailure};
pub const Reader = std.io.Reader(*Response, ReadError, read);
@ -466,12 +510,23 @@ pub const Response = struct {
}
pub fn read(res: *Response, buffer: []u8) ReadError!usize {
return switch (res.request.compression) {
.deflate => |*deflate| try deflate.read(buffer),
.gzip => |*gzip| try gzip.read(buffer),
.zstd => |*zstd| try zstd.read(buffer),
const out_index = switch (res.request.compression) {
.deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
else => try res.transferRead(buffer),
};
if (out_index == 0) {
while (!res.request.parser.state.isContent()) { // read trailing headers
try res.connection.fill();
const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek());
res.connection.clear(@intCast(u16, nchecked));
}
}
return out_index;
}
pub fn readAll(res: *Response, buffer: []u8) !usize {
@ -513,9 +568,18 @@ pub const Response = struct {
}
}
pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
var index: usize = 0;
while (index < bytes.len) {
index += try write(req, bytes[index..]);
}
}
pub const FinishError = WriteError || error{MessageNotCompleted};
/// Finish the body of a request. This notifies the server that you have no more data to send.
pub fn finish(res: *Response) !void {
switch (res.headers.transfer_encoding) {
pub fn finish(res: *Response) FinishError!void {
switch (res.transfer_encoding) {
.chunked => try res.connection.writeAll("0\r\n\r\n"),
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
.none => {},
@ -523,25 +587,6 @@ pub const Response = struct {
}
};
/// The mode of transport for responses.
pub const ResponseTransfer = union(enum) {
content_length: u64,
chunked: void,
none: void,
};
/// The decompressor for request messages.
pub const Compression = union(enum) {
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
deflate: DeflateDecompressor,
gzip: GzipDecompressor,
zstd: ZstdDecompressor,
none: void,
};
pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server {
return .{
.allocator = allocator,

View File

@ -485,6 +485,8 @@ fn fetchAndUnpack(
var req = try http_client.request(uri, h, .{ .method = .GET });
defer req.deinit();
try req.start();
try req.do();
if (mem.endsWith(u8, uri.path, ".tar.gz")) {