std: WIP update more to new reader/writer

delete some bad readers/writers

add limited reader

update TLS

about to do something drastic to compress
This commit is contained in:
Andrew Kelley 2025-04-30 14:18:45 -07:00
parent c7040171fb
commit 25ac70f973
18 changed files with 388 additions and 437 deletions

View File

@ -1,75 +1,19 @@
//! Compression algorithms.
const std = @import("std.zig");
pub const flate = @import("compress/flate.zig");
pub const gzip = @import("compress/gzip.zig");
pub const zlib = @import("compress/zlib.zig");
pub const lzma = @import("compress/lzma.zig");
pub const lzma2 = @import("compress/lzma2.zig");
pub const xz = @import("compress/xz.zig");
pub const zlib = @import("compress/zlib.zig");
pub const zstd = @import("compress/zstandard.zig");
pub fn HashedReader(ReaderType: type, HasherType: type) type {
return struct {
child_reader: ReaderType,
hasher: HasherType,
pub const Error = ReaderType.Error;
pub const Reader = std.io.Reader(*@This(), Error, read);
pub fn read(self: *@This(), buf: []u8) Error!usize {
const amt = try self.child_reader.read(buf);
self.hasher.update(buf[0..amt]);
return amt;
}
pub fn reader(self: *@This()) Reader {
return .{ .context = self };
}
};
}
pub fn hashedReader(
reader: anytype,
hasher: anytype,
) HashedReader(@TypeOf(reader), @TypeOf(hasher)) {
return .{ .child_reader = reader, .hasher = hasher };
}
pub fn HashedWriter(WriterType: type, HasherType: type) type {
return struct {
child_writer: WriterType,
hasher: HasherType,
pub const Error = WriterType.Error;
pub const Writer = std.io.Writer(*@This(), Error, write);
pub fn write(self: *@This(), buf: []const u8) Error!usize {
const amt = try self.child_writer.write(buf);
self.hasher.update(buf[0..amt]);
return amt;
}
pub fn writer(self: *@This()) Writer {
return .{ .context = self };
}
};
}
pub fn hashedWriter(
writer: anytype,
hasher: anytype,
) HashedWriter(@TypeOf(writer), @TypeOf(hasher)) {
return .{ .child_writer = writer, .hasher = hasher };
}
test {
_ = flate;
_ = gzip;
_ = lzma;
_ = lzma2;
_ = xz;
_ = zstd;
_ = flate;
_ = gzip;
_ = zlib;
_ = zstd;
}

View File

@ -821,7 +821,7 @@ pub fn BitReader(comptime T: type) type {
/// Skip zero terminated string.
pub fn skipStringZ(self: *Self) !void {
while (true) {
if (try self.readF(u8, 0) == 0) break;
if (try self.readF(u8, .{}) == 0) break;
}
}

View File

@ -1,7 +1,12 @@
const std = @import("std");
const std = @import("../std.zig");
const RingBuffer = std.RingBuffer;
const types = @import("zstandard/types.zig");
/// Recommended amount by the standard. Lower than this may result in inability
/// to decompress common streams.
pub const default_window_len = 8 * 1024 * 1024;
pub const frame = types.frame;
pub const compressed_block = types.compressed_block;
@ -10,7 +15,8 @@ pub const decompress = @import("zstandard/decompress.zig");
pub const Decompressor = struct {
const table_size_max = types.compressed_block.table_size_max;
source: *std.io.BufferedReader,
input: *std.io.BufferedReader,
bytes_read: usize,
state: enum { NewFrame, InFrame, LastBlock },
decode_state: decompress.block.DecodeState,
frame_context: decompress.FrameContext,
@ -23,14 +29,12 @@ pub const Decompressor = struct {
verify_checksum: bool,
checksum: ?u32,
current_frame_decompressed_size: usize,
err: ?Error = null,
pub const Options = struct {
verify_checksum: bool = true,
/// See `default_window_len`.
window_buffer: []u8,
/// Recommended amount by the standard. Lower than this may result
/// in inability to decompress common streams.
pub const default_window_buffer_len = 8 * 1024 * 1024;
};
const WindowBuffer = struct {
@ -45,11 +49,13 @@ pub const Decompressor = struct {
MalformedBlock,
MalformedFrame,
OutOfMemory,
EndOfStream,
};
pub fn init(source: *std.io.BufferedReader, options: Options) Decompressor {
pub fn init(input: *std.io.BufferedReader, options: Options) Decompressor {
return .{
.source = source,
.input = input,
.bytes_read = 0,
.state = .NewFrame,
.decode_state = undefined,
.frame_context = undefined,
@ -65,100 +71,128 @@ pub const Decompressor = struct {
};
}
fn frameInit(self: *Decompressor) !void {
const source_reader = self.source;
switch (try decompress.decodeFrameHeader(source_reader)) {
fn frameInit(d: *Decompressor) !void {
const in = d.input;
switch (try decompress.decodeFrameHeader(in, &d.bytes_read)) {
.skippable => |header| {
try source_reader.skipBytes(header.frame_size, .{});
self.state = .NewFrame;
try in.discardAll(header.frame_size);
d.bytes_read += header.frame_size;
d.state = .NewFrame;
},
.zstandard => |header| {
const frame_context = try decompress.FrameContext.init(
header,
self.buffer.data.len,
self.verify_checksum,
d.buffer.data.len,
d.verify_checksum,
);
const decode_state = decompress.block.DecodeState.init(
&self.literal_fse_buffer,
&self.match_fse_buffer,
&self.offset_fse_buffer,
&d.literal_fse_buffer,
&d.match_fse_buffer,
&d.offset_fse_buffer,
);
self.decode_state = decode_state;
self.frame_context = frame_context;
d.decode_state = decode_state;
d.frame_context = frame_context;
self.checksum = null;
self.current_frame_decompressed_size = 0;
d.checksum = null;
d.current_frame_decompressed_size = 0;
self.state = .InFrame;
d.state = .InFrame;
},
}
}
pub fn reader(self: *Decompressor) std.io.Reader {
return .{ .context = self };
return .{
.context = self,
.vtable = &.{
.read = read,
.readVec = readVec,
.discard = discard,
},
};
}
pub fn read(self: *Decompressor, buffer: []u8) Error!usize {
if (buffer.len == 0) return 0;
fn read(context: ?*anyopaque, bw: *std.io.BufferedWriter, limit: std.io.Reader.Limit) std.io.Reader.RwError!usize {
const buf = limit.slice(try bw.writableSliceGreedy(1));
const n = try readVec(context, &.{buf});
bw.advance(n);
return n;
}
var size: usize = 0;
while (size == 0) {
while (self.state == .NewFrame) {
const initial_count = self.source.bytes_read;
self.frameInit() catch |err| switch (err) {
error.DictionaryIdFlagUnsupported => return error.DictionaryIdFlagUnsupported,
error.EndOfStream => return if (self.source.bytes_read == initial_count)
0
else
error.MalformedFrame,
else => return error.MalformedFrame,
};
}
size = try self.readInner(buffer);
fn discard(context: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize {
var trash: [128]u8 = undefined;
const buf = limit.slice(&trash);
return readVec(context, &.{buf});
}
fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
const d: *Decompressor = @ptrCast(@alignCast(context));
if (data.len == 0) return 0;
const buffer = data[0];
while (d.state == .NewFrame) {
const initial_count = d.bytes_read;
d.frameInit() catch |err| switch (err) {
error.DictionaryIdFlagUnsupported => {
d.err = error.DictionaryIdFlagUnsupported;
return error.ReadFailed;
},
error.EndOfStream => {
if (d.bytes_read == initial_count) return error.EndOfStream;
d.err = error.MalformedFrame;
return error.ReadFailed;
},
else => {
d.err = error.MalformedFrame;
return error.ReadFailed;
},
};
}
return size;
return d.readInner(buffer) catch |err| {
d.err = err;
return error.ReadFailed;
};
}
fn readInner(self: *Decompressor, buffer: []u8) Error!usize {
std.debug.assert(self.state != .NewFrame);
fn readInner(d: *Decompressor, buffer: []u8) Error!usize {
std.debug.assert(d.state != .NewFrame);
var ring_buffer = RingBuffer{
.data = self.buffer.data,
.read_index = self.buffer.read_index,
.write_index = self.buffer.write_index,
.data = d.buffer.data,
.read_index = d.buffer.read_index,
.write_index = d.buffer.write_index,
};
defer {
self.buffer.read_index = ring_buffer.read_index;
self.buffer.write_index = ring_buffer.write_index;
d.buffer.read_index = ring_buffer.read_index;
d.buffer.write_index = ring_buffer.write_index;
}
const source_reader = self.source;
while (ring_buffer.isEmpty() and self.state != .LastBlock) {
const header_bytes = source_reader.readBytesNoEof(3) catch
return error.MalformedFrame;
const block_header = decompress.block.decodeBlockHeader(&header_bytes);
const in = d.input;
while (ring_buffer.isEmpty() and d.state != .LastBlock) {
const header_bytes = try in.takeArray(3);
d.bytes_read += header_bytes.len;
const block_header = decompress.block.decodeBlockHeader(header_bytes);
decompress.block.decodeBlockReader(
&ring_buffer,
source_reader,
in,
&d.bytes_read,
block_header,
&self.decode_state,
self.frame_context.block_size_max,
&self.literals_buffer,
&self.sequence_buffer,
) catch
return error.MalformedBlock;
&d.decode_state,
d.frame_context.block_size_max,
&d.literals_buffer,
&d.sequence_buffer,
) catch return error.MalformedBlock;
if (self.frame_context.content_size) |size| {
if (self.current_frame_decompressed_size > size) return error.MalformedFrame;
if (d.frame_context.content_size) |size| {
if (d.current_frame_decompressed_size > size) return error.MalformedFrame;
}
const size = ring_buffer.len();
self.current_frame_decompressed_size += size;
d.current_frame_decompressed_size += size;
if (self.frame_context.hasher_opt) |*hasher| {
if (d.frame_context.hasher_opt) |*hasher| {
if (size > 0) {
const written_slice = ring_buffer.sliceLast(size);
hasher.update(written_slice.first);
@ -166,19 +200,19 @@ pub const Decompressor = struct {
}
}
if (block_header.last_block) {
self.state = .LastBlock;
if (self.frame_context.has_checksum) {
const checksum = source_reader.readInt(u32, .little) catch
return error.MalformedFrame;
if (self.verify_checksum) {
if (self.frame_context.hasher_opt) |*hasher| {
d.state = .LastBlock;
if (d.frame_context.has_checksum) {
const checksum = in.readInt(u32, .little) catch return error.MalformedFrame;
d.bytes_read += 4;
if (d.verify_checksum) {
if (d.frame_context.hasher_opt) |*hasher| {
if (checksum != decompress.computeChecksum(hasher))
return error.ChecksumFailure;
}
}
}
if (self.frame_context.content_size) |content_size| {
if (content_size != self.current_frame_decompressed_size) {
if (d.frame_context.content_size) |content_size| {
if (content_size != d.current_frame_decompressed_size) {
return error.MalformedFrame;
}
}
@ -189,8 +223,8 @@ pub const Decompressor = struct {
if (size > 0) {
ring_buffer.readFirstAssumeLength(buffer, size);
}
if (self.state == .LastBlock and ring_buffer.len() == 0) {
self.state = .NewFrame;
if (d.state == .LastBlock and ring_buffer.len() == 0) {
d.state = .NewFrame;
}
return size;
}

View File

@ -807,7 +807,8 @@ pub fn decodeBlockRingBuffer(
/// contain enough bytes.
pub fn decodeBlockReader(
dest: *RingBuffer,
source: anytype,
in: *std.io.BufferedReader,
bytes_read: *usize,
block_header: frame.Zstandard.Block.Header,
decode_state: *DecodeState,
block_size_max: usize,
@ -815,26 +816,29 @@ pub fn decodeBlockReader(
sequence_buffer: []u8,
) !void {
const block_size = block_header.block_size;
var block_reader_limited = std.io.limitedReader(source, block_size);
const block_reader = block_reader_limited.reader();
if (block_size_max < block_size) return error.BlockSizeOverMaximum;
switch (block_header.block_type) {
.raw => {
if (block_size == 0) return;
const slice = dest.sliceAt(dest.write_index, block_size);
try source.readNoEof(slice.first);
try source.readNoEof(slice.second);
var vecs: [2][]u8 = &.{slice.first, slice.second };
try in.readVecAll(&vecs);
assert(slice.first.len + slice.second.len == block_size);
bytes_read.* += block_size;
dest.write_index = dest.mask2(dest.write_index + block_size);
decode_state.written_count += block_size;
},
.rle => {
const byte = try source.readByte();
const byte = try in.takeByte();
bytes_read.* += 1;
for (0..block_size) |_| {
dest.writeAssumeCapacity(byte);
}
decode_state.written_count += block_size;
},
.compressed => {
var block_reader_limited = std.io.limitedReader(source, block_size);
const block_reader = block_reader_limited.reader();
const literals = try decodeLiteralsSection(block_reader, literals_buffer);
const sequences_header = try decodeSequencesHeader(block_reader);

View File

@ -50,7 +50,7 @@ pub const FrameHeader = union(enum) {
skippable: SkippableHeader,
};
pub const HeaderError = error{ BadMagic, EndOfStream, ReservedBitSet };
pub const HeaderError = error{ ReadFailed, BadMagic, EndOfStream, ReservedBitSet };
/// Returns the header of the frame at the beginning of `source`.
///
@ -61,16 +61,21 @@ pub const HeaderError = error{ BadMagic, EndOfStream, ReservedBitSet };
/// - `error.EndOfStream` if `source` contains fewer than 4 bytes
/// - `error.ReservedBitSet` if the frame is a Zstandard frame and any of the
/// reserved bits are set
pub fn decodeFrameHeader(source: anytype) (@TypeOf(source).Error || HeaderError)!FrameHeader {
const magic = try source.readInt(u32, .little);
pub fn decodeFrameHeader(br: *std.io.BufferedReader, bytes_read: *usize) HeaderError!FrameHeader {
const magic = try br.readInt(u32, .little);
bytes_read.* += 4;
const frame_type = try frameType(magic);
switch (frame_type) {
.zstandard => return FrameHeader{ .zstandard = try decodeZstandardHeader(source) },
.skippable => return FrameHeader{
.skippable = .{
.magic_number = magic,
.frame_size = try source.readInt(u32, .little),
},
.zstandard => return .{ .zstandard = try decodeZstandardHeader(br, bytes_read) },
.skippable => {
const result: FrameHeader = .{
.skippable = .{
.magic_number = magic,
.frame_size = try br.readInt(u32, .little),
},
};
bytes_read.* += 4;
return result;
},
}
}

View File

@ -49,8 +49,8 @@ pub const hello_retry_request_sequence = [32]u8{
};
pub const close_notify_alert = [_]u8{
@intFromEnum(AlertLevel.warning),
@intFromEnum(AlertDescription.close_notify),
@intFromEnum(Alert.Level.warning),
@intFromEnum(Alert.Description.close_notify),
};
pub const ProtocolVersion = enum(u16) {
@ -138,103 +138,108 @@ pub const ExtensionType = enum(u16) {
_,
};
pub const AlertLevel = enum(u8) {
warning = 1,
fatal = 2,
_,
};
pub const Alert = struct {
level: Level,
description: Description,
pub const AlertDescription = enum(u8) {
pub const Error = error{
TlsAlertUnexpectedMessage,
TlsAlertBadRecordMac,
TlsAlertRecordOverflow,
TlsAlertHandshakeFailure,
TlsAlertBadCertificate,
TlsAlertUnsupportedCertificate,
TlsAlertCertificateRevoked,
TlsAlertCertificateExpired,
TlsAlertCertificateUnknown,
TlsAlertIllegalParameter,
TlsAlertUnknownCa,
TlsAlertAccessDenied,
TlsAlertDecodeError,
TlsAlertDecryptError,
TlsAlertProtocolVersion,
TlsAlertInsufficientSecurity,
TlsAlertInternalError,
TlsAlertInappropriateFallback,
TlsAlertMissingExtension,
TlsAlertUnsupportedExtension,
TlsAlertUnrecognizedName,
TlsAlertBadCertificateStatusResponse,
TlsAlertUnknownPskIdentity,
TlsAlertCertificateRequired,
TlsAlertNoApplicationProtocol,
TlsAlertUnknown,
pub const Level = enum(u8) {
warning = 1,
fatal = 2,
_,
};
close_notify = 0,
unexpected_message = 10,
bad_record_mac = 20,
record_overflow = 22,
handshake_failure = 40,
bad_certificate = 42,
unsupported_certificate = 43,
certificate_revoked = 44,
certificate_expired = 45,
certificate_unknown = 46,
illegal_parameter = 47,
unknown_ca = 48,
access_denied = 49,
decode_error = 50,
decrypt_error = 51,
protocol_version = 70,
insufficient_security = 71,
internal_error = 80,
inappropriate_fallback = 86,
user_canceled = 90,
missing_extension = 109,
unsupported_extension = 110,
unrecognized_name = 112,
bad_certificate_status_response = 113,
unknown_psk_identity = 115,
certificate_required = 116,
no_application_protocol = 120,
_,
pub const Description = enum(u8) {
pub const Error = error{
TlsAlertUnexpectedMessage,
TlsAlertBadRecordMac,
TlsAlertRecordOverflow,
TlsAlertHandshakeFailure,
TlsAlertBadCertificate,
TlsAlertUnsupportedCertificate,
TlsAlertCertificateRevoked,
TlsAlertCertificateExpired,
TlsAlertCertificateUnknown,
TlsAlertIllegalParameter,
TlsAlertUnknownCa,
TlsAlertAccessDenied,
TlsAlertDecodeError,
TlsAlertDecryptError,
TlsAlertProtocolVersion,
TlsAlertInsufficientSecurity,
TlsAlertInternalError,
TlsAlertInappropriateFallback,
TlsAlertMissingExtension,
TlsAlertUnsupportedExtension,
TlsAlertUnrecognizedName,
TlsAlertBadCertificateStatusResponse,
TlsAlertUnknownPskIdentity,
TlsAlertCertificateRequired,
TlsAlertNoApplicationProtocol,
TlsAlertUnknown,
};
pub fn toError(alert: AlertDescription) Error!void {
switch (alert) {
.close_notify => {}, // not an error
.unexpected_message => return error.TlsAlertUnexpectedMessage,
.bad_record_mac => return error.TlsAlertBadRecordMac,
.record_overflow => return error.TlsAlertRecordOverflow,
.handshake_failure => return error.TlsAlertHandshakeFailure,
.bad_certificate => return error.TlsAlertBadCertificate,
.unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
.certificate_revoked => return error.TlsAlertCertificateRevoked,
.certificate_expired => return error.TlsAlertCertificateExpired,
.certificate_unknown => return error.TlsAlertCertificateUnknown,
.illegal_parameter => return error.TlsAlertIllegalParameter,
.unknown_ca => return error.TlsAlertUnknownCa,
.access_denied => return error.TlsAlertAccessDenied,
.decode_error => return error.TlsAlertDecodeError,
.decrypt_error => return error.TlsAlertDecryptError,
.protocol_version => return error.TlsAlertProtocolVersion,
.insufficient_security => return error.TlsAlertInsufficientSecurity,
.internal_error => return error.TlsAlertInternalError,
.inappropriate_fallback => return error.TlsAlertInappropriateFallback,
.user_canceled => {}, // not an error
.missing_extension => return error.TlsAlertMissingExtension,
.unsupported_extension => return error.TlsAlertUnsupportedExtension,
.unrecognized_name => return error.TlsAlertUnrecognizedName,
.bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
.unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
.certificate_required => return error.TlsAlertCertificateRequired,
.no_application_protocol => return error.TlsAlertNoApplicationProtocol,
_ => return error.TlsAlertUnknown,
close_notify = 0,
unexpected_message = 10,
bad_record_mac = 20,
record_overflow = 22,
handshake_failure = 40,
bad_certificate = 42,
unsupported_certificate = 43,
certificate_revoked = 44,
certificate_expired = 45,
certificate_unknown = 46,
illegal_parameter = 47,
unknown_ca = 48,
access_denied = 49,
decode_error = 50,
decrypt_error = 51,
protocol_version = 70,
insufficient_security = 71,
internal_error = 80,
inappropriate_fallback = 86,
user_canceled = 90,
missing_extension = 109,
unsupported_extension = 110,
unrecognized_name = 112,
bad_certificate_status_response = 113,
unknown_psk_identity = 115,
certificate_required = 116,
no_application_protocol = 120,
_,
pub fn toError(description: Description) Error!void {
switch (description) {
.close_notify => {}, // not an error
.unexpected_message => return error.TlsAlertUnexpectedMessage,
.bad_record_mac => return error.TlsAlertBadRecordMac,
.record_overflow => return error.TlsAlertRecordOverflow,
.handshake_failure => return error.TlsAlertHandshakeFailure,
.bad_certificate => return error.TlsAlertBadCertificate,
.unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
.certificate_revoked => return error.TlsAlertCertificateRevoked,
.certificate_expired => return error.TlsAlertCertificateExpired,
.certificate_unknown => return error.TlsAlertCertificateUnknown,
.illegal_parameter => return error.TlsAlertIllegalParameter,
.unknown_ca => return error.TlsAlertUnknownCa,
.access_denied => return error.TlsAlertAccessDenied,
.decode_error => return error.TlsAlertDecodeError,
.decrypt_error => return error.TlsAlertDecryptError,
.protocol_version => return error.TlsAlertProtocolVersion,
.insufficient_security => return error.TlsAlertInsufficientSecurity,
.internal_error => return error.TlsAlertInternalError,
.inappropriate_fallback => return error.TlsAlertInappropriateFallback,
.user_canceled => {}, // not an error
.missing_extension => return error.TlsAlertMissingExtension,
.unsupported_extension => return error.TlsAlertUnsupportedExtension,
.unrecognized_name => return error.TlsAlertUnrecognizedName,
.bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
.unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
.certificate_required => return error.TlsAlertCertificateRequired,
.no_application_protocol => return error.TlsAlertNoApplicationProtocol,
_ => return error.TlsAlertUnknown,
}
}
}
};
};
pub const SignatureScheme = enum(u16) {

View File

@ -39,8 +39,9 @@ output: *std.io.BufferedWriter,
///
/// Its buffer aliases the buffer of `input`.
reader: std.io.BufferedReader,
/// Populated under various error conditions.
diagnostics: Diagnostics,
/// Populated when `error.TlsAlert` is returned.
alert: ?tls.Alert,
read_err: ?ReadError,
tls_version: tls.ProtocolVersion,
read_seq: u64,
@ -69,15 +70,16 @@ application_cipher: tls.ApplicationCipher,
/// this connection.
ssl_key_log: ?*SslKeyLog,
pub const Diagnostics = union(enum) {
/// Any `ReadFailure` and `WriteFailure` was due to `input` or `output`
/// returning the error, respectively.
transitive,
/// Populated on `error.TlsAlert`.
///
/// If this isn't a error alert, then it's a closure alert, which makes
/// no sense in a handshake.
alert: tls.AlertDescription,
pub const ReadError = error{
/// The alert description will be stored in `alert`.
TlsAlert,
TlsBadLength,
TlsBadRecordMac,
TlsConnectionTruncated,
TlsDecodeError,
TlsRecordOverflow,
TlsUnexpectedMessage,
TlsIllegalParameter,
};
pub const SslKeyLog = struct {
@ -128,14 +130,13 @@ pub const Options = struct {
};
const InitError = error{
//OutOfMemory,
WriteFailure,
ReadFailure,
WriteFailed,
ReadFailed,
InsufficientEntropy,
DiskQuota,
LockViolation,
NotOpenForWriting,
/// The alert description will be stored in `Options.Diagnostics.alert`.
/// The alert description will be stored in `alert`.
TlsAlert,
TlsUnexpectedMessage,
TlsIllegalParameter,
@ -192,7 +193,7 @@ pub fn init(
) InitError!void {
assert(input.storage.buffer.len >= min_buffer_len);
assert(output.buffer.len >= min_buffer_len);
client.diagnostics = .transient;
client.alert = null;
const host = switch (options.host) {
.no_verification => "",
.explicit => |host| host,
@ -417,10 +418,10 @@ pub fn init(
switch (ct) {
.alert => {
ctd.ensure(2) catch continue :fragment;
const level = ctd.decode(tls.AlertLevel);
const desc = ctd.decode(tls.AlertDescription);
_ = level;
client.diagnostics = .{ .alert = desc };
client.alert = .{
.level = ctd.decode(tls.Alert.Level),
.description = ctd.decode(tls.Alert.Description),
};
return error.TlsAlert;
},
.change_cipher_spec => {
@ -924,7 +925,7 @@ pub fn writer(c: *Client) std.io.Writer {
fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
const c: *Client = @alignCast(@ptrCast(context));
const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
const output = &c.output;
const output = c.output;
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
var total_clear: usize = 0;
var ciphertext_end: usize = 0;
@ -942,7 +943,7 @@ fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) std.i
/// distinguish between a properly finished TLS session, or a truncation
/// attack.
pub fn end(c: *Client) std.io.Writer.Error!void {
const output = &c.output;
const output = c.output;
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert);
output.advance(prepared.cleartext_len);
@ -1062,16 +1063,16 @@ fn read(
context: ?*anyopaque,
bw: *std.io.BufferedWriter,
limit: std.io.Reader.Limit,
) std.io.Reader.RwError!std.io.Reader.Status {
) std.io.Reader.RwError!usize {
const buf = limit.slice(try bw.writableSliceGreedy(1));
const status = try readVec(context, &.{buf});
bw.advance(status.len);
return status;
const n = try readVec(context, &.{buf});
bw.advance(n);
return n;
}
fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
const c: *Client = @ptrCast(@alignCast(context));
if (c.eof()) return .{ .end = true };
if (c.eof()) return error.EndOfStream;
var vp: VecPut = .{ .iovecs = data };
@ -1093,11 +1094,11 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
if (c.received_close_notify) {
c.partial_ciphertext_end = 0;
assert(vp.total == amt);
return .{ .len = amt, .end = c.eof() };
return amt;
} else if (amt > 0) {
// We don't need more data, so don't call read.
assert(vp.total == amt);
return .{ .len = amt, .end = c.eof() };
return amt;
}
}
@ -1149,7 +1150,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
if (c.allow_truncation_attacks) {
c.received_close_notify = true;
} else {
return error.TlsConnectionTruncated;
return failRead(c, error.TlsConnectionTruncated);
}
}
@ -1168,7 +1169,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
// Perfect split.
if (frag.ptr == frag1.ptr) {
c.partial_ciphertext_end = c.partial_ciphertext_idx;
return .{ .len = vp.total, .end = c.eof() };
return vp.total;
}
frag = frag1;
in = 0;
@ -1188,7 +1189,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3);
const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4);
const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow);
const full_record_len = record_len + tls.record_header_len;
const second_len = full_record_len - first.len;
@ -1208,7 +1209,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
in += 2;
_ = legacy_version;
const record_len = mem.readInt(u16, frag[in..][0..2], .big);
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow);
in += 2;
const the_end = in + record_len;
if (the_end > frag.len) {
@ -1255,7 +1256,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
&cleartext_stack_buffer;
const cleartext = cleartext_buf[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
return error.TlsBadRecordMac;
return failRead(c, error.TlsBadRecordMac);
const msg = mem.trimEnd(u8, cleartext, "\x00");
break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
},
@ -1287,7 +1288,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
&cleartext_stack_buffer;
const cleartext = cleartext_buf[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
return error.TlsBadRecordMac;
return failRead(c, error.TlsBadRecordMac);
break :cleartext .{ cleartext, ct };
},
else => unreachable,
@ -1296,23 +1297,24 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
c.read_seq = try std.math.add(u64, c.read_seq, 1);
switch (inner_ct) {
.alert => {
if (cleartext.len != 2) return error.TlsDecodeError;
const level: tls.AlertLevel = @enumFromInt(cleartext[0]);
_ = level;
const desc: tls.AlertDescription = @enumFromInt(cleartext[1]);
switch (desc) {
if (cleartext.len != 2) return failRead(c, error.TlsDecodeError);
const alert: tls.Alert = .{
.level = @enumFromInt(cleartext[0]),
.description = @enumFromInt(cleartext[1]),
};
switch (alert.description) {
.close_notify => {
c.received_close_notify = true;
c.partial_ciphertext_end = c.partial_ciphertext_idx;
return .{ .len = vp.total, .end = c.eof() };
return vp.total;
},
.user_canceled => {
// TODO: handle server-side closures
return error.TlsUnexpectedMessage;
return failRead(c, error.TlsUnexpectedMessage);
},
else => {
c.diagnostics = .{ .alert = desc };
return error.TlsAlert;
c.alert = alert;
return failRead(c, error.TlsAlert);
},
}
},
@ -1324,8 +1326,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big);
ct_i += 3;
const next_handshake_i = ct_i + handshake_len;
if (next_handshake_i > cleartext.len)
return error.TlsBadLength;
if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength);
const handshake = cleartext[ct_i..next_handshake_i];
switch (handshake_type) {
.new_session_ticket => {
@ -1371,12 +1372,10 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
c.write_seq = 0;
},
.update_not_requested => {},
_ => return error.TlsIllegalParameter,
_ => return failRead(c, error.TlsIllegalParameter),
}
},
else => {
return error.TlsUnexpectedMessage;
},
else => return failRead(c, error.TlsUnexpectedMessage),
}
ct_i = next_handshake_i;
if (ct_i >= cleartext.len) break;
@ -1411,7 +1410,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
vp.next(cleartext.len);
}
},
else => return error.TlsUnexpectedMessage,
else => return failRead(c, error.TlsUnexpectedMessage),
}
in = end;
}
@ -1423,6 +1422,11 @@ fn discard(context: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error
@panic("TODO");
}
fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
c.read_err = err;
return error.ReadFailed;
}
fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void {
const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false;
defer if (locked) key_log_file.unlock();

View File

@ -28,7 +28,7 @@ tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.cr
/// If non-null, ssl secrets are logged to a stream. Creating such a stream
/// allows other processes with access to that stream to decrypt all
/// traffic over connections created with this `Client`.
ssl_key_logger: ?*std.io.BufferedWriter = null,
ssl_key_log: ?*std.io.BufferedWriter = null,
/// When this is `true`, the next time this client performs an HTTPS request,
/// it will first rescan the system for root certificates.
@ -342,7 +342,7 @@ pub const Connection = struct {
tls.client.init(&tls.reader, &tls.writer, .{
.host = .{ .explicit = remote_host },
.ca = .{ .bundle = client.ca_bundle },
.ssl_key_logger = client.ssl_key_logger,
.ssl_key_log = client.ssl_key_log,
}) catch return error.TlsInitializationFailed;
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
@ -1671,7 +1671,7 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
const decompress_buffer: []u8 = switch (response.head.content_encoding) {
.identity => &.{},
.zstd => options.decompress_buffer orelse
try client.allocator.alloc(u8, std.compress.zstd.Decompressor.Options.default_window_buffer_len * 2),
try client.allocator.alloc(u8, std.compress.zstd.default_window_len * 2),
else => options.decompress_buffer orelse try client.allocator.alloc(u8, 8 * 1024),
};
defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer);

View File

@ -1,16 +1,11 @@
const std = @import("std.zig");
const builtin = @import("builtin");
const root = @import("root");
const c = std.c;
const is_windows = builtin.os.tag == .windows;
const std = @import("std.zig");
const windows = std.os.windows;
const posix = std.posix;
const math = std.math;
const assert = std.debug.assert;
const fs = std.fs;
const mem = std.mem;
const meta = std.meta;
const File = std.fs.File;
const Allocator = std.mem.Allocator;
const Alignment = std.mem.Alignment;
@ -21,12 +16,6 @@ pub const BufferedReader = @import("io/BufferedReader.zig");
pub const BufferedWriter = @import("io/BufferedWriter.zig");
pub const AllocatingWriter = @import("io/AllocatingWriter.zig");
pub const CWriter = @import("io/c_writer.zig").CWriter;
pub const cWriter = @import("io/c_writer.zig").cWriter;
pub const LimitedReader = @import("io/limited_reader.zig").LimitedReader;
pub const limitedReader = @import("io/limited_reader.zig").limitedReader;
pub const MultiWriter = @import("io/multi_writer.zig").MultiWriter;
pub const multiWriter = @import("io/multi_writer.zig").multiWriter;
@ -38,9 +27,6 @@ pub const bitWriter = @import("io/bit_writer.zig").bitWriter;
pub const ChangeDetectionStream = @import("io/change_detection_stream.zig").ChangeDetectionStream;
pub const changeDetectionStream = @import("io/change_detection_stream.zig").changeDetectionStream;
pub const FindByteWriter = @import("io/find_byte_writer.zig").FindByteWriter;
pub const findByteWriter = @import("io/find_byte_writer.zig").findByteWriter;
pub const BufferedAtomicFile = @import("io/buffered_atomic_file.zig").BufferedAtomicFile;
pub const tty = @import("io/tty.zig");
@ -63,7 +49,7 @@ pub fn poll(
.windows = if (is_windows) .{
.first_read_done = false,
.overlapped = [1]windows.OVERLAPPED{
mem.zeroes(windows.OVERLAPPED),
std.mem.zeroes(windows.OVERLAPPED),
} ** enum_fields.len,
.small_bufs = undefined,
.active = .{
@ -436,10 +422,10 @@ pub fn PollFiles(comptime StreamEnum: type) type {
for (&struct_fields, enum_fields) |*struct_field, enum_field| {
struct_field.* = .{
.name = enum_field.name,
.type = fs.File,
.type = std.fs.File,
.default_value_ptr = null,
.is_comptime = false,
.alignment = @alignOf(fs.File),
.alignment = @alignOf(std.fs.File),
};
}
return @Type(.{ .@"struct" = .{
@ -459,6 +445,5 @@ test {
_ = @import("io/bit_reader.zig");
_ = @import("io/bit_writer.zig");
_ = @import("io/buffered_atomic_file.zig");
_ = @import("io/c_writer.zig");
_ = @import("io/test.zig");
}

View File

@ -6,6 +6,8 @@ const BufferedReader = std.io.BufferedReader;
const Allocator = std.mem.Allocator;
const ArrayList = std.ArrayListUnmanaged;
pub const Limited = @import("Reader/Limited.zig");
context: ?*anyopaque,
vtable: *const VTable,
@ -252,6 +254,13 @@ pub fn buffered(r: Reader, buffer: []u8) BufferedReader {
};
}
pub fn limited(r: Reader, limit: Limit) Limited {
return .{
.unlimited_reader = r,
.remaining = limit,
};
}
fn endingRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) RwError!usize {
_ = context;
_ = bw;

View File

@ -0,0 +1,55 @@
const Limited = @This();
const std = @import("../../std.zig");
const Reader = std.io.Reader;
const BufferedWriter = std.io.BufferedWriter;
unlimited_reader: Reader,
remaining: Reader.Limit,
pub fn reader(l: *Limited) Reader {
return .{
.context = l,
.vtable = &.{
.read = passthruRead,
.readVec = passthruReadVec,
.discard = passthruDiscard,
},
};
}
fn passthruRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!usize {
const l: *Limited = @alignCast(@ptrCast(context));
const combined_limit = limit.min(l.remaining);
const n = try l.unlimited_reader.read(bw, combined_limit);
l.remaining.subtract(n);
return n;
}
fn passthruDiscard(context: ?*anyopaque, limit: Reader.Limit) Reader.Error!usize {
const l: *Limited = @alignCast(@ptrCast(context));
const combined_limit = limit.min(l.remaining);
const n = try l.unlimited_reader.discard(combined_limit);
l.remaining.subtract(n);
return n;
}
fn passthruReadVec(context: ?*anyopaque, data: []const []u8) Reader.Error!usize {
const l: *Limited = @alignCast(@ptrCast(context));
if (data.len == 0) return 0;
if (data[0].len >= @intFromEnum(l.limit)) {
const n = try l.unlimited_reader.readVec(&.{l.limit.slice(data[0])});
l.remaining.subtract(n);
return n;
}
var total: usize = 0;
for (data, 0..) |buf, i| {
total += buf.len;
if (total > @intFromEnum(l.limit)) {
const n = try l.unlimited_reader.readVec(data[0..i]);
l.remaining.subtract(n);
return n;
}
}
return 0;
}

View File

@ -95,10 +95,12 @@ pub const Offset = enum(u64) {
};
pub fn writeVec(w: Writer, data: []const []const u8) Error!usize {
assert(data.len > 0);
return w.vtable.writeSplat(w.context, data, 1);
}
pub fn writeSplat(w: Writer, data: []const []const u8, splat: usize) Error!usize {
assert(data.len > 0);
return w.vtable.writeSplat(w.context, data, splat);
}

View File

@ -1,44 +0,0 @@
const std = @import("../std.zig");
const builtin = @import("builtin");
const io = std.io;
const testing = std.testing;
pub const CWriter = io.Writer(*std.c.FILE, std.fs.File.WriteError, cWriterWrite);
pub fn cWriter(c_file: *std.c.FILE) CWriter {
return .{ .context = c_file };
}
fn cWriterWrite(c_file: *std.c.FILE, bytes: []const u8) std.fs.File.WriteError!usize {
const amt_written = std.c.fwrite(bytes.ptr, 1, bytes.len, c_file);
if (amt_written >= 0) return amt_written;
switch (@as(std.c.E, @enumFromInt(std.c._errno().*))) {
.SUCCESS => unreachable,
.INVAL => unreachable,
.FAULT => unreachable,
.AGAIN => unreachable, // this is a blocking API
.BADF => unreachable, // always a race condition
.DESTADDRREQ => unreachable, // connect was never called
.DQUOT => return error.DiskQuota,
.FBIG => return error.FileTooBig,
.IO => return error.InputOutput,
.NOSPC => return error.NoSpaceLeft,
.PERM => return error.PermissionDenied,
.PIPE => return error.BrokenPipe,
else => |err| return std.posix.unexpectedErrno(err),
}
}
test cWriter {
if (!builtin.link_libc or builtin.os.tag == .wasi) return error.SkipZigTest;
const filename = "tmp_io_test_file.txt";
const out_file = std.c.fopen(filename, "w") orelse return error.UnableToOpenTestFile;
defer {
_ = std.c.fclose(out_file);
std.fs.cwd().deleteFileZ(filename) catch {};
}
const writer = cWriter(out_file);
try writer.print("hi: {}\n", .{@as(i32, 123)});
}

View File

@ -1,45 +0,0 @@
const std = @import("../std.zig");
const io = std.io;
const assert = std.debug.assert;
const testing = std.testing;
pub fn LimitedReader(comptime ReaderType: type) type {
return struct {
inner_reader: ReaderType,
bytes_left: u64,
pub const Error = ReaderType.Error;
pub const Reader = io.Reader(*Self, Error, read);
const Self = @This();
pub fn read(self: *Self, dest: []u8) Error!usize {
const max_read = @min(self.bytes_left, dest.len);
const n = try self.inner_reader.read(dest[0..max_read]);
self.bytes_left -= n;
return n;
}
pub fn reader(self: *Self) Reader {
return .{ .context = self };
}
};
}
/// Returns an initialised `LimitedReader`.
/// `bytes_left` is a `u64` to be able to take 64 bit file offsets
pub fn limitedReader(inner_reader: anytype, bytes_left: u64) LimitedReader(@TypeOf(inner_reader)) {
return .{ .inner_reader = inner_reader, .bytes_left = bytes_left };
}
test "basic usage" {
const data = "hello world";
var fbs = std.io.fixedBufferStream(data);
var early_stream = limitedReader(fbs.reader(), 3);
var buf: [5]u8 = undefined;
try testing.expectEqual(@as(usize, 3), try early_stream.reader().read(&buf));
try testing.expectEqualSlices(u8, data[0..3], buf[0..3]);
try testing.expectEqual(@as(usize, 0), try early_stream.reader().read(&buf));
try testing.expectError(error.EndOfStream, early_stream.reader().skipBytes(10, .{}));
}

View File

@ -1919,9 +1919,9 @@ pub const Stream = struct {
limit: std.io.Reader.Limit,
) std.io.Reader.Error!usize {
const buf = limit.slice(try bw.writableSliceGreedy(1));
const status = try readVec(context, &.{buf});
bw.advance(status.len);
return status;
const n = try readVec(context, &.{buf});
bw.advance(n);
return n;
}
fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {

View File

@ -154,35 +154,30 @@ pub fn findEndRecord(seekable_stream: anytype, stream_len: u64) !EndRecord {
pub fn decompress(
method: CompressionMethod,
uncompressed_size: u64,
reader: anytype,
writer: anytype,
reader: *std.io.BufferedReader,
writer: *std.io.BufferedWriter,
compressed_remaining: *u64,
) !u32 {
var hash = std.hash.Crc32.init();
var total_uncompressed: u64 = 0;
switch (method) {
.store => {
var buf: [4096]u8 = undefined;
while (true) {
const len = try reader.read(&buf);
if (len == 0) break;
try writer.writeAll(buf[0..len]);
hash.update(buf[0..len]);
total_uncompressed += @intCast(len);
}
reader.writeAll(writer, .limited(compressed_remaining.*)) catch |err| switch (err) {
error.EndOfStream => return error.ZipDecompressTruncated,
else => |e| return e,
};
total_uncompressed += compressed_remaining.*;
},
.deflate => {
var br = std.io.bufferedReader(reader);
var decompressor = std.compress.flate.decompressor(br.reader());
var decompressor: std.compress.flate.Decompressor = .init(reader);
while (try decompressor.next()) |chunk| {
try writer.writeAll(chunk);
hash.update(chunk);
total_uncompressed += @intCast(chunk.len);
if (total_uncompressed > uncompressed_size)
return error.ZipUncompressSizeTooSmall;
compressed_remaining.* -= chunk.len;
}
if (br.end != br.start)
return error.ZipDeflateTruncated;
},
_ => return error.UnsupportedCompressionMethod,
}
@ -552,15 +547,15 @@ pub fn Iterator(comptime SeekableStream: type) type {
@as(u64, @sizeOf(LocalFileHeader)) +
local_data_header_offset;
try stream.seekTo(local_data_file_offset);
var limited_reader = std.io.limitedReader(stream.context.reader(), self.compressed_size);
var compressed_remaining: u64 = self.compressed_size;
const crc = try decompress(
self.compression_method,
self.uncompressed_size,
limited_reader.reader(),
stream.context.reader(),
out_file.writer(),
&compressed_remaining,
);
if (limited_reader.bytes_left != 0)
return error.ZipDecompressTruncated;
if (compressed_remaining != 0) return error.ZipDecompressTruncated;
return crc;
}
};

View File

@ -1199,7 +1199,7 @@ fn unpackResource(
return try unpackTarball(f, tmp_directory.handle, dcp.reader());
},
.@"tar.zst" => {
const window_size = std.compress.zstd.DecompressorOptions.default_window_buffer_len;
const window_size = std.compress.zstd.default_window_len;
const window_buffer = try f.arena.allocator().create([window_size]u8);
const reader = resource.reader();
var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader);

View File

@ -1490,7 +1490,7 @@ fn readObjectRaw(allocator: Allocator, reader: anytype, size: u64) ![]u8 {
///
/// The format of the delta data is documented in
/// [pack-format](https://git-scm.com/docs/pack-format).
fn expandDelta(base_object: anytype, delta_reader: anytype, writer: anytype) !void {
fn expandDelta(base_object: anytype, delta_reader: *std.io.BufferedReader, writer: *std.io.BufferedWriter) !void {
while (true) {
const inst: packed struct { value: u7, copy: bool } = @bitCast(delta_reader.readByte() catch |e| switch (e) {
error.EndOfStream => return,
@ -1521,13 +1521,11 @@ fn expandDelta(base_object: anytype, delta_reader: anytype, writer: anytype) !vo
var size: u24 = @bitCast(size_parts);
if (size == 0) size = 0x10000;
try base_object.seekTo(offset);
var copy_reader = std.io.limitedReader(base_object.reader(), size);
var fifo = std.fifo.LinearFifo(u8, .{ .Static = 4096 }).init();
try fifo.pump(copy_reader.reader(), writer);
var base_object_br = base_object.reader();
try base_object_br.readAll(writer, .limited(size));
} else if (inst.value != 0) {
var data_reader = std.io.limitedReader(delta_reader, inst.value);
var fifo = std.fifo.LinearFifo(u8, .{ .Static = 4096 }).init();
try fifo.pump(data_reader.reader(), writer);
try delta_reader.readAll(writer, .limited(inst.value));
} else {
return error.InvalidDeltaInstruction;
}