std.crypto.tls.Client: update to new reader/writer API

This commit is contained in:
Andrew Kelley 2025-05-02 17:27:47 -07:00
parent e326d7e8ec
commit bb7af21d6f
8 changed files with 316 additions and 614 deletions

View File

@ -168,7 +168,7 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type {
has_top_bit = true;
}
const out_slice = out[out.len - expected_len ..];
br.read(out_slice) catch return error.InvalidEncoding;
br.readSlice(out_slice) catch return error.InvalidEncoding;
if (@intFromBool(has_top_bit) != out[0] >> 7) return error.InvalidEncoding;
}
@ -177,7 +177,7 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type {
pub fn fromDer(der: []const u8) EncodingError!Signature {
if (der.len < 2) return error.InvalidEncoding;
var br: std.io.BufferedReader = undefined;
br.initFixed(der);
br.initFixed(@constCast(der));
const buf = br.take(2) catch return error.InvalidEncoding;
if (buf[0] != 0x30 or @as(usize, buf[1]) + 2 != der.len) return error.InvalidEncoding;
var sig: Signature = mem.zeroInit(Signature, .{});

View File

@ -655,7 +655,7 @@ pub const Decoder = struct {
}
/// Use this function to increase `their_end`.
pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void {
pub fn readAtLeast(d: *Decoder, stream: *std.io.BufferedReader, their_amt: usize) !void {
assert(!d.disable_reads);
const existing_amt = d.cap - d.idx;
d.their_end = d.idx + their_amt;
@ -663,14 +663,16 @@ pub const Decoder = struct {
const request_amt = their_amt - existing_amt;
const dest = d.buf[d.cap..];
if (request_amt > dest.len) return error.TlsRecordOverflow;
const actual_amt = try stream.readAtLeast(dest, request_amt);
if (actual_amt < request_amt) return error.TlsConnectionTruncated;
d.cap += actual_amt;
stream.readSlice(dest[0..request_amt]) catch |err| switch (err) {
error.EndOfStream => return error.TlsConnectionTruncated,
error.ReadFailed => return error.ReadFailed,
};
d.cap += request_amt;
}
/// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`.
/// Use when `our_amt` is calculated by us, not by them.
pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void {
pub fn readAtLeastOurAmt(d: *Decoder, stream: *std.io.BufferedReader, our_amt: usize) !void {
assert(!d.disable_reads);
try readAtLeast(d, stream, our_amt);
d.our_end = d.idx + our_amt;

File diff suppressed because it is too large Load Diff

View File

@ -487,9 +487,7 @@ pub const Reader = struct {
return decompressor.compression.gzip.reader();
},
.zstd => {
decompressor.compression = .{ .zstd = .init(reader.in, .{
.window_buffer = decompression_buffer,
}) };
decompressor.compression = .{ .zstd = .init(reader.in, .{ .verify_checksum = false }) };
return decompressor.compression.zstd.reader();
},
.compress => unreachable,
@ -742,7 +740,7 @@ pub const Decompressor = struct {
pub const Compression = union(enum) {
deflate: std.compress.zlib.Decompressor,
gzip: std.compress.gzip.Decompressor,
zstd: std.compress.zstd.Decompressor,
zstd: std.compress.zstd.Decompress,
none: void,
};
@ -768,12 +766,8 @@ pub const Decompressor = struct {
return decompressor.compression.gzip.reader();
},
.zstd => {
const first_half = buffer[0 .. buffer.len / 2];
const second_half = buffer[buffer.len / 2 ..];
decompressor.buffered_reader = transfer_reader.buffered(first_half);
decompressor.compression = .{ .zstd = .init(&decompressor.buffered_reader, .{
.window_buffer = second_half,
}) };
decompressor.buffered_reader = transfer_reader.buffered(buffer);
decompressor.compression = .{ .zstd = .init(&decompressor.buffered_reader, .{}) };
return decompressor.compression.gzip.reader();
},
.compress => unreachable,

View File

@ -28,7 +28,7 @@ tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.cr
/// If non-null, ssl secrets are logged to a stream. Creating such a stream
/// allows other processes with access to that stream to decrypt all
/// traffic over connections created with this `Client`.
ssl_key_log: ?*std.io.BufferedWriter = null,
ssl_key_log: ?*std.crypto.tls.Client.SslKeyLog = null,
/// When this is `true`, the next time this client performs an HTTPS request,
/// it will first rescan the system for root certificates.
@ -230,8 +230,11 @@ pub const Connection = struct {
stream_writer: net.Stream.Writer,
stream_reader: net.Stream.Reader,
/// HTTP protocol from client to server.
/// This either goes directly to `stream`, or to a TLS client.
/// This either goes directly to `stream_writer`, or to a TLS client.
writer: std.io.BufferedWriter,
/// HTTP protocol from server to client.
/// This either comes directly from `stream_reader`, or from a TLS client.
reader: std.io.BufferedReader,
/// Entry in `ConnectionPool.used` or `ConnectionPool.free`.
pool_node: std.DoublyLinkedList.Node,
port: u16,
@ -241,8 +244,6 @@ pub const Connection = struct {
protocol: Protocol,
const Plain = struct {
/// Data from `Connection.stream`.
reader: std.io.BufferedReader,
connection: Connection,
fn create(
@ -267,6 +268,7 @@ pub const Connection = struct {
.stream_writer = stream.writer(),
.stream_reader = stream.reader(),
.writer = plain.connection.stream_writer.interface().buffered(socket_write_buffer),
.reader = plain.connection.stream_reader.interface().buffered(socket_read_buffer),
.pool_node = .{},
.port = port,
.host_len = @intCast(remote_host.len),
@ -274,7 +276,6 @@ pub const Connection = struct {
.closing = false,
.protocol = .plain,
},
.reader = plain.connection.stream_reader.interface().buffered(socket_read_buffer),
};
return plain;
}
@ -327,6 +328,7 @@ pub const Connection = struct {
.stream_writer = stream.writer(),
.stream_reader = stream.reader(),
.writer = tls.client.writer().buffered(socket_write_buffer),
.reader = tls.client.reader().unbuffered(),
.pool_node = .{},
.port = port,
.host_len = @intCast(remote_host.len),
@ -386,21 +388,6 @@ pub const Connection = struct {
};
}
/// This is either data from `stream`, or `Tls.client`.
fn reader(c: *Connection) *std.io.BufferedReader {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
return &tls.client.reader;
},
.plain => {
const plain: *Plain = @fieldParentPtr("connection", c);
return &plain.reader;
},
};
}
/// If this is called without calling `flush` or `end`, data will be
/// dropped unsent.
pub fn destroy(c: *Connection) void {
@ -1556,7 +1543,7 @@ pub fn request(
.client = client,
.connection = connection,
.reader = .{
.in = connection.reader(),
.in = &connection.reader,
.state = .ready,
.body_state = undefined,
},
@ -1670,8 +1657,7 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
const decompress_buffer: []u8 = switch (response.head.content_encoding) {
.identity => &.{},
.zstd => options.decompress_buffer orelse
try client.allocator.alloc(u8, std.compress.zstd.default_window_len * 2),
.zstd => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.zstd.default_window_len),
else => options.decompress_buffer orelse try client.allocator.alloc(u8, 8 * 1024),
};
defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer);
@ -1681,7 +1667,7 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
const list = storage.list;
if (storage.allocator) |allocator| {
reader.readRemainingArrayList(allocator, null, list, storage.append_limit) catch |err| switch (err) {
reader.readRemainingArrayList(allocator, null, list, storage.append_limit, 128) catch |err| switch (err) {
error.ReadFailed => return response.bodyErr().?,
else => |e| return e,
};

View File

@ -252,6 +252,12 @@ pub fn toss(br: *BufferedReader, n: usize) void {
assert(br.seek <= br.end);
}
pub fn unread(noalias br: *BufferedReader, noalias data: []const u8) void {
_ = br;
_ = data;
@panic("TODO");
}
/// Equivalent to `peek` followed by `toss`.
///
/// The data returned is invalidated by the next call to `take`, `peek`,
@ -736,24 +742,25 @@ pub fn discardDelimiterExclusive(br: *BufferedReader, delimiter: u8) Reader.Shor
/// Asserts buffer capacity is at least `n`.
pub fn fill(br: *BufferedReader, n: usize) Reader.Error!void {
assert(n <= br.buffer.len);
const buffer = br.buffer[0..br.end];
const seek = br.seek;
if (seek + n <= buffer.len) {
if (br.seek + n <= br.end) {
@branchHint(.likely);
return;
}
if (seek > 0) {
const remainder = buffer[seek..];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
br.end = remainder.len;
br.seek = 0;
}
while (true) {
rebaseCapacity(br, n);
while (br.end < br.seek + n) {
br.end += try br.unbuffered_reader.readVec(&.{br.buffer[br.end..]});
if (n <= br.end) return;
}
}
/// Fills the buffer with at least one more byte of data, without advancing the
/// seek position, doing exactly one underlying read.
///
/// Asserts buffer capacity is at least 1.
pub fn fillMore(br: *BufferedReader) Reader.Error!void {
rebaseCapacity(br, 1);
br.end += try br.unbuffered_reader.readVec(&.{br.buffer[br.end..]});
}
/// Returns the next byte from the stream or returns `error.EndOfStream`.
///
/// Does not advance the seek position.
@ -783,7 +790,7 @@ pub fn takeByteSigned(br: *BufferedReader) Reader.Error!i8 {
return @bitCast(try br.takeByte());
}
/// Asserts the buffer was initialized with a capacity at least `@sizeOf(T)`.
/// Asserts the buffer was initialized with a capacity at least `@bitSizeOf(T) / 8`.
pub inline fn takeInt(br: *BufferedReader, comptime T: type, endian: std.builtin.Endian) Reader.Error!T {
const n = @divExact(@typeInfo(T).int.bits, 8);
return std.mem.readInt(T, try br.takeArray(n), endian);
@ -957,6 +964,14 @@ pub fn rebase(br: *BufferedReader) void {
br.end = data.len;
}
/// Ensures `capacity` more data can be buffered without rebasing, by rebasing
/// if necessary.
///
/// Asserts `capacity` is within the buffer capacity.
pub fn rebaseCapacity(br: *BufferedReader, capacity: usize) void {
if (br.end > br.buffer.len - capacity) rebase(br);
}
/// Advances the stream and decreases the size of the storage buffer by `n`,
/// returning the range of bytes no longer accessible by `br`.
///

View File

@ -100,6 +100,10 @@ pub const Limit = enum(usize) {
return s[0..l.minInt(s.len)];
}
pub fn sliceConst(l: Limit, s: []const u8) []const u8 {
return s[0..l.minInt(s.len)];
}
pub fn toInt(l: Limit) ?usize {
return switch (l) {
else => @intFromEnum(l),

View File

@ -140,7 +140,7 @@ pub fn failingWriteFile(
limit: std.io.Writer.Limit,
headers_and_trailers: []const []const u8,
headers_len: usize,
) Error!usize {
) FileError!usize {
_ = context;
_ = file;
_ = offset;
@ -165,7 +165,7 @@ pub fn unimplementedWriteFile(
limit: std.io.Writer.Limit,
headers_and_trailers: []const []const u8,
headers_len: usize,
) Error!usize {
) FileError!usize {
_ = context;
_ = file;
_ = offset;