mirror of
https://github.com/ziglang/zig.git
synced 2026-01-20 14:25:16 +00:00
std.http: curate some Server errors, fix reading chunked bodies
This commit is contained in:
parent
134294230a
commit
85221b4e97
@ -193,7 +193,13 @@ pub const Connection = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub const ReadError = error{ TlsFailure, TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure };
|
||||
pub const ReadError = error{
|
||||
TlsFailure,
|
||||
TlsAlert,
|
||||
ConnectionTimedOut,
|
||||
ConnectionResetByPeer,
|
||||
UnexpectedReadFailure,
|
||||
};
|
||||
|
||||
pub const Reader = std.io.Reader(*Connection, ReadError, read);
|
||||
|
||||
@ -518,7 +524,10 @@ pub const Request = struct {
|
||||
req.* = undefined;
|
||||
}
|
||||
|
||||
pub fn start(req: *Request, uri: Uri) !void {
|
||||
pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
|
||||
|
||||
/// Send the request to the server.
|
||||
pub fn start(req: *Request, uri: Uri) StartError!void {
|
||||
var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
|
||||
const w = buffered.writer();
|
||||
|
||||
@ -575,7 +584,7 @@ pub const Request = struct {
|
||||
}
|
||||
} else {
|
||||
if (has_content_length) {
|
||||
const content_length = try std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10);
|
||||
const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
|
||||
|
||||
req.transfer_encoding = .{ .content_length = content_length };
|
||||
} else if (has_transfer_encoding) {
|
||||
@ -618,7 +627,7 @@ pub const Request = struct {
|
||||
return index;
|
||||
}
|
||||
|
||||
pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed };
|
||||
pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported };
|
||||
|
||||
/// Waits for a response from the server and parses any headers that are sent.
|
||||
/// This function will block until the final response is received.
|
||||
@ -739,25 +748,23 @@ pub const Request = struct {
|
||||
|
||||
/// Reads data from the response body. Must be called after `do`.
|
||||
pub fn read(req: *Request, buffer: []u8) ReadError!usize {
|
||||
while (true) {
|
||||
const out_index = switch (req.response.compression) {
|
||||
.deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
|
||||
.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
|
||||
.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
|
||||
else => try req.transferRead(buffer),
|
||||
};
|
||||
const out_index = switch (req.response.compression) {
|
||||
.deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
|
||||
.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
|
||||
.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
|
||||
else => try req.transferRead(buffer),
|
||||
};
|
||||
|
||||
if (out_index == 0) {
|
||||
while (!req.response.parser.state.isContent()) { // read trailing headers
|
||||
try req.connection.data.buffered.fill();
|
||||
if (out_index == 0) {
|
||||
while (!req.response.parser.state.isContent()) { // read trailing headers
|
||||
try req.connection.data.buffered.fill();
|
||||
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
|
||||
req.connection.data.buffered.clear(@intCast(u16, nchecked));
|
||||
}
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
|
||||
req.connection.data.buffered.clear(@intCast(u16, nchecked));
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
/// Reads data from the response body. Must be called after `do`.
|
||||
@ -800,15 +807,19 @@ pub const Request = struct {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
|
||||
var index: usize = 0;
|
||||
while (index < bytes.len) {
|
||||
index += try write(req, bytes[index..]);
|
||||
}
|
||||
}
|
||||
|
||||
pub const FinishError = WriteError || error{MessageNotCompleted};
|
||||
|
||||
/// Finish the body of a request. This notifies the server that you have no more data to send.
|
||||
pub fn finish(req: *Request) FinishError!void {
|
||||
switch (req.transfer_encoding) {
|
||||
.chunked => req.connection.data.conn.writeAll("0\r\n\r\n") catch |err| {
|
||||
req.client.last_error = .{ .write = err };
|
||||
return error.WriteFailed;
|
||||
},
|
||||
.chunked => try req.connection.data.conn.writeAll("0\r\n\r\n"),
|
||||
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
|
||||
.none => {},
|
||||
}
|
||||
@ -923,7 +934,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
|
||||
}
|
||||
}
|
||||
|
||||
pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || std.fmt.ParseIntError || BufferedConnection.WriteError || error{
|
||||
pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || BufferedConnection.WriteError || error{
|
||||
UnsupportedUrlScheme,
|
||||
UriMissingHost,
|
||||
|
||||
@ -998,6 +1009,7 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option
|
||||
.handle_redirects = options.handle_redirects,
|
||||
.response = .{
|
||||
.status = undefined,
|
||||
.reason = undefined,
|
||||
.version = undefined,
|
||||
.headers = undefined,
|
||||
.parser = switch (options.header_strategy) {
|
||||
@ -1011,8 +1023,6 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option
|
||||
|
||||
req.arena = std.heap.ArenaAllocator.init(client.allocator);
|
||||
|
||||
try req.start(uri);
|
||||
|
||||
return req;
|
||||
}
|
||||
|
||||
|
||||
@ -23,21 +23,33 @@ pub const Connection = struct {
|
||||
|
||||
pub const Protocol = enum { plain };
|
||||
|
||||
pub fn read(conn: *Connection, buffer: []u8) !usize {
|
||||
switch (conn.protocol) {
|
||||
.plain => return conn.stream.read(buffer),
|
||||
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
|
||||
return switch (conn.protocol) {
|
||||
.plain => conn.stream.read(buffer),
|
||||
// .tls => return conn.tls_client.read(conn.stream, buffer),
|
||||
}
|
||||
} catch |err| switch (err) {
|
||||
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
|
||||
switch (conn.protocol) {
|
||||
.plain => return conn.stream.readAtLeast(buffer, len),
|
||||
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
|
||||
return switch (conn.protocol) {
|
||||
.plain => conn.stream.readAtLeast(buffer, len),
|
||||
// .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
|
||||
}
|
||||
} catch |err| switch (err) {
|
||||
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
};
|
||||
}
|
||||
|
||||
pub const ReadError = net.Stream.ReadError;
|
||||
pub const ReadError = error{
|
||||
ConnectionTimedOut,
|
||||
ConnectionResetByPeer,
|
||||
UnexpectedReadFailure,
|
||||
};
|
||||
|
||||
pub const Reader = std.io.Reader(*Connection, ReadError, read);
|
||||
|
||||
@ -45,21 +57,31 @@ pub const Connection = struct {
|
||||
return Reader{ .context = conn };
|
||||
}
|
||||
|
||||
pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
|
||||
switch (conn.protocol) {
|
||||
.plain => return conn.stream.writeAll(buffer),
|
||||
pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void {
|
||||
return switch (conn.protocol) {
|
||||
.plain => conn.stream.writeAll(buffer),
|
||||
// .tls => return conn.tls_client.writeAll(conn.stream, buffer),
|
||||
}
|
||||
} catch |err| switch (err) {
|
||||
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedWriteFailure,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn write(conn: *Connection, buffer: []const u8) !usize {
|
||||
switch (conn.protocol) {
|
||||
.plain => return conn.stream.write(buffer),
|
||||
pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
|
||||
return switch (conn.protocol) {
|
||||
.plain => conn.stream.write(buffer),
|
||||
// .tls => return conn.tls_client.write(conn.stream, buffer),
|
||||
}
|
||||
} catch |err| switch (err) {
|
||||
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedWriteFailure,
|
||||
};
|
||||
}
|
||||
|
||||
pub const WriteError = net.Stream.WriteError || error{};
|
||||
pub const WriteError = error{
|
||||
ConnectionResetByPeer,
|
||||
UnexpectedWriteFailure,
|
||||
};
|
||||
|
||||
pub const Writer = std.io.Writer(*Connection, WriteError, write);
|
||||
|
||||
pub fn writer(conn: *Connection) Writer {
|
||||
@ -155,6 +177,25 @@ pub const BufferedConnection = struct {
|
||||
}
|
||||
};
|
||||
|
||||
/// The mode of transport for responses.
|
||||
pub const ResponseTransfer = union(enum) {
|
||||
content_length: u64,
|
||||
chunked: void,
|
||||
none: void,
|
||||
};
|
||||
|
||||
/// The decompressor for request messages.
|
||||
pub const Compression = union(enum) {
|
||||
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
|
||||
pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
|
||||
pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
|
||||
|
||||
deflate: DeflateDecompressor,
|
||||
gzip: GzipDecompressor,
|
||||
zstd: ZstdDecompressor,
|
||||
none: void,
|
||||
};
|
||||
|
||||
/// A HTTP request originating from a client.
|
||||
pub const Request = struct {
|
||||
pub const ParseError = Allocator.Error || error{
|
||||
@ -165,10 +206,11 @@ pub const Request = struct {
|
||||
HttpHeaderContinuationsUnsupported,
|
||||
HttpTransferEncodingUnsupported,
|
||||
HttpConnectionHeaderUnsupported,
|
||||
InvalidCharacter,
|
||||
InvalidContentLength,
|
||||
CompressionNotSupported,
|
||||
};
|
||||
|
||||
pub fn parse(req: *Request, bytes: []const u8) !void {
|
||||
pub fn parse(req: *Request, bytes: []const u8) ParseError!void {
|
||||
var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
|
||||
|
||||
const first_line = it.next() orelse return error.HttpHeadersInvalid;
|
||||
@ -211,7 +253,7 @@ pub const Request = struct {
|
||||
|
||||
if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
|
||||
if (req.content_length != null) return error.HttpHeadersInvalid;
|
||||
req.content_length = try std.fmt.parseInt(u64, header_value, 10);
|
||||
req.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
|
||||
} else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
|
||||
// Transfer-Encoding: second, first
|
||||
// Transfer-Encoding: deflate, chunked
|
||||
@ -321,6 +363,8 @@ pub const Response = struct {
|
||||
}
|
||||
}
|
||||
|
||||
pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength };
|
||||
|
||||
/// Send the response headers.
|
||||
pub fn do(res: *Response) !void {
|
||||
var buffered = std.io.bufferedWriter(res.connection.writer());
|
||||
@ -356,7 +400,7 @@ pub const Response = struct {
|
||||
}
|
||||
} else {
|
||||
if (has_content_length) {
|
||||
const content_length = try std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10);
|
||||
const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
|
||||
|
||||
res.transfer_encoding = .{ .content_length = content_length };
|
||||
} else if (has_transfer_encoding) {
|
||||
@ -386,23 +430,23 @@ pub const Response = struct {
|
||||
return .{ .context = res };
|
||||
}
|
||||
|
||||
pub fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
|
||||
if (res.request.parser.isComplete()) return 0;
|
||||
fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
|
||||
if (res.request.parser.done) return 0;
|
||||
|
||||
var index: usize = 0;
|
||||
while (index == 0) {
|
||||
const amt = try res.request.parser.read(&res.connection, buf[index..], false);
|
||||
if (amt == 0 and res.request.parser.isComplete()) break;
|
||||
if (amt == 0 and res.request.parser.done) break;
|
||||
index += amt;
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported};
|
||||
pub const WaitError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported };
|
||||
|
||||
/// Wait for the client to send a complete request head.
|
||||
pub fn wait(res: *Response) !void {
|
||||
pub fn wait(res: *Response) WaitError!void {
|
||||
while (true) {
|
||||
try res.connection.fill();
|
||||
|
||||
@ -445,10 +489,10 @@ pub const Response = struct {
|
||||
if (res.request.transfer_compression) |tc| switch (tc) {
|
||||
.compress => return error.CompressionNotSupported,
|
||||
.deflate => res.request.compression = .{
|
||||
.deflate = try std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()),
|
||||
.deflate = std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed,
|
||||
},
|
||||
.gzip => res.request.compression = .{
|
||||
.gzip = try std.compress.gzip.decompress(res.server.allocator, res.transferReader()),
|
||||
.gzip = std.compress.gzip.decompress(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed,
|
||||
},
|
||||
.zstd => res.request.compression = .{
|
||||
.zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()),
|
||||
@ -457,7 +501,7 @@ pub const Response = struct {
|
||||
}
|
||||
}
|
||||
|
||||
pub const ReadError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error || WaitForCompleteHeadError;
|
||||
pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{DecompressionFailure};
|
||||
|
||||
pub const Reader = std.io.Reader(*Response, ReadError, read);
|
||||
|
||||
@ -466,12 +510,23 @@ pub const Response = struct {
|
||||
}
|
||||
|
||||
pub fn read(res: *Response, buffer: []u8) ReadError!usize {
|
||||
return switch (res.request.compression) {
|
||||
.deflate => |*deflate| try deflate.read(buffer),
|
||||
.gzip => |*gzip| try gzip.read(buffer),
|
||||
.zstd => |*zstd| try zstd.read(buffer),
|
||||
const out_index = switch (res.request.compression) {
|
||||
.deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
|
||||
.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
|
||||
.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
|
||||
else => try res.transferRead(buffer),
|
||||
};
|
||||
|
||||
if (out_index == 0) {
|
||||
while (!res.request.parser.state.isContent()) { // read trailing headers
|
||||
try res.connection.fill();
|
||||
|
||||
const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek());
|
||||
res.connection.clear(@intCast(u16, nchecked));
|
||||
}
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
pub fn readAll(res: *Response, buffer: []u8) !usize {
|
||||
@ -513,9 +568,18 @@ pub const Response = struct {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
|
||||
var index: usize = 0;
|
||||
while (index < bytes.len) {
|
||||
index += try write(req, bytes[index..]);
|
||||
}
|
||||
}
|
||||
|
||||
pub const FinishError = WriteError || error{MessageNotCompleted};
|
||||
|
||||
/// Finish the body of a request. This notifies the server that you have no more data to send.
|
||||
pub fn finish(res: *Response) !void {
|
||||
switch (res.headers.transfer_encoding) {
|
||||
pub fn finish(res: *Response) FinishError!void {
|
||||
switch (res.transfer_encoding) {
|
||||
.chunked => try res.connection.writeAll("0\r\n\r\n"),
|
||||
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
|
||||
.none => {},
|
||||
@ -523,25 +587,6 @@ pub const Response = struct {
|
||||
}
|
||||
};
|
||||
|
||||
/// The mode of transport for responses.
|
||||
pub const ResponseTransfer = union(enum) {
|
||||
content_length: u64,
|
||||
chunked: void,
|
||||
none: void,
|
||||
};
|
||||
|
||||
/// The decompressor for request messages.
|
||||
pub const Compression = union(enum) {
|
||||
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
|
||||
pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
|
||||
pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
|
||||
|
||||
deflate: DeflateDecompressor,
|
||||
gzip: GzipDecompressor,
|
||||
zstd: ZstdDecompressor,
|
||||
none: void,
|
||||
};
|
||||
|
||||
pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server {
|
||||
return .{
|
||||
.allocator = allocator,
|
||||
|
||||
@ -485,6 +485,8 @@ fn fetchAndUnpack(
|
||||
var req = try http_client.request(uri, h, .{ .method = .GET });
|
||||
defer req.deinit();
|
||||
|
||||
try req.start();
|
||||
|
||||
try req.do();
|
||||
|
||||
if (mem.endsWith(u8, uri.path, ".tar.gz")) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user