diff --git a/lib/std/compress/zstd.zig b/lib/std/compress/zstd.zig index 2209d2d936..2ed25e0931 100644 --- a/lib/std/compress/zstd.zig +++ b/lib/std/compress/zstd.zig @@ -81,12 +81,11 @@ pub const table_size_max = struct { fn testDecompress(gpa: std.mem.Allocator, compressed: []const u8) ![]u8 { var out: std.ArrayListUnmanaged(u8) = .empty; defer out.deinit(gpa); - try out.ensureUnusedCapacity(gpa, default_window_len); var in: std.io.BufferedReader = undefined; in.initFixed(@constCast(compressed)); var zstd_stream: Decompress = .init(&in, .{}); - try zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited); + try zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited, default_window_len); return out.toOwnedSlice(gpa); } @@ -103,16 +102,18 @@ fn testExpectDecompressError(err: anyerror, compressed: []const u8) !void { var out: std.ArrayListUnmanaged(u8) = .empty; defer out.deinit(gpa); - try out.ensureUnusedCapacity(gpa, default_window_len); var in: std.io.BufferedReader = undefined; in.initFixed(@constCast(compressed)); var zstd_stream: Decompress = .init(&in, .{}); - try std.testing.expectError(error.ReadFailed, zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited)); + try std.testing.expectError( + error.ReadFailed, + zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited, default_window_len), + ); try std.testing.expectError(err, zstd_stream.err orelse {}); } -test "decompression" { +test Decompress { const uncompressed = @embedFile("testdata/rfc8478.txt"); const compressed3 = @embedFile("testdata/rfc8478.txt.zst.3"); const compressed19 = @embedFile("testdata/rfc8478.txt.zst.19"); diff --git a/lib/std/compress/zstd/Decompress.zig b/lib/std/compress/zstd/Decompress.zig index a333405890..16a5db4bfc 100644 --- a/lib/std/compress/zstd/Decompress.zig +++ b/lib/std/compress/zstd/Decompress.zig @@ -149,7 +149,8 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: const header_bytes = try in.takeArray(3); const block_header: Frame.Zstandard.Block.Header = @bitCast(header_bytes.*); const block_size = block_header.size; - if (state.frame.block_size_max < block_size) return error.BlockOversize; + const frame_block_size_max = state.frame.block_size_max; + if (frame_block_size_max < block_size) return error.BlockOversize; if (@intFromEnum(limit) < block_size) return error.OutputBufferUndersize; var bytes_written: usize = 0; switch (block_header.type) { @@ -185,17 +186,18 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: if (sequences_header.sequence_count > 0) { try decode.readInitialFseState(&bit_stream); - var sequence_size_limit = state.frame.block_size_max; - for (0..sequences_header.sequence_count) |i| { - const decompressed_size = try decode.decodeSequence( - bw, - &bit_stream, - sequence_size_limit, - i == sequences_header.sequence_count - 1, - ); - sequence_size_limit -= decompressed_size; - bytes_written += decompressed_size; + // Ensures the following calls to `decodeSequence` will not flush. + if (frame_block_size_max > bw.buffer.len) return error.OutputBufferUndersize; + const dest = (try bw.writableSliceGreedy(frame_block_size_max))[0..frame_block_size_max]; + for (0..sequences_header.sequence_count - 1) |_| { + bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream); + try decode.updateState(.literal, &bit_stream); + try decode.updateState(.match, &bit_stream); + try decode.updateState(.offset, &bit_stream); } + bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream); + if (bytes_written > dest.len) return error.MalformedSequence; + bw.advance(bytes_written); } if (!bit_stream.isEmpty()) { @@ -206,6 +208,7 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: if (decode.literal_written_count < literals.header.regenerated_size) { const len = literals.header.regenerated_size - decode.literal_written_count; try decode.decodeLiterals(bw, len); + decode.literal_written_count += len; bytes_written += len; } @@ -216,8 +219,7 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: .raw, .rle => {}, } - if (bytes_written > state.frame.block_size_max) return error.BlockOversize; - if (remaining.nonzero()) return error.MalformedCompressedBlock; + if (bytes_written > frame_block_size_max) return error.BlockOversize; state.decompressed_size += bytes_written; if (state.frame.content_size) |size| { @@ -649,63 +651,35 @@ pub const Frame = struct { }; } - const DecodeSequenceError = error{ - /// The decompressed sequence would be longer than - /// `sequence_size_limit` or the sequence's offset is too large - MalformedSequence, - /// The decoder state's literal streams do not contain enough - /// literals for the sequence (this may mean the literal stream or the - /// sequence is malformed). - UnexpectedEndOfLiteralStream, - /// The FSE sequence bitstream is malformed - InvalidBitStream, - /// `bit_reader` does not contain enough bits - EndOfStream, - /// The `BufferedWriter` storage capacity is not large enough to - /// accept this stream. - OutputBufferUndersize, - WriteFailed, - MalformedLiteralsLength, - MalformedFseBits, - MissingStartBit, - HuffmanTreeIncomplete, - }; - /// Decode one sequence from `bit_reader` into `dest`. Updates FSE states /// if `last_sequence` is `false`. Assumes `prepare` called for the block /// before attempting to decode sequences. fn decodeSequence( - self: *Decode, - dest: *BufferedWriter, + decode: *Decode, + dest: []u8, + write_pos: usize, bit_reader: *ReverseBitReader, - sequence_size_limit: usize, - last_sequence: bool, - ) DecodeSequenceError!usize { - const sequence = try self.nextSequence(bit_reader); + ) !usize { + const sequence = try decode.nextSequence(bit_reader); const literal_length: usize = sequence.literal_length; const match_length: usize = sequence.match_length; const sequence_length = literal_length + match_length; - if (sequence_length > sequence_size_limit) return error.MalformedSequence; - if (sequence_length > dest.buffer.len) return error.OutputBufferUndersize; - if (sequence.offset > literal_length) return error.MalformedSequence; - // Ensures the following call to `decodeLiterals` will not cause a - // flush and therefore be at the beginning of `out`. - const out = try dest.writableSlice(sequence_length); - const expected_end = dest.end + out.len; - try decodeLiterals(self, dest, literal_length); - @memmove( - out[literal_length..][0..match_length], - out[literal_length - sequence.offset ..][0..match_length], - ); - dest.advance(match_length); - assert(dest.end == expected_end); + const copy_start = std.math.sub(usize, write_pos + sequence.literal_length, sequence.offset) catch + return error.MalformedSequence; - if (!last_sequence) { - try self.updateState(.literal, bit_reader); - try self.updateState(.match, bit_reader); - try self.updateState(.offset, bit_reader); - } + if (decode.literal_written_count + literal_length > decode.literal_header.regenerated_size) + return error.MalformedLiteralsLength; + var sub_bw: BufferedWriter = undefined; + sub_bw.initFixed(dest[write_pos..]); + try decodeLiterals(decode, &sub_bw, literal_length); + decode.literal_written_count += literal_length; + // This is not a @memmove; it intentionally repeats patterns + // caused by iterating one byte at a time. + for ( + dest[write_pos + literal_length ..][0..match_length], + dest[copy_start..][0..match_length], + ) |*d, s| d.* = s; return sequence_length; } @@ -744,31 +718,14 @@ pub const Frame = struct { }; } - const DecodeLiteralsError = error{ - /// The number of literal bytes decoded by `self` plus `len` is greater - /// than the regenerated size of `literals` - MalformedLiteralsLength, - /// Problems decoding Huffman compressed literals - UnexpectedEndOfLiteralStream, - OutputBufferUndersize, - WriteFailed, - MissingStartBit, - HuffmanTreeIncomplete, - }; - /// Decode `len` bytes of literals into `dest`. - pub fn decodeLiterals(self: *Decode, dest: *BufferedWriter, len: usize) DecodeLiteralsError!void { - if (self.literal_written_count + len > self.literal_header.regenerated_size) - return error.MalformedLiteralsLength; - + fn decodeLiterals(self: *Decode, dest: *BufferedWriter, len: usize) !void { switch (self.literal_header.block_type) { .raw => { try dest.writeAll(self.literal_streams.one[self.literal_written_count..][0..len]); - self.literal_written_count += len; }, .rle => { try dest.splatByteAll(self.literal_streams.one[0], len); - self.literal_written_count += len; }, .compressed, .treeless => { if (len > dest.buffer.len) return error.OutputBufferUndersize; @@ -810,7 +767,6 @@ pub const Frame = struct { } } } - self.literal_written_count += len; }, } }