std.http: further curate error set, remove last_error

This commit is contained in:
Nameless 2023-04-12 23:26:40 -05:00
parent 038ed32cff
commit 2c492064fb
No known key found for this signature in database
GPG Key ID: A477BC03CAFCCAF7

View File

@ -25,9 +25,6 @@ next_https_rescan_certs: bool = true,
/// The pool of connections that can be reused (and currently in use).
connection_pool: ConnectionPool = .{},
/// The last error that occurred on this client. This is not threadsafe, do not expect it to be completely accurate.
last_error: ?ExtraError = null,
pub const ExtraError = union(enum) {
pub const TcpConnectError = std.net.TcpConnectToHostError;
pub const TlsError = std.crypto.tls.Client.InitError(net.Stream);
@ -184,31 +181,33 @@ pub const Connection = struct {
pub const Protocol = enum { plain, tls };
pub fn read(conn: *Connection, buffer: []u8) !usize {
switch (conn.protocol) {
.plain => return conn.stream.read(buffer),
.tls => return conn.tls_client.read(conn.stream, buffer),
}
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
return switch (conn.protocol) {
.plain => conn.stream.read(buffer),
.tls => conn.tls_client.read(conn.stream, buffer),
} catch |err| switch (err) {
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
error.TlsAlert => return error.TlsAlert,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
}
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
switch (conn.protocol) {
.plain => return conn.stream.readAtLeast(buffer, len),
.tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
}
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
return switch (conn.protocol) {
.plain => conn.stream.readAtLeast(buffer, len),
.tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
} catch |err| switch (err) {
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
error.TlsAlert => return error.TlsAlert,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
}
pub const ReadError = net.Stream.ReadError || error{
TlsConnectionTruncated,
TlsRecordOverflow,
TlsDecodeError,
TlsAlert,
TlsBadRecordMac,
Overflow,
TlsBadLength,
TlsIllegalParameter,
TlsUnexpectedMessage,
};
pub const ReadError = error{ TlsFailure, TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure };
pub const Reader = std.io.Reader(*Connection, ReadError, read);
@ -217,20 +216,30 @@ pub const Connection = struct {
}
pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
switch (conn.protocol) {
.plain => return conn.stream.writeAll(buffer),
.tls => return conn.tls_client.writeAll(conn.stream, buffer),
}
return switch (conn.protocol) {
.plain => conn.stream.writeAll(buffer),
.tls => conn.tls_client.writeAll(conn.stream, buffer),
} catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
pub fn write(conn: *Connection, buffer: []const u8) !usize {
switch (conn.protocol) {
.plain => return conn.stream.write(buffer),
.tls => return conn.tls_client.write(conn.stream, buffer),
}
return switch (conn.protocol) {
.plain => conn.stream.write(buffer),
.tls => conn.tls_client.write(conn.stream, buffer),
} catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
pub const WriteError = net.Stream.WriteError || error{};
pub const WriteError = error{
ConnectionResetByPeer,
UnexpectedWriteFailure,
};
pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer {
@ -604,7 +613,7 @@ pub const Request = struct {
try buffered.flush();
}
pub const TransferReadError = proto.HeadersParser.ReadError || error{ReadFailed};
pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
@ -617,10 +626,7 @@ pub const Request = struct {
var index: usize = 0;
while (index == 0) {
const amt = req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip) catch |err| {
req.client.last_error = .{ .read = err };
return error.ReadFailed;
};
const amt = try req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip);
if (amt == 0 and req.response.parser.done) break;
index += amt;
}
@ -638,10 +644,7 @@ pub const Request = struct {
pub fn do(req: *Request) DoError!void {
while (true) { // handle redirects
while (true) { // read headers
req.connection.data.buffered.fill() catch |err| {
req.client.last_error = .{ .read = err };
return error.ReadFailed;
};
try req.connection.data.buffered.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
req.connection.data.buffered.clear(@intCast(u16, nchecked));
@ -712,16 +715,10 @@ pub const Request = struct {
if (req.response.headers.transfer_compression) |tc| switch (tc) {
.compress => return error.CompressionNotSupported,
.deflate => req.response.compression = .{
.deflate = std.compress.zlib.zlibStream(req.client.allocator, req.transferReader()) catch |err| {
req.client.last_error = .{ .zlib_init = err };
return error.CompressionInitializationFailed;
},
.deflate = std.compress.zlib.zlibStream(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed,
},
.gzip => req.response.compression = .{
.gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch |err| {
req.client.last_error = .{ .gzip_init = err };
return error.CompressionInitializationFailed;
},
.gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed,
},
.zstd => req.response.compression = .{
.zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()),
@ -734,7 +731,7 @@ pub const Request = struct {
}
}
pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError;
pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{DecompressionFailure};
pub const Reader = std.io.Reader(*Request, ReadError, read);
@ -746,30 +743,15 @@ pub const Request = struct {
pub fn read(req: *Request, buffer: []u8) ReadError!usize {
while (true) {
const out_index = switch (req.response.compression) {
.deflate => |*deflate| deflate.read(buffer) catch |err| {
req.client.last_error = .{ .decompress = err };
err catch {};
return error.ReadFailed;
},
.gzip => |*gzip| gzip.read(buffer) catch |err| {
req.client.last_error = .{ .decompress = err };
err catch {};
return error.ReadFailed;
},
.zstd => |*zstd| zstd.read(buffer) catch |err| {
req.client.last_error = .{ .decompress = err };
err catch {};
return error.ReadFailed;
},
.deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
else => try req.transferRead(buffer),
};
if (out_index == 0) {
while (!req.response.parser.state.isContent()) { // read trailing headers
req.connection.data.buffered.fill() catch |err| {
req.client.last_error = .{ .read = err };
return error.ReadFailed;
};
try req.connection.data.buffered.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
req.connection.data.buffered.clear(@intCast(u16, nchecked));
@ -784,17 +766,14 @@ pub const Request = struct {
pub fn readAll(req: *Request, buffer: []u8) !usize {
var index: usize = 0;
while (index < buffer.len) {
const amt = read(req, buffer[index..]) catch |err| {
req.client.last_error = .{ .read = err };
return error.ReadFailed;
};
const amt = try read(req, buffer[index..]);
if (amt == 0) break;
index += amt;
}
return index;
}
pub const WriteError = error{ WriteFailed, NotWriteable, MessageTooLong };
pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
pub const Writer = std.io.Writer(*Request, WriteError, write);
@ -806,28 +785,16 @@ pub const Request = struct {
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
switch (req.headers.transfer_encoding) {
.chunked => {
req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len}) catch |err| {
req.client.last_error = .{ .write = err };
return error.WriteFailed;
};
req.connection.data.conn.writeAll(bytes) catch |err| {
req.client.last_error = .{ .write = err };
return error.WriteFailed;
};
req.connection.data.conn.writeAll("\r\n") catch |err| {
req.client.last_error = .{ .write = err };
return error.WriteFailed;
};
try req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len});
try req.connection.data.conn.writeAll(bytes);
try req.connection.data.conn.writeAll("\r\n");
return bytes.len;
},
.content_length => |*len| {
if (len.* < bytes.len) return error.MessageTooLong;
const amt = req.connection.data.conn.write(bytes) catch |err| {
req.client.last_error = .{ .write = err };
return error.WriteFailed;
};
const amt = try req.connection.data.conn.write(bytes);
len.* -= amt;
return amt;
},
@ -835,8 +802,10 @@ pub const Request = struct {
}
}
pub const FinishError = WriteError || error{ MessageNotCompleted };
/// Finish the body of a request. This notifies the server that you have no more data to send.
pub fn finish(req: *Request) !void {
pub fn finish(req: *Request) FinishError!void {
switch (req.headers.transfer_encoding) {
.chunked => req.connection.data.conn.writeAll("0\r\n\r\n") catch |err| {
req.client.last_error = .{ .write = err };
@ -857,7 +826,7 @@ pub fn deinit(client: *Client) void {
client.* = undefined;
}
pub const ConnectError = Allocator.Error || error{ ConnectionFailed, TlsInitializationFailed };
pub const ConnectError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
/// This function is threadsafe.
@ -873,9 +842,16 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| {
client.last_error = .{ .connect = err };
return error.ConnectionFailed;
const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) {
error.ConnectionRefused => return error.ConnectionRefused,
error.NetworkUnreachable => return error.NetworkUnreachable,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure,
error.NameServerFailure => return error.NameServerFailure,
error.UnknownHostName => return error.UnknownHostName,
error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses,
else => return error.UnexpectedConnectFailure,
};
errdefer stream.close();
@ -896,10 +872,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client);
errdefer client.allocator.destroy(conn.data.buffered.conn.tls_client);
conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch |err| {
client.last_error = .{ .tls = err };
return error.TlsInitializationFailed;
};
conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
conn.data.buffered.conn.tls_client.allow_truncation_attacks = true;
@ -911,12 +884,11 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
return conn;
}
pub const RequestError = ConnectError || error{
pub const RequestError = ConnectError || BufferedConnection.WriteError || error{
UnsupportedUrlScheme,
UriMissingHost,
CertificateAuthorityBundleFailed,
WriteFailed,
CertificateBundleLoadFailure,
};
pub const Options = struct {
@ -962,10 +934,7 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Opt
defer client.ca_bundle_mutex.unlock();
if (client.next_https_rescan_certs) {
client.ca_bundle.rescan(client.allocator) catch |err| {
client.last_error = .{ .ca_bundle = err };
return error.CertificateAuthorityBundleFailed;
};
client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure;
@atomicStore(bool, &client.next_https_rescan_certs, false, .Release);
}
}
@ -989,13 +958,7 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Opt
req.arena = std.heap.ArenaAllocator.init(client.allocator);
req.start(uri, headers) catch |err| {
if (err == error.OutOfMemory) return error.OutOfMemory;
const err_casted = @errSetCast(BufferedConnection.WriteError, err);
client.last_error = .{ .write = err_casted };
return error.WriteFailed;
};
try req.start(uri, headers);
return req;
}