From 25ac70f973c84bb26ebb1b69eda30d2c6207c9b0 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 30 Apr 2025 14:18:45 -0700 Subject: [PATCH] std: WIP update more to new reader/writer delete some bad readers/writers add limited reader update TLS about to do something drastic to compress --- lib/std/compress.zig | 64 +------ lib/std/compress/flate/inflate.zig | 2 +- lib/std/compress/zstandard.zig | 182 ++++++++++-------- lib/std/compress/zstandard/decode/block.zig | 16 +- lib/std/compress/zstandard/decompress.zig | 23 ++- lib/std/crypto/tls.zig | 193 ++++++++++---------- lib/std/crypto/tls/Client.zig | 106 +++++------ lib/std/http/Client.zig | 6 +- lib/std/io.zig | 25 +-- lib/std/io/Reader.zig | 9 + lib/std/io/Reader/Limited.zig | 55 ++++++ lib/std/io/Writer.zig | 2 + lib/std/io/c_writer.zig | 44 ----- lib/std/io/limited_reader.zig | 45 ----- lib/std/net.zig | 6 +- lib/std/zip.zig | 33 ++-- src/Package/Fetch.zig | 2 +- src/Package/Fetch/git.zig | 12 +- 18 files changed, 388 insertions(+), 437 deletions(-) create mode 100644 lib/std/io/Reader/Limited.zig delete mode 100644 lib/std/io/c_writer.zig delete mode 100644 lib/std/io/limited_reader.zig diff --git a/lib/std/compress.zig b/lib/std/compress.zig index 7cc4a80d33..bdc388b84f 100644 --- a/lib/std/compress.zig +++ b/lib/std/compress.zig @@ -1,75 +1,19 @@ //! Compression algorithms. -const std = @import("std.zig"); - pub const flate = @import("compress/flate.zig"); pub const gzip = @import("compress/gzip.zig"); -pub const zlib = @import("compress/zlib.zig"); pub const lzma = @import("compress/lzma.zig"); pub const lzma2 = @import("compress/lzma2.zig"); pub const xz = @import("compress/xz.zig"); +pub const zlib = @import("compress/zlib.zig"); pub const zstd = @import("compress/zstandard.zig"); -pub fn HashedReader(ReaderType: type, HasherType: type) type { - return struct { - child_reader: ReaderType, - hasher: HasherType, - - pub const Error = ReaderType.Error; - pub const Reader = std.io.Reader(*@This(), Error, read); - - pub fn read(self: *@This(), buf: []u8) Error!usize { - const amt = try self.child_reader.read(buf); - self.hasher.update(buf[0..amt]); - return amt; - } - - pub fn reader(self: *@This()) Reader { - return .{ .context = self }; - } - }; -} - -pub fn hashedReader( - reader: anytype, - hasher: anytype, -) HashedReader(@TypeOf(reader), @TypeOf(hasher)) { - return .{ .child_reader = reader, .hasher = hasher }; -} - -pub fn HashedWriter(WriterType: type, HasherType: type) type { - return struct { - child_writer: WriterType, - hasher: HasherType, - - pub const Error = WriterType.Error; - pub const Writer = std.io.Writer(*@This(), Error, write); - - pub fn write(self: *@This(), buf: []const u8) Error!usize { - const amt = try self.child_writer.write(buf); - self.hasher.update(buf[0..amt]); - return amt; - } - - pub fn writer(self: *@This()) Writer { - return .{ .context = self }; - } - }; -} - -pub fn hashedWriter( - writer: anytype, - hasher: anytype, -) HashedWriter(@TypeOf(writer), @TypeOf(hasher)) { - return .{ .child_writer = writer, .hasher = hasher }; -} - test { + _ = flate; + _ = gzip; _ = lzma; _ = lzma2; _ = xz; - _ = zstd; - _ = flate; - _ = gzip; _ = zlib; + _ = zstd; } diff --git a/lib/std/compress/flate/inflate.zig b/lib/std/compress/flate/inflate.zig index 951b13b0c7..1d89e7f7c8 100644 --- a/lib/std/compress/flate/inflate.zig +++ b/lib/std/compress/flate/inflate.zig @@ -821,7 +821,7 @@ pub fn BitReader(comptime T: type) type { /// Skip zero terminated string. pub fn skipStringZ(self: *Self) !void { while (true) { - if (try self.readF(u8, 0) == 0) break; + if (try self.readF(u8, .{}) == 0) break; } } diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index cf1775ed11..4ed14276ea 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -1,7 +1,12 @@ -const std = @import("std"); +const std = @import("../std.zig"); const RingBuffer = std.RingBuffer; const types = @import("zstandard/types.zig"); + +/// Recommended amount by the standard. Lower than this may result in inability +/// to decompress common streams. +pub const default_window_len = 8 * 1024 * 1024; + pub const frame = types.frame; pub const compressed_block = types.compressed_block; @@ -10,7 +15,8 @@ pub const decompress = @import("zstandard/decompress.zig"); pub const Decompressor = struct { const table_size_max = types.compressed_block.table_size_max; - source: *std.io.BufferedReader, + input: *std.io.BufferedReader, + bytes_read: usize, state: enum { NewFrame, InFrame, LastBlock }, decode_state: decompress.block.DecodeState, frame_context: decompress.FrameContext, @@ -23,14 +29,12 @@ pub const Decompressor = struct { verify_checksum: bool, checksum: ?u32, current_frame_decompressed_size: usize, + err: ?Error = null, pub const Options = struct { verify_checksum: bool = true, + /// See `default_window_len`. window_buffer: []u8, - - /// Recommended amount by the standard. Lower than this may result - /// in inability to decompress common streams. - pub const default_window_buffer_len = 8 * 1024 * 1024; }; const WindowBuffer = struct { @@ -45,11 +49,13 @@ pub const Decompressor = struct { MalformedBlock, MalformedFrame, OutOfMemory, + EndOfStream, }; - pub fn init(source: *std.io.BufferedReader, options: Options) Decompressor { + pub fn init(input: *std.io.BufferedReader, options: Options) Decompressor { return .{ - .source = source, + .input = input, + .bytes_read = 0, .state = .NewFrame, .decode_state = undefined, .frame_context = undefined, @@ -65,100 +71,128 @@ pub const Decompressor = struct { }; } - fn frameInit(self: *Decompressor) !void { - const source_reader = self.source; - switch (try decompress.decodeFrameHeader(source_reader)) { + fn frameInit(d: *Decompressor) !void { + const in = d.input; + switch (try decompress.decodeFrameHeader(in, &d.bytes_read)) { .skippable => |header| { - try source_reader.skipBytes(header.frame_size, .{}); - self.state = .NewFrame; + try in.discardAll(header.frame_size); + d.bytes_read += header.frame_size; + d.state = .NewFrame; }, .zstandard => |header| { const frame_context = try decompress.FrameContext.init( header, - self.buffer.data.len, - self.verify_checksum, + d.buffer.data.len, + d.verify_checksum, ); const decode_state = decompress.block.DecodeState.init( - &self.literal_fse_buffer, - &self.match_fse_buffer, - &self.offset_fse_buffer, + &d.literal_fse_buffer, + &d.match_fse_buffer, + &d.offset_fse_buffer, ); - self.decode_state = decode_state; - self.frame_context = frame_context; + d.decode_state = decode_state; + d.frame_context = frame_context; - self.checksum = null; - self.current_frame_decompressed_size = 0; + d.checksum = null; + d.current_frame_decompressed_size = 0; - self.state = .InFrame; + d.state = .InFrame; }, } } pub fn reader(self: *Decompressor) std.io.Reader { - return .{ .context = self }; + return .{ + .context = self, + .vtable = &.{ + .read = read, + .readVec = readVec, + .discard = discard, + }, + }; } - pub fn read(self: *Decompressor, buffer: []u8) Error!usize { - if (buffer.len == 0) return 0; + fn read(context: ?*anyopaque, bw: *std.io.BufferedWriter, limit: std.io.Reader.Limit) std.io.Reader.RwError!usize { + const buf = limit.slice(try bw.writableSliceGreedy(1)); + const n = try readVec(context, &.{buf}); + bw.advance(n); + return n; + } - var size: usize = 0; - while (size == 0) { - while (self.state == .NewFrame) { - const initial_count = self.source.bytes_read; - self.frameInit() catch |err| switch (err) { - error.DictionaryIdFlagUnsupported => return error.DictionaryIdFlagUnsupported, - error.EndOfStream => return if (self.source.bytes_read == initial_count) - 0 - else - error.MalformedFrame, - else => return error.MalformedFrame, - }; - } - size = try self.readInner(buffer); + fn discard(context: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize { + var trash: [128]u8 = undefined; + const buf = limit.slice(&trash); + return readVec(context, &.{buf}); + } + + fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { + const d: *Decompressor = @ptrCast(@alignCast(context)); + if (data.len == 0) return 0; + const buffer = data[0]; + while (d.state == .NewFrame) { + const initial_count = d.bytes_read; + d.frameInit() catch |err| switch (err) { + error.DictionaryIdFlagUnsupported => { + d.err = error.DictionaryIdFlagUnsupported; + return error.ReadFailed; + }, + error.EndOfStream => { + if (d.bytes_read == initial_count) return error.EndOfStream; + d.err = error.MalformedFrame; + return error.ReadFailed; + }, + else => { + d.err = error.MalformedFrame; + return error.ReadFailed; + }, + }; } - return size; + return d.readInner(buffer) catch |err| { + d.err = err; + return error.ReadFailed; + }; } - fn readInner(self: *Decompressor, buffer: []u8) Error!usize { - std.debug.assert(self.state != .NewFrame); + fn readInner(d: *Decompressor, buffer: []u8) Error!usize { + std.debug.assert(d.state != .NewFrame); var ring_buffer = RingBuffer{ - .data = self.buffer.data, - .read_index = self.buffer.read_index, - .write_index = self.buffer.write_index, + .data = d.buffer.data, + .read_index = d.buffer.read_index, + .write_index = d.buffer.write_index, }; defer { - self.buffer.read_index = ring_buffer.read_index; - self.buffer.write_index = ring_buffer.write_index; + d.buffer.read_index = ring_buffer.read_index; + d.buffer.write_index = ring_buffer.write_index; } - const source_reader = self.source; - while (ring_buffer.isEmpty() and self.state != .LastBlock) { - const header_bytes = source_reader.readBytesNoEof(3) catch - return error.MalformedFrame; - const block_header = decompress.block.decodeBlockHeader(&header_bytes); + const in = d.input; + while (ring_buffer.isEmpty() and d.state != .LastBlock) { + const header_bytes = try in.takeArray(3); + d.bytes_read += header_bytes.len; + const block_header = decompress.block.decodeBlockHeader(header_bytes); decompress.block.decodeBlockReader( &ring_buffer, - source_reader, + in, + &d.bytes_read, block_header, - &self.decode_state, - self.frame_context.block_size_max, - &self.literals_buffer, - &self.sequence_buffer, - ) catch - return error.MalformedBlock; + &d.decode_state, + d.frame_context.block_size_max, + &d.literals_buffer, + &d.sequence_buffer, + ) catch return error.MalformedBlock; - if (self.frame_context.content_size) |size| { - if (self.current_frame_decompressed_size > size) return error.MalformedFrame; + if (d.frame_context.content_size) |size| { + if (d.current_frame_decompressed_size > size) return error.MalformedFrame; } const size = ring_buffer.len(); - self.current_frame_decompressed_size += size; + d.current_frame_decompressed_size += size; - if (self.frame_context.hasher_opt) |*hasher| { + if (d.frame_context.hasher_opt) |*hasher| { if (size > 0) { const written_slice = ring_buffer.sliceLast(size); hasher.update(written_slice.first); @@ -166,19 +200,19 @@ pub const Decompressor = struct { } } if (block_header.last_block) { - self.state = .LastBlock; - if (self.frame_context.has_checksum) { - const checksum = source_reader.readInt(u32, .little) catch - return error.MalformedFrame; - if (self.verify_checksum) { - if (self.frame_context.hasher_opt) |*hasher| { + d.state = .LastBlock; + if (d.frame_context.has_checksum) { + const checksum = in.readInt(u32, .little) catch return error.MalformedFrame; + d.bytes_read += 4; + if (d.verify_checksum) { + if (d.frame_context.hasher_opt) |*hasher| { if (checksum != decompress.computeChecksum(hasher)) return error.ChecksumFailure; } } } - if (self.frame_context.content_size) |content_size| { - if (content_size != self.current_frame_decompressed_size) { + if (d.frame_context.content_size) |content_size| { + if (content_size != d.current_frame_decompressed_size) { return error.MalformedFrame; } } @@ -189,8 +223,8 @@ pub const Decompressor = struct { if (size > 0) { ring_buffer.readFirstAssumeLength(buffer, size); } - if (self.state == .LastBlock and ring_buffer.len() == 0) { - self.state = .NewFrame; + if (d.state == .LastBlock and ring_buffer.len() == 0) { + d.state = .NewFrame; } return size; } diff --git a/lib/std/compress/zstandard/decode/block.zig b/lib/std/compress/zstandard/decode/block.zig index e71a20e56b..881915c613 100644 --- a/lib/std/compress/zstandard/decode/block.zig +++ b/lib/std/compress/zstandard/decode/block.zig @@ -807,7 +807,8 @@ pub fn decodeBlockRingBuffer( /// contain enough bytes. pub fn decodeBlockReader( dest: *RingBuffer, - source: anytype, + in: *std.io.BufferedReader, + bytes_read: *usize, block_header: frame.Zstandard.Block.Header, decode_state: *DecodeState, block_size_max: usize, @@ -815,26 +816,29 @@ pub fn decodeBlockReader( sequence_buffer: []u8, ) !void { const block_size = block_header.block_size; - var block_reader_limited = std.io.limitedReader(source, block_size); - const block_reader = block_reader_limited.reader(); if (block_size_max < block_size) return error.BlockSizeOverMaximum; switch (block_header.block_type) { .raw => { if (block_size == 0) return; const slice = dest.sliceAt(dest.write_index, block_size); - try source.readNoEof(slice.first); - try source.readNoEof(slice.second); + var vecs: [2][]u8 = &.{slice.first, slice.second }; + try in.readVecAll(&vecs); + assert(slice.first.len + slice.second.len == block_size); + bytes_read.* += block_size; dest.write_index = dest.mask2(dest.write_index + block_size); decode_state.written_count += block_size; }, .rle => { - const byte = try source.readByte(); + const byte = try in.takeByte(); + bytes_read.* += 1; for (0..block_size) |_| { dest.writeAssumeCapacity(byte); } decode_state.written_count += block_size; }, .compressed => { + var block_reader_limited = std.io.limitedReader(source, block_size); + const block_reader = block_reader_limited.reader(); const literals = try decodeLiteralsSection(block_reader, literals_buffer); const sequences_header = try decodeSequencesHeader(block_reader); diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 90fa12e3cb..6acb21a52d 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -50,7 +50,7 @@ pub const FrameHeader = union(enum) { skippable: SkippableHeader, }; -pub const HeaderError = error{ BadMagic, EndOfStream, ReservedBitSet }; +pub const HeaderError = error{ ReadFailed, BadMagic, EndOfStream, ReservedBitSet }; /// Returns the header of the frame at the beginning of `source`. /// @@ -61,16 +61,21 @@ pub const HeaderError = error{ BadMagic, EndOfStream, ReservedBitSet }; /// - `error.EndOfStream` if `source` contains fewer than 4 bytes /// - `error.ReservedBitSet` if the frame is a Zstandard frame and any of the /// reserved bits are set -pub fn decodeFrameHeader(source: anytype) (@TypeOf(source).Error || HeaderError)!FrameHeader { - const magic = try source.readInt(u32, .little); +pub fn decodeFrameHeader(br: *std.io.BufferedReader, bytes_read: *usize) HeaderError!FrameHeader { + const magic = try br.readInt(u32, .little); + bytes_read.* += 4; const frame_type = try frameType(magic); switch (frame_type) { - .zstandard => return FrameHeader{ .zstandard = try decodeZstandardHeader(source) }, - .skippable => return FrameHeader{ - .skippable = .{ - .magic_number = magic, - .frame_size = try source.readInt(u32, .little), - }, + .zstandard => return .{ .zstandard = try decodeZstandardHeader(br, bytes_read) }, + .skippable => { + const result: FrameHeader = .{ + .skippable = .{ + .magic_number = magic, + .frame_size = try br.readInt(u32, .little), + }, + }; + bytes_read.* += 4; + return result; }, } } diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 74113225cb..c2071bcbc5 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -49,8 +49,8 @@ pub const hello_retry_request_sequence = [32]u8{ }; pub const close_notify_alert = [_]u8{ - @intFromEnum(AlertLevel.warning), - @intFromEnum(AlertDescription.close_notify), + @intFromEnum(Alert.Level.warning), + @intFromEnum(Alert.Description.close_notify), }; pub const ProtocolVersion = enum(u16) { @@ -138,103 +138,108 @@ pub const ExtensionType = enum(u16) { _, }; -pub const AlertLevel = enum(u8) { - warning = 1, - fatal = 2, - _, -}; +pub const Alert = struct { + level: Level, + description: Description, -pub const AlertDescription = enum(u8) { - pub const Error = error{ - TlsAlertUnexpectedMessage, - TlsAlertBadRecordMac, - TlsAlertRecordOverflow, - TlsAlertHandshakeFailure, - TlsAlertBadCertificate, - TlsAlertUnsupportedCertificate, - TlsAlertCertificateRevoked, - TlsAlertCertificateExpired, - TlsAlertCertificateUnknown, - TlsAlertIllegalParameter, - TlsAlertUnknownCa, - TlsAlertAccessDenied, - TlsAlertDecodeError, - TlsAlertDecryptError, - TlsAlertProtocolVersion, - TlsAlertInsufficientSecurity, - TlsAlertInternalError, - TlsAlertInappropriateFallback, - TlsAlertMissingExtension, - TlsAlertUnsupportedExtension, - TlsAlertUnrecognizedName, - TlsAlertBadCertificateStatusResponse, - TlsAlertUnknownPskIdentity, - TlsAlertCertificateRequired, - TlsAlertNoApplicationProtocol, - TlsAlertUnknown, + pub const Level = enum(u8) { + warning = 1, + fatal = 2, + _, }; - close_notify = 0, - unexpected_message = 10, - bad_record_mac = 20, - record_overflow = 22, - handshake_failure = 40, - bad_certificate = 42, - unsupported_certificate = 43, - certificate_revoked = 44, - certificate_expired = 45, - certificate_unknown = 46, - illegal_parameter = 47, - unknown_ca = 48, - access_denied = 49, - decode_error = 50, - decrypt_error = 51, - protocol_version = 70, - insufficient_security = 71, - internal_error = 80, - inappropriate_fallback = 86, - user_canceled = 90, - missing_extension = 109, - unsupported_extension = 110, - unrecognized_name = 112, - bad_certificate_status_response = 113, - unknown_psk_identity = 115, - certificate_required = 116, - no_application_protocol = 120, - _, + pub const Description = enum(u8) { + pub const Error = error{ + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + }; - pub fn toError(alert: AlertDescription) Error!void { - switch (alert) { - .close_notify => {}, // not an error - .unexpected_message => return error.TlsAlertUnexpectedMessage, - .bad_record_mac => return error.TlsAlertBadRecordMac, - .record_overflow => return error.TlsAlertRecordOverflow, - .handshake_failure => return error.TlsAlertHandshakeFailure, - .bad_certificate => return error.TlsAlertBadCertificate, - .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, - .certificate_revoked => return error.TlsAlertCertificateRevoked, - .certificate_expired => return error.TlsAlertCertificateExpired, - .certificate_unknown => return error.TlsAlertCertificateUnknown, - .illegal_parameter => return error.TlsAlertIllegalParameter, - .unknown_ca => return error.TlsAlertUnknownCa, - .access_denied => return error.TlsAlertAccessDenied, - .decode_error => return error.TlsAlertDecodeError, - .decrypt_error => return error.TlsAlertDecryptError, - .protocol_version => return error.TlsAlertProtocolVersion, - .insufficient_security => return error.TlsAlertInsufficientSecurity, - .internal_error => return error.TlsAlertInternalError, - .inappropriate_fallback => return error.TlsAlertInappropriateFallback, - .user_canceled => {}, // not an error - .missing_extension => return error.TlsAlertMissingExtension, - .unsupported_extension => return error.TlsAlertUnsupportedExtension, - .unrecognized_name => return error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, - .certificate_required => return error.TlsAlertCertificateRequired, - .no_application_protocol => return error.TlsAlertNoApplicationProtocol, - _ => return error.TlsAlertUnknown, + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, + + pub fn toError(description: Description) Error!void { + switch (description) { + .close_notify => {}, // not an error + .unexpected_message => return error.TlsAlertUnexpectedMessage, + .bad_record_mac => return error.TlsAlertBadRecordMac, + .record_overflow => return error.TlsAlertRecordOverflow, + .handshake_failure => return error.TlsAlertHandshakeFailure, + .bad_certificate => return error.TlsAlertBadCertificate, + .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, + .certificate_revoked => return error.TlsAlertCertificateRevoked, + .certificate_expired => return error.TlsAlertCertificateExpired, + .certificate_unknown => return error.TlsAlertCertificateUnknown, + .illegal_parameter => return error.TlsAlertIllegalParameter, + .unknown_ca => return error.TlsAlertUnknownCa, + .access_denied => return error.TlsAlertAccessDenied, + .decode_error => return error.TlsAlertDecodeError, + .decrypt_error => return error.TlsAlertDecryptError, + .protocol_version => return error.TlsAlertProtocolVersion, + .insufficient_security => return error.TlsAlertInsufficientSecurity, + .internal_error => return error.TlsAlertInternalError, + .inappropriate_fallback => return error.TlsAlertInappropriateFallback, + .user_canceled => {}, // not an error + .missing_extension => return error.TlsAlertMissingExtension, + .unsupported_extension => return error.TlsAlertUnsupportedExtension, + .unrecognized_name => return error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, + .certificate_required => return error.TlsAlertCertificateRequired, + .no_application_protocol => return error.TlsAlertNoApplicationProtocol, + _ => return error.TlsAlertUnknown, + } } - } + }; }; pub const SignatureScheme = enum(u16) { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 00e5a03bef..a393691828 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -39,8 +39,9 @@ output: *std.io.BufferedWriter, /// /// Its buffer aliases the buffer of `input`. reader: std.io.BufferedReader, -/// Populated under various error conditions. -diagnostics: Diagnostics, +/// Populated when `error.TlsAlert` is returned. +alert: ?tls.Alert, +read_err: ?ReadError, tls_version: tls.ProtocolVersion, read_seq: u64, @@ -69,15 +70,16 @@ application_cipher: tls.ApplicationCipher, /// this connection. ssl_key_log: ?*SslKeyLog, -pub const Diagnostics = union(enum) { - /// Any `ReadFailure` and `WriteFailure` was due to `input` or `output` - /// returning the error, respectively. - transitive, - /// Populated on `error.TlsAlert`. - /// - /// If this isn't a error alert, then it's a closure alert, which makes - /// no sense in a handshake. - alert: tls.AlertDescription, +pub const ReadError = error{ + /// The alert description will be stored in `alert`. + TlsAlert, + TlsBadLength, + TlsBadRecordMac, + TlsConnectionTruncated, + TlsDecodeError, + TlsRecordOverflow, + TlsUnexpectedMessage, + TlsIllegalParameter, }; pub const SslKeyLog = struct { @@ -128,14 +130,13 @@ pub const Options = struct { }; const InitError = error{ - //OutOfMemory, - WriteFailure, - ReadFailure, + WriteFailed, + ReadFailed, InsufficientEntropy, DiskQuota, LockViolation, NotOpenForWriting, - /// The alert description will be stored in `Options.Diagnostics.alert`. + /// The alert description will be stored in `alert`. TlsAlert, TlsUnexpectedMessage, TlsIllegalParameter, @@ -192,7 +193,7 @@ pub fn init( ) InitError!void { assert(input.storage.buffer.len >= min_buffer_len); assert(output.buffer.len >= min_buffer_len); - client.diagnostics = .transient; + client.alert = null; const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -417,10 +418,10 @@ pub fn init( switch (ct) { .alert => { ctd.ensure(2) catch continue :fragment; - const level = ctd.decode(tls.AlertLevel); - const desc = ctd.decode(tls.AlertDescription); - _ = level; - client.diagnostics = .{ .alert = desc }; + client.alert = .{ + .level = ctd.decode(tls.Alert.Level), + .description = ctd.decode(tls.Alert.Description), + }; return error.TlsAlert; }, .change_cipher_spec => { @@ -924,7 +925,7 @@ pub fn writer(c: *Client) std.io.Writer { fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize { const c: *Client = @alignCast(@ptrCast(context)); const sliced_data = if (splat == 0) data[0..data.len -| 1] else data; - const output = &c.output; + const output = c.output; const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); var total_clear: usize = 0; var ciphertext_end: usize = 0; @@ -942,7 +943,7 @@ fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) std.i /// distinguish between a properly finished TLS session, or a truncation /// attack. pub fn end(c: *Client) std.io.Writer.Error!void { - const output = &c.output; + const output = c.output; const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert); output.advance(prepared.cleartext_len); @@ -1062,16 +1063,16 @@ fn read( context: ?*anyopaque, bw: *std.io.BufferedWriter, limit: std.io.Reader.Limit, -) std.io.Reader.RwError!std.io.Reader.Status { +) std.io.Reader.RwError!usize { const buf = limit.slice(try bw.writableSliceGreedy(1)); - const status = try readVec(context, &.{buf}); - bw.advance(status.len); - return status; + const n = try readVec(context, &.{buf}); + bw.advance(n); + return n; } fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { const c: *Client = @ptrCast(@alignCast(context)); - if (c.eof()) return .{ .end = true }; + if (c.eof()) return error.EndOfStream; var vp: VecPut = .{ .iovecs = data }; @@ -1093,11 +1094,11 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { if (c.received_close_notify) { c.partial_ciphertext_end = 0; assert(vp.total == amt); - return .{ .len = amt, .end = c.eof() }; + return amt; } else if (amt > 0) { // We don't need more data, so don't call read. assert(vp.total == amt); - return .{ .len = amt, .end = c.eof() }; + return amt; } } @@ -1149,7 +1150,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { if (c.allow_truncation_attacks) { c.received_close_notify = true; } else { - return error.TlsConnectionTruncated; + return failRead(c, error.TlsConnectionTruncated); } } @@ -1168,7 +1169,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { // Perfect split. if (frag.ptr == frag1.ptr) { c.partial_ciphertext_end = c.partial_ciphertext_idx; - return .{ .len = vp.total, .end = c.eof() }; + return vp.total; } frag = frag1; in = 0; @@ -1188,7 +1189,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); const record_len = (record_len_byte_0 << 8) | record_len_byte_1; - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; + if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow); const full_record_len = record_len + tls.record_header_len; const second_len = full_record_len - first.len; @@ -1208,7 +1209,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { in += 2; _ = legacy_version; const record_len = mem.readInt(u16, frag[in..][0..2], .big); - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; + if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow); in += 2; const the_end = in + record_len; if (the_end > frag.len) { @@ -1255,7 +1256,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { &cleartext_stack_buffer; const cleartext = cleartext_buf[0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch - return error.TlsBadRecordMac; + return failRead(c, error.TlsBadRecordMac); const msg = mem.trimEnd(u8, cleartext, "\x00"); break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; }, @@ -1287,7 +1288,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { &cleartext_stack_buffer; const cleartext = cleartext_buf[0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch - return error.TlsBadRecordMac; + return failRead(c, error.TlsBadRecordMac); break :cleartext .{ cleartext, ct }; }, else => unreachable, @@ -1296,23 +1297,24 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { c.read_seq = try std.math.add(u64, c.read_seq, 1); switch (inner_ct) { .alert => { - if (cleartext.len != 2) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - _ = level; - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - switch (desc) { + if (cleartext.len != 2) return failRead(c, error.TlsDecodeError); + const alert: tls.Alert = .{ + .level = @enumFromInt(cleartext[0]), + .description = @enumFromInt(cleartext[1]), + }; + switch (alert.description) { .close_notify => { c.received_close_notify = true; c.partial_ciphertext_end = c.partial_ciphertext_idx; - return .{ .len = vp.total, .end = c.eof() }; + return vp.total; }, .user_canceled => { // TODO: handle server-side closures - return error.TlsUnexpectedMessage; + return failRead(c, error.TlsUnexpectedMessage); }, else => { - c.diagnostics = .{ .alert = desc }; - return error.TlsAlert; + c.alert = alert; + return failRead(c, error.TlsAlert); }, } }, @@ -1324,8 +1326,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); ct_i += 3; const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len) - return error.TlsBadLength; + if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength); const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { .new_session_ticket => { @@ -1371,12 +1372,10 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { c.write_seq = 0; }, .update_not_requested => {}, - _ => return error.TlsIllegalParameter, + _ => return failRead(c, error.TlsIllegalParameter), } }, - else => { - return error.TlsUnexpectedMessage; - }, + else => return failRead(c, error.TlsUnexpectedMessage), } ct_i = next_handshake_i; if (ct_i >= cleartext.len) break; @@ -1411,7 +1410,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { vp.next(cleartext.len); } }, - else => return error.TlsUnexpectedMessage, + else => return failRead(c, error.TlsUnexpectedMessage), } in = end; } @@ -1423,6 +1422,11 @@ fn discard(context: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error @panic("TODO"); } +fn failRead(c: *Client, err: ReadError) error{ReadFailed} { + c.read_err = err; + return error.ReadFailed; +} + fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void { const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false; defer if (locked) key_log_file.unlock(); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index ce82c19f1d..4c014908e0 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -28,7 +28,7 @@ tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.cr /// If non-null, ssl secrets are logged to a stream. Creating such a stream /// allows other processes with access to that stream to decrypt all /// traffic over connections created with this `Client`. -ssl_key_logger: ?*std.io.BufferedWriter = null, +ssl_key_log: ?*std.io.BufferedWriter = null, /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. @@ -342,7 +342,7 @@ pub const Connection = struct { tls.client.init(&tls.reader, &tls.writer, .{ .host = .{ .explicit = remote_host }, .ca = .{ .bundle = client.ca_bundle }, - .ssl_key_logger = client.ssl_key_logger, + .ssl_key_log = client.ssl_key_log, }) catch return error.TlsInitializationFailed; // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. @@ -1671,7 +1671,7 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult { const decompress_buffer: []u8 = switch (response.head.content_encoding) { .identity => &.{}, .zstd => options.decompress_buffer orelse - try client.allocator.alloc(u8, std.compress.zstd.Decompressor.Options.default_window_buffer_len * 2), + try client.allocator.alloc(u8, std.compress.zstd.default_window_len * 2), else => options.decompress_buffer orelse try client.allocator.alloc(u8, 8 * 1024), }; defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer); diff --git a/lib/std/io.zig b/lib/std/io.zig index 5b316c65d1..292812752d 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -1,16 +1,11 @@ -const std = @import("std.zig"); const builtin = @import("builtin"); -const root = @import("root"); -const c = std.c; const is_windows = builtin.os.tag == .windows; + +const std = @import("std.zig"); const windows = std.os.windows; const posix = std.posix; const math = std.math; const assert = std.debug.assert; -const fs = std.fs; -const mem = std.mem; -const meta = std.meta; -const File = std.fs.File; const Allocator = std.mem.Allocator; const Alignment = std.mem.Alignment; @@ -21,12 +16,6 @@ pub const BufferedReader = @import("io/BufferedReader.zig"); pub const BufferedWriter = @import("io/BufferedWriter.zig"); pub const AllocatingWriter = @import("io/AllocatingWriter.zig"); -pub const CWriter = @import("io/c_writer.zig").CWriter; -pub const cWriter = @import("io/c_writer.zig").cWriter; - -pub const LimitedReader = @import("io/limited_reader.zig").LimitedReader; -pub const limitedReader = @import("io/limited_reader.zig").limitedReader; - pub const MultiWriter = @import("io/multi_writer.zig").MultiWriter; pub const multiWriter = @import("io/multi_writer.zig").multiWriter; @@ -38,9 +27,6 @@ pub const bitWriter = @import("io/bit_writer.zig").bitWriter; pub const ChangeDetectionStream = @import("io/change_detection_stream.zig").ChangeDetectionStream; pub const changeDetectionStream = @import("io/change_detection_stream.zig").changeDetectionStream; -pub const FindByteWriter = @import("io/find_byte_writer.zig").FindByteWriter; -pub const findByteWriter = @import("io/find_byte_writer.zig").findByteWriter; - pub const BufferedAtomicFile = @import("io/buffered_atomic_file.zig").BufferedAtomicFile; pub const tty = @import("io/tty.zig"); @@ -63,7 +49,7 @@ pub fn poll( .windows = if (is_windows) .{ .first_read_done = false, .overlapped = [1]windows.OVERLAPPED{ - mem.zeroes(windows.OVERLAPPED), + std.mem.zeroes(windows.OVERLAPPED), } ** enum_fields.len, .small_bufs = undefined, .active = .{ @@ -436,10 +422,10 @@ pub fn PollFiles(comptime StreamEnum: type) type { for (&struct_fields, enum_fields) |*struct_field, enum_field| { struct_field.* = .{ .name = enum_field.name, - .type = fs.File, + .type = std.fs.File, .default_value_ptr = null, .is_comptime = false, - .alignment = @alignOf(fs.File), + .alignment = @alignOf(std.fs.File), }; } return @Type(.{ .@"struct" = .{ @@ -459,6 +445,5 @@ test { _ = @import("io/bit_reader.zig"); _ = @import("io/bit_writer.zig"); _ = @import("io/buffered_atomic_file.zig"); - _ = @import("io/c_writer.zig"); _ = @import("io/test.zig"); } diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index 98b8992df4..2416aa0c54 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -6,6 +6,8 @@ const BufferedReader = std.io.BufferedReader; const Allocator = std.mem.Allocator; const ArrayList = std.ArrayListUnmanaged; +pub const Limited = @import("Reader/Limited.zig"); + context: ?*anyopaque, vtable: *const VTable, @@ -252,6 +254,13 @@ pub fn buffered(r: Reader, buffer: []u8) BufferedReader { }; } +pub fn limited(r: Reader, limit: Limit) Limited { + return .{ + .unlimited_reader = r, + .remaining = limit, + }; +} + fn endingRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) RwError!usize { _ = context; _ = bw; diff --git a/lib/std/io/Reader/Limited.zig b/lib/std/io/Reader/Limited.zig new file mode 100644 index 0000000000..fd7197ae47 --- /dev/null +++ b/lib/std/io/Reader/Limited.zig @@ -0,0 +1,55 @@ +const Limited = @This(); + +const std = @import("../../std.zig"); +const Reader = std.io.Reader; +const BufferedWriter = std.io.BufferedWriter; + +unlimited_reader: Reader, +remaining: Reader.Limit, + +pub fn reader(l: *Limited) Reader { + return .{ + .context = l, + .vtable = &.{ + .read = passthruRead, + .readVec = passthruReadVec, + .discard = passthruDiscard, + }, + }; +} + +fn passthruRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!usize { + const l: *Limited = @alignCast(@ptrCast(context)); + const combined_limit = limit.min(l.remaining); + const n = try l.unlimited_reader.read(bw, combined_limit); + l.remaining.subtract(n); + return n; +} + +fn passthruDiscard(context: ?*anyopaque, limit: Reader.Limit) Reader.Error!usize { + const l: *Limited = @alignCast(@ptrCast(context)); + const combined_limit = limit.min(l.remaining); + const n = try l.unlimited_reader.discard(combined_limit); + l.remaining.subtract(n); + return n; +} + +fn passthruReadVec(context: ?*anyopaque, data: []const []u8) Reader.Error!usize { + const l: *Limited = @alignCast(@ptrCast(context)); + if (data.len == 0) return 0; + if (data[0].len >= @intFromEnum(l.limit)) { + const n = try l.unlimited_reader.readVec(&.{l.limit.slice(data[0])}); + l.remaining.subtract(n); + return n; + } + var total: usize = 0; + for (data, 0..) |buf, i| { + total += buf.len; + if (total > @intFromEnum(l.limit)) { + const n = try l.unlimited_reader.readVec(data[0..i]); + l.remaining.subtract(n); + return n; + } + } + return 0; +} diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index 97c0d97552..10128e14c5 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -95,10 +95,12 @@ pub const Offset = enum(u64) { }; pub fn writeVec(w: Writer, data: []const []const u8) Error!usize { + assert(data.len > 0); return w.vtable.writeSplat(w.context, data, 1); } pub fn writeSplat(w: Writer, data: []const []const u8, splat: usize) Error!usize { + assert(data.len > 0); return w.vtable.writeSplat(w.context, data, splat); } diff --git a/lib/std/io/c_writer.zig b/lib/std/io/c_writer.zig deleted file mode 100644 index 8c25e51029..0000000000 --- a/lib/std/io/c_writer.zig +++ /dev/null @@ -1,44 +0,0 @@ -const std = @import("../std.zig"); -const builtin = @import("builtin"); -const io = std.io; -const testing = std.testing; - -pub const CWriter = io.Writer(*std.c.FILE, std.fs.File.WriteError, cWriterWrite); - -pub fn cWriter(c_file: *std.c.FILE) CWriter { - return .{ .context = c_file }; -} - -fn cWriterWrite(c_file: *std.c.FILE, bytes: []const u8) std.fs.File.WriteError!usize { - const amt_written = std.c.fwrite(bytes.ptr, 1, bytes.len, c_file); - if (amt_written >= 0) return amt_written; - switch (@as(std.c.E, @enumFromInt(std.c._errno().*))) { - .SUCCESS => unreachable, - .INVAL => unreachable, - .FAULT => unreachable, - .AGAIN => unreachable, // this is a blocking API - .BADF => unreachable, // always a race condition - .DESTADDRREQ => unreachable, // connect was never called - .DQUOT => return error.DiskQuota, - .FBIG => return error.FileTooBig, - .IO => return error.InputOutput, - .NOSPC => return error.NoSpaceLeft, - .PERM => return error.PermissionDenied, - .PIPE => return error.BrokenPipe, - else => |err| return std.posix.unexpectedErrno(err), - } -} - -test cWriter { - if (!builtin.link_libc or builtin.os.tag == .wasi) return error.SkipZigTest; - - const filename = "tmp_io_test_file.txt"; - const out_file = std.c.fopen(filename, "w") orelse return error.UnableToOpenTestFile; - defer { - _ = std.c.fclose(out_file); - std.fs.cwd().deleteFileZ(filename) catch {}; - } - - const writer = cWriter(out_file); - try writer.print("hi: {}\n", .{@as(i32, 123)}); -} diff --git a/lib/std/io/limited_reader.zig b/lib/std/io/limited_reader.zig deleted file mode 100644 index d7e2503881..0000000000 --- a/lib/std/io/limited_reader.zig +++ /dev/null @@ -1,45 +0,0 @@ -const std = @import("../std.zig"); -const io = std.io; -const assert = std.debug.assert; -const testing = std.testing; - -pub fn LimitedReader(comptime ReaderType: type) type { - return struct { - inner_reader: ReaderType, - bytes_left: u64, - - pub const Error = ReaderType.Error; - pub const Reader = io.Reader(*Self, Error, read); - - const Self = @This(); - - pub fn read(self: *Self, dest: []u8) Error!usize { - const max_read = @min(self.bytes_left, dest.len); - const n = try self.inner_reader.read(dest[0..max_read]); - self.bytes_left -= n; - return n; - } - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - }; -} - -/// Returns an initialised `LimitedReader`. -/// `bytes_left` is a `u64` to be able to take 64 bit file offsets -pub fn limitedReader(inner_reader: anytype, bytes_left: u64) LimitedReader(@TypeOf(inner_reader)) { - return .{ .inner_reader = inner_reader, .bytes_left = bytes_left }; -} - -test "basic usage" { - const data = "hello world"; - var fbs = std.io.fixedBufferStream(data); - var early_stream = limitedReader(fbs.reader(), 3); - - var buf: [5]u8 = undefined; - try testing.expectEqual(@as(usize, 3), try early_stream.reader().read(&buf)); - try testing.expectEqualSlices(u8, data[0..3], buf[0..3]); - try testing.expectEqual(@as(usize, 0), try early_stream.reader().read(&buf)); - try testing.expectError(error.EndOfStream, early_stream.reader().skipBytes(10, .{})); -} diff --git a/lib/std/net.zig b/lib/std/net.zig index b362728857..b191383500 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1919,9 +1919,9 @@ pub const Stream = struct { limit: std.io.Reader.Limit, ) std.io.Reader.Error!usize { const buf = limit.slice(try bw.writableSliceGreedy(1)); - const status = try readVec(context, &.{buf}); - bw.advance(status.len); - return status; + const n = try readVec(context, &.{buf}); + bw.advance(n); + return n; } fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { diff --git a/lib/std/zip.zig b/lib/std/zip.zig index c149584fd5..66a7da1021 100644 --- a/lib/std/zip.zig +++ b/lib/std/zip.zig @@ -154,35 +154,30 @@ pub fn findEndRecord(seekable_stream: anytype, stream_len: u64) !EndRecord { pub fn decompress( method: CompressionMethod, uncompressed_size: u64, - reader: anytype, - writer: anytype, + reader: *std.io.BufferedReader, + writer: *std.io.BufferedWriter, + compressed_remaining: *u64, ) !u32 { var hash = std.hash.Crc32.init(); - var total_uncompressed: u64 = 0; switch (method) { .store => { - var buf: [4096]u8 = undefined; - while (true) { - const len = try reader.read(&buf); - if (len == 0) break; - try writer.writeAll(buf[0..len]); - hash.update(buf[0..len]); - total_uncompressed += @intCast(len); - } + reader.writeAll(writer, .limited(compressed_remaining.*)) catch |err| switch (err) { + error.EndOfStream => return error.ZipDecompressTruncated, + else => |e| return e, + }; + total_uncompressed += compressed_remaining.*; }, .deflate => { - var br = std.io.bufferedReader(reader); - var decompressor = std.compress.flate.decompressor(br.reader()); + var decompressor: std.compress.flate.Decompressor = .init(reader); while (try decompressor.next()) |chunk| { try writer.writeAll(chunk); hash.update(chunk); total_uncompressed += @intCast(chunk.len); if (total_uncompressed > uncompressed_size) return error.ZipUncompressSizeTooSmall; + compressed_remaining.* -= chunk.len; } - if (br.end != br.start) - return error.ZipDeflateTruncated; }, _ => return error.UnsupportedCompressionMethod, } @@ -552,15 +547,15 @@ pub fn Iterator(comptime SeekableStream: type) type { @as(u64, @sizeOf(LocalFileHeader)) + local_data_header_offset; try stream.seekTo(local_data_file_offset); - var limited_reader = std.io.limitedReader(stream.context.reader(), self.compressed_size); + var compressed_remaining: u64 = self.compressed_size; const crc = try decompress( self.compression_method, self.uncompressed_size, - limited_reader.reader(), + stream.context.reader(), out_file.writer(), + &compressed_remaining, ); - if (limited_reader.bytes_left != 0) - return error.ZipDecompressTruncated; + if (compressed_remaining != 0) return error.ZipDecompressTruncated; return crc; } }; diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index 950a44e77e..6f49cbdb3f 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -1199,7 +1199,7 @@ fn unpackResource( return try unpackTarball(f, tmp_directory.handle, dcp.reader()); }, .@"tar.zst" => { - const window_size = std.compress.zstd.DecompressorOptions.default_window_buffer_len; + const window_size = std.compress.zstd.default_window_len; const window_buffer = try f.arena.allocator().create([window_size]u8); const reader = resource.reader(); var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig index a54cb476e7..fe8ee5f667 100644 --- a/src/Package/Fetch/git.zig +++ b/src/Package/Fetch/git.zig @@ -1490,7 +1490,7 @@ fn readObjectRaw(allocator: Allocator, reader: anytype, size: u64) ![]u8 { /// /// The format of the delta data is documented in /// [pack-format](https://git-scm.com/docs/pack-format). -fn expandDelta(base_object: anytype, delta_reader: anytype, writer: anytype) !void { +fn expandDelta(base_object: anytype, delta_reader: *std.io.BufferedReader, writer: *std.io.BufferedWriter) !void { while (true) { const inst: packed struct { value: u7, copy: bool } = @bitCast(delta_reader.readByte() catch |e| switch (e) { error.EndOfStream => return, @@ -1521,13 +1521,11 @@ fn expandDelta(base_object: anytype, delta_reader: anytype, writer: anytype) !vo var size: u24 = @bitCast(size_parts); if (size == 0) size = 0x10000; try base_object.seekTo(offset); - var copy_reader = std.io.limitedReader(base_object.reader(), size); - var fifo = std.fifo.LinearFifo(u8, .{ .Static = 4096 }).init(); - try fifo.pump(copy_reader.reader(), writer); + + var base_object_br = base_object.reader(); + try base_object_br.readAll(writer, .limited(size)); } else if (inst.value != 0) { - var data_reader = std.io.limitedReader(delta_reader, inst.value); - var fifo = std.fifo.LinearFifo(u8, .{ .Static = 4096 }).init(); - try fifo.pump(data_reader.reader(), writer); + try delta_reader.readAll(writer, .limited(inst.value)); } else { return error.InvalidDeltaInstruction; }