std.compress.zstd: fix raw/rle blocks not doing frame accounting

This commit is contained in:
Andrew Kelley 2025-05-01 15:38:05 -07:00
parent 439117d0d7
commit 2ff5ea2231
3 changed files with 61 additions and 63 deletions

View File

@ -123,19 +123,20 @@ test "decompression" {
try testExpectDecompress(uncompressed, compressed19);
}
test "zero sized block" {
test "zero sized raw block" {
const input_raw =
"\x28\xb5\x2f\xfd" ++ // zstandard frame magic number
"\x20\x00" ++ // frame header: only single_segment_flag set, frame_content_size zero
"\x01\x00\x00"; // block header with: last_block set, block_type raw, block_size zero
try testExpectDecompress("", input_raw);
}
test "zero sized rle block" {
const input_rle =
"\x28\xb5\x2f\xfd" ++ // zstandard frame magic number
"\x20\x00" ++ // frame header: only single_segment_flag set, frame_content_size zero
"\x03\x00\x00" ++ // block header with: last_block set, block_type rle, block_size zero
"\xaa"; // block_content
try testExpectDecompress("", input_raw);
try testExpectDecompress("", input_rle);
}

View File

@ -147,88 +147,85 @@ fn initFrame(d: *Decompress, window_size_max: usize, magic: Frame.Magic) !void {
fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: *State.InFrame) !usize {
const in = d.input;
var literal_fse_buffer: [zstd.table_size_max.literal]Table.Fse = undefined;
var match_fse_buffer: [zstd.table_size_max.match]Table.Fse = undefined;
var offset_fse_buffer: [zstd.table_size_max.offset]Table.Fse = undefined;
var literals_buffer: [zstd.block_size_max]u8 = undefined;
var sequence_buffer: [zstd.block_size_max]u8 = undefined;
var decode: Frame.Zstandard.Decode = .init(&literal_fse_buffer, &match_fse_buffer, &offset_fse_buffer);
const header_bytes = try in.takeArray(3);
const block_header: Frame.Zstandard.Block.Header = @bitCast(header_bytes.*);
const block_size = block_header.size;
if (state.frame.block_size_max < block_size) return error.BlockOversize;
if (@intFromEnum(limit) < block_size) return error.OutputBufferUndersize;
var bytes_written: usize = 0;
switch (block_header.type) {
.raw => {
try in.readAll(bw, .limited(block_size));
return block_size;
bytes_written = block_size;
},
.rle => {
const byte = try in.takeByte();
try bw.splatByteAll(byte, block_size);
return block_size;
bytes_written = block_size;
},
.compressed => {},
.reserved => return error.ReservedBlock,
}
.compressed => {
var literal_fse_buffer: [zstd.table_size_max.literal]Table.Fse = undefined;
var match_fse_buffer: [zstd.table_size_max.match]Table.Fse = undefined;
var offset_fse_buffer: [zstd.table_size_max.offset]Table.Fse = undefined;
var literals_buffer: [zstd.block_size_max]u8 = undefined;
var sequence_buffer: [zstd.block_size_max]u8 = undefined;
var decode: Frame.Zstandard.Decode = .init(&literal_fse_buffer, &match_fse_buffer, &offset_fse_buffer);
var remaining: Reader.Limit = .limited(block_size);
const literals = try LiteralsSection.decode(in, &remaining, &literals_buffer);
const sequences_header = try SequencesSection.Header.decode(in, &remaining);
var remaining: Reader.Limit = .limited(block_size);
const literals = try LiteralsSection.decode(in, &remaining, &literals_buffer);
const sequences_header = try SequencesSection.Header.decode(in, &remaining);
try decode.prepare(in, &remaining, literals, sequences_header);
try decode.prepare(in, &remaining, literals, sequences_header);
{
if (sequence_buffer.len < @intFromEnum(remaining))
return error.SequenceBufferUndersize;
const seq_slice = remaining.slice(&sequence_buffer);
try in.readSlice(seq_slice);
var bit_stream = try ReverseBitReader.init(seq_slice);
var bytes_written: usize = 0;
{
if (sequence_buffer.len < @intFromEnum(remaining))
return error.SequenceBufferUndersize;
const seq_slice = remaining.slice(&sequence_buffer);
try in.readSlice(seq_slice);
var bit_stream = try ReverseBitReader.init(seq_slice);
if (sequences_header.sequence_count > 0) {
try decode.readInitialFseState(&bit_stream);
if (sequences_header.sequence_count > 0) {
try decode.readInitialFseState(&bit_stream);
var sequence_size_limit = state.frame.block_size_max;
for (0..sequences_header.sequence_count) |i| {
const decompressed_size = try decode.decodeSequence(
bw,
&bit_stream,
sequence_size_limit,
i == sequences_header.sequence_count - 1,
);
sequence_size_limit -= decompressed_size;
bytes_written += decompressed_size;
}
}
var sequence_size_limit = state.frame.block_size_max;
for (0..sequences_header.sequence_count) |i| {
const decompressed_size = try decode.decodeSequence(
bw,
&bit_stream,
sequence_size_limit,
i == sequences_header.sequence_count - 1,
);
sequence_size_limit -= decompressed_size;
bytes_written += decompressed_size;
if (!bit_stream.isEmpty()) {
return error.MalformedCompressedBlock;
}
}
}
if (!bit_stream.isEmpty()) {
return error.MalformedCompressedBlock;
}
}
if (decode.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode.literal_written_count;
try decode.decodeLiterals(bw, len);
bytes_written += len;
}
if (decode.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode.literal_written_count;
try decode.decodeLiterals(bw, len);
bytes_written += len;
}
switch (decode.literal_header.block_type) {
.treeless, .compressed => {
if (!decode.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
},
.raw, .rle => {},
}
switch (decode.literal_header.block_type) {
.treeless, .compressed => {
if (!decode.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
if (bytes_written > state.frame.block_size_max) return error.BlockOversize;
if (remaining.nonzero()) return error.MalformedCompressedBlock;
state.decompressed_size += bytes_written;
if (state.frame.content_size) |size| {
if (state.decompressed_size > size) return error.MalformedFrame;
}
},
.raw, .rle => {},
}
if (bytes_written > state.frame.block_size_max) return error.BlockOversize;
if (remaining.nonzero()) return error.MalformedCompressedBlock;
state.decompressed_size += bytes_written;
if (state.frame.content_size) |size| {
if (state.decompressed_size > size) return error.MalformedFrame;
.reserved => return error.ReservedBlock,
}
if (state.frame.hasher_opt) |*hasher| {

View File

@ -730,7 +730,7 @@ pub fn fill(br: *BufferedReader, n: usize) Reader.Error!void {
}
if (seek > 0) {
const remainder = buffer[seek..];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
@memmove(buffer[0..remainder.len], remainder);
br.end = remainder.len;
br.seek = 0;
}