diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index 00f475df00..72ed40e03d 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -12,12 +12,11 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool const Self = @This(); allocator: Allocator, - in_reader: ReaderType, - state: enum { NewFrame, InFrame }, + source: std.io.CountingReader(ReaderType), + state: enum { NewFrame, InFrame, LastBlock }, decode_state: decompress.block.DecodeState, frame_context: decompress.FrameContext, buffer: RingBuffer, - last_block: bool, literal_fse_buffer: []types.compressed_block.Table.Fse, match_fse_buffer: []types.compressed_block.Table.Fse, offset_fse_buffer: []types.compressed_block.Table.Fse, @@ -32,12 +31,11 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool pub fn init(allocator: Allocator, source: ReaderType) !Self { return Self{ .allocator = allocator, - .in_reader = source, + .source = std.io.countingReader(source), .state = .NewFrame, .decode_state = undefined, .frame_context = undefined, .buffer = undefined, - .last_block = undefined, .literal_fse_buffer = undefined, .match_fse_buffer = undefined, .offset_fse_buffer = undefined, @@ -48,22 +46,16 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool } fn frameInit(self: *Self) !void { - var bytes: [4]u8 = undefined; - const bytes_read = try self.in_reader.readAll(&bytes); - if (bytes_read == 0) return error.NoBytes; - if (bytes_read < 4) return error.EndOfStream; - const frame_type = try decompress.frameType(std.mem.readIntLittle(u32, &bytes)); - switch (frame_type) { - .skippable => { - const size = try self.in_reader.readIntLittle(u32); - try self.in_reader.skipBytes(size, .{}); + const source_reader = self.source.reader(); + switch (try decompress.decodeFrameHeader(source_reader)) { + .skippable => |header| { + try source_reader.skipBytes(header.frame_size, .{}); self.state = .NewFrame; }, - .zstandard => { + .zstandard => |header| { const frame_context = context: { - const frame_header = try decompress.decodeZstandardHeader(self.in_reader); break :context try decompress.FrameContext.init( - frame_header, + header, window_size_max, verify_checksum, ); @@ -112,7 +104,6 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool self.frame_context = frame_context; self.checksum = if (verify_checksum) null else {}; - self.last_block = false; self.state = .InFrame; }, @@ -134,10 +125,14 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool } pub fn read(self: *Self, buffer: []u8) Error!usize { + const initial_count = self.source.bytes_read; if (buffer.len == 0) return 0; while (self.state == .NewFrame) { self.frameInit() catch |err| switch (err) { - error.NoBytes => return 0, + error.EndOfStream => return if (self.source.bytes_read == initial_count) + 0 + else + error.MalformedFrame, error.OutOfMemory => return error.OutOfMemory, else => return error.MalformedFrame, }; @@ -147,15 +142,16 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool } fn readInner(self: *Self, buffer: []u8) Error!usize { - std.debug.assert(self.state == .InFrame); + std.debug.assert(self.state != .NewFrame); - if (self.buffer.isEmpty() and !self.last_block) { - const header_bytes = self.in_reader.readBytesNoEof(3) catch return error.MalformedFrame; + const source_reader = self.source.reader(); + while (self.buffer.isEmpty() and self.state != .LastBlock) { + const header_bytes = source_reader.readBytesNoEof(3) catch return error.MalformedFrame; const block_header = decompress.block.decodeBlockHeader(&header_bytes); decompress.block.decodeBlockReader( &self.buffer, - self.in_reader, + source_reader, block_header, &self.decode_state, self.frame_context.block_size_max, @@ -164,15 +160,18 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool ) catch return error.MalformedBlock; - self.last_block = block_header.last_block; if (self.frame_context.hasher_opt) |*hasher| { - const written_slice = self.buffer.sliceLast(self.buffer.len()); - hasher.update(written_slice.first); - hasher.update(written_slice.second); + const size = self.buffer.len(); + if (size > 0) { + const written_slice = self.buffer.sliceLast(size); + hasher.update(written_slice.first); + hasher.update(written_slice.second); + } } if (block_header.last_block) { + self.state = .LastBlock; if (self.frame_context.has_checksum) { - const checksum = self.in_reader.readIntLittle(u32) catch return error.MalformedFrame; + const checksum = source_reader.readIntLittle(u32) catch return error.MalformedFrame; if (comptime verify_checksum) { if (self.frame_context.hasher_opt) |*hasher| { if (checksum != decompress.computeChecksum(hasher)) return error.ChecksumFailure; @@ -187,7 +186,7 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool while (written_count < decoded_data_len and written_count < buffer.len) : (written_count += 1) { buffer[written_count] = self.buffer.read().?; } - if (self.buffer.len() == 0) { + if (self.state == .LastBlock and self.buffer.len() == 0) { self.state = .NewFrame; self.allocator.free(self.literal_fse_buffer); self.allocator.free(self.match_fse_buffer); @@ -219,7 +218,7 @@ fn testReader(data: []const u8, comptime expected: []const u8) !void { try std.testing.expectEqualSlices(u8, expected, buf); } -test "decompression" { +test "zstandard decompression" { const uncompressed = @embedFile("testdata/rfc8478.txt"); const compressed3 = @embedFile("testdata/rfc8478.txt.zst.3"); const compressed19 = @embedFile("testdata/rfc8478.txt.zst.19"); diff --git a/lib/std/compress/zstandard/decode/block.zig b/lib/std/compress/zstandard/decode/block.zig index 2503180023..af87ec694f 100644 --- a/lib/std/compress/zstandard/decode/block.zig +++ b/lib/std/compress/zstandard/decode/block.zig @@ -795,6 +795,7 @@ pub fn decodeBlockReader( if (block_size_max < block_size) return error.BlockSizeOverMaximum; switch (block_header.block_type) { .raw => { + if (block_size == 0) return; const slice = dest.sliceAt(dest.write_index, block_size); try source.readNoEof(slice.first); try source.readNoEof(slice.second);