std.compress.zstd: tests passing

This commit is contained in:
Andrew Kelley 2025-05-01 18:23:21 -07:00
parent 990031444d
commit 0d29c78af5
2 changed files with 41 additions and 84 deletions

View File

@ -81,12 +81,11 @@ pub const table_size_max = struct {
fn testDecompress(gpa: std.mem.Allocator, compressed: []const u8) ![]u8 {
var out: std.ArrayListUnmanaged(u8) = .empty;
defer out.deinit(gpa);
try out.ensureUnusedCapacity(gpa, default_window_len);
var in: std.io.BufferedReader = undefined;
in.initFixed(@constCast(compressed));
var zstd_stream: Decompress = .init(&in, .{});
try zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited);
try zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited, default_window_len);
return out.toOwnedSlice(gpa);
}
@ -103,16 +102,18 @@ fn testExpectDecompressError(err: anyerror, compressed: []const u8) !void {
var out: std.ArrayListUnmanaged(u8) = .empty;
defer out.deinit(gpa);
try out.ensureUnusedCapacity(gpa, default_window_len);
var in: std.io.BufferedReader = undefined;
in.initFixed(@constCast(compressed));
var zstd_stream: Decompress = .init(&in, .{});
try std.testing.expectError(error.ReadFailed, zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited));
try std.testing.expectError(
error.ReadFailed,
zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited, default_window_len),
);
try std.testing.expectError(err, zstd_stream.err orelse {});
}
test "decompression" {
test Decompress {
const uncompressed = @embedFile("testdata/rfc8478.txt");
const compressed3 = @embedFile("testdata/rfc8478.txt.zst.3");
const compressed19 = @embedFile("testdata/rfc8478.txt.zst.19");

View File

@ -149,7 +149,8 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state:
const header_bytes = try in.takeArray(3);
const block_header: Frame.Zstandard.Block.Header = @bitCast(header_bytes.*);
const block_size = block_header.size;
if (state.frame.block_size_max < block_size) return error.BlockOversize;
const frame_block_size_max = state.frame.block_size_max;
if (frame_block_size_max < block_size) return error.BlockOversize;
if (@intFromEnum(limit) < block_size) return error.OutputBufferUndersize;
var bytes_written: usize = 0;
switch (block_header.type) {
@ -185,17 +186,18 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state:
if (sequences_header.sequence_count > 0) {
try decode.readInitialFseState(&bit_stream);
var sequence_size_limit = state.frame.block_size_max;
for (0..sequences_header.sequence_count) |i| {
const decompressed_size = try decode.decodeSequence(
bw,
&bit_stream,
sequence_size_limit,
i == sequences_header.sequence_count - 1,
);
sequence_size_limit -= decompressed_size;
bytes_written += decompressed_size;
// Ensures the following calls to `decodeSequence` will not flush.
if (frame_block_size_max > bw.buffer.len) return error.OutputBufferUndersize;
const dest = (try bw.writableSliceGreedy(frame_block_size_max))[0..frame_block_size_max];
for (0..sequences_header.sequence_count - 1) |_| {
bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream);
try decode.updateState(.literal, &bit_stream);
try decode.updateState(.match, &bit_stream);
try decode.updateState(.offset, &bit_stream);
}
bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream);
if (bytes_written > dest.len) return error.MalformedSequence;
bw.advance(bytes_written);
}
if (!bit_stream.isEmpty()) {
@ -206,6 +208,7 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state:
if (decode.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode.literal_written_count;
try decode.decodeLiterals(bw, len);
decode.literal_written_count += len;
bytes_written += len;
}
@ -216,8 +219,7 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state:
.raw, .rle => {},
}
if (bytes_written > state.frame.block_size_max) return error.BlockOversize;
if (remaining.nonzero()) return error.MalformedCompressedBlock;
if (bytes_written > frame_block_size_max) return error.BlockOversize;
state.decompressed_size += bytes_written;
if (state.frame.content_size) |size| {
@ -649,63 +651,35 @@ pub const Frame = struct {
};
}
const DecodeSequenceError = error{
/// The decompressed sequence would be longer than
/// `sequence_size_limit` or the sequence's offset is too large
MalformedSequence,
/// The decoder state's literal streams do not contain enough
/// literals for the sequence (this may mean the literal stream or the
/// sequence is malformed).
UnexpectedEndOfLiteralStream,
/// The FSE sequence bitstream is malformed
InvalidBitStream,
/// `bit_reader` does not contain enough bits
EndOfStream,
/// The `BufferedWriter` storage capacity is not large enough to
/// accept this stream.
OutputBufferUndersize,
WriteFailed,
MalformedLiteralsLength,
MalformedFseBits,
MissingStartBit,
HuffmanTreeIncomplete,
};
/// Decode one sequence from `bit_reader` into `dest`. Updates FSE states
/// if `last_sequence` is `false`. Assumes `prepare` called for the block
/// before attempting to decode sequences.
fn decodeSequence(
self: *Decode,
dest: *BufferedWriter,
decode: *Decode,
dest: []u8,
write_pos: usize,
bit_reader: *ReverseBitReader,
sequence_size_limit: usize,
last_sequence: bool,
) DecodeSequenceError!usize {
const sequence = try self.nextSequence(bit_reader);
) !usize {
const sequence = try decode.nextSequence(bit_reader);
const literal_length: usize = sequence.literal_length;
const match_length: usize = sequence.match_length;
const sequence_length = literal_length + match_length;
if (sequence_length > sequence_size_limit) return error.MalformedSequence;
if (sequence_length > dest.buffer.len) return error.OutputBufferUndersize;
if (sequence.offset > literal_length) return error.MalformedSequence;
// Ensures the following call to `decodeLiterals` will not cause a
// flush and therefore be at the beginning of `out`.
const out = try dest.writableSlice(sequence_length);
const expected_end = dest.end + out.len;
try decodeLiterals(self, dest, literal_length);
@memmove(
out[literal_length..][0..match_length],
out[literal_length - sequence.offset ..][0..match_length],
);
dest.advance(match_length);
assert(dest.end == expected_end);
const copy_start = std.math.sub(usize, write_pos + sequence.literal_length, sequence.offset) catch
return error.MalformedSequence;
if (!last_sequence) {
try self.updateState(.literal, bit_reader);
try self.updateState(.match, bit_reader);
try self.updateState(.offset, bit_reader);
}
if (decode.literal_written_count + literal_length > decode.literal_header.regenerated_size)
return error.MalformedLiteralsLength;
var sub_bw: BufferedWriter = undefined;
sub_bw.initFixed(dest[write_pos..]);
try decodeLiterals(decode, &sub_bw, literal_length);
decode.literal_written_count += literal_length;
// This is not a @memmove; it intentionally repeats patterns
// caused by iterating one byte at a time.
for (
dest[write_pos + literal_length ..][0..match_length],
dest[copy_start..][0..match_length],
) |*d, s| d.* = s;
return sequence_length;
}
@ -744,31 +718,14 @@ pub const Frame = struct {
};
}
const DecodeLiteralsError = error{
/// The number of literal bytes decoded by `self` plus `len` is greater
/// than the regenerated size of `literals`
MalformedLiteralsLength,
/// Problems decoding Huffman compressed literals
UnexpectedEndOfLiteralStream,
OutputBufferUndersize,
WriteFailed,
MissingStartBit,
HuffmanTreeIncomplete,
};
/// Decode `len` bytes of literals into `dest`.
pub fn decodeLiterals(self: *Decode, dest: *BufferedWriter, len: usize) DecodeLiteralsError!void {
if (self.literal_written_count + len > self.literal_header.regenerated_size)
return error.MalformedLiteralsLength;
fn decodeLiterals(self: *Decode, dest: *BufferedWriter, len: usize) !void {
switch (self.literal_header.block_type) {
.raw => {
try dest.writeAll(self.literal_streams.one[self.literal_written_count..][0..len]);
self.literal_written_count += len;
},
.rle => {
try dest.splatByteAll(self.literal_streams.one[0], len);
self.literal_written_count += len;
},
.compressed, .treeless => {
if (len > dest.buffer.len) return error.OutputBufferUndersize;
@ -810,7 +767,6 @@ pub const Frame = struct {
}
}
}
self.literal_written_count += len;
},
}
}