std.compress.zstd: respect the window length

This commit is contained in:
Andrew Kelley 2025-07-24 23:31:00 -07:00
parent 7f1c04423e
commit ee4f5b3f92
2 changed files with 29 additions and 19 deletions

View File

@ -1,12 +1,11 @@
const std = @import("../std.zig");
const assert = std.debug.assert;
pub const Decompress = @import("zstd/Decompress.zig");
/// Recommended amount by the standard. Lower than this may result in inability
/// to decompress common streams.
pub const default_window_len = 8 * 1024 * 1024;
pub const Decompress = @import("zstd/Decompress.zig");
pub const block_size_max = 1 << 17;
pub const literals_length_default_distribution = [36]i16{

View File

@ -10,6 +10,7 @@ input: *Reader,
reader: Reader,
state: State,
verify_checksum: bool,
window_len: u32,
err: ?Error = null,
const State = union(enum) {
@ -29,6 +30,8 @@ pub const Options = struct {
/// Verifying checksums is not implemented yet and will cause a panic if
/// you set this to true.
verify_checksum: bool = false,
/// Affects the minimum capacity of the provided buffer.
window_len: u32 = zstd.default_window_len,
};
pub const Error = error{
@ -65,11 +68,14 @@ pub const Error = error{
WindowSizeUnknown,
};
/// If buffer that is written to is not big enough, some streams will fail with
/// `error.OutputBufferUndersize`. A safe value is `zstd.default_window_len * 2`.
pub fn init(input: *Reader, buffer: []u8, options: Options) Decompress {
return .{
.input = input,
.state = .new_frame,
.verify_checksum = options.verify_checksum,
.window_len = options.window_len,
.reader = .{
.vtable = &.{ .stream = stream },
.buffer = buffer,
@ -143,6 +149,7 @@ fn initFrame(d: *Decompress, window_size_max: usize, magic: Frame.Magic) !void {
fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame) !usize {
const in = d.input;
const window_len = d.window_len;
const header_bytes = try in.takeArray(3);
const block_header: Frame.Zstandard.Block.Header = @bitCast(header_bytes.*);
@ -153,12 +160,12 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame)
var bytes_written: usize = 0;
switch (block_header.type) {
.raw => {
try in.streamExact(w, block_size);
try in.streamExactPreserve(w, window_len, block_size);
bytes_written = block_size;
},
.rle => {
const byte = try in.takeByte();
try w.splatByteAll(byte, block_size);
try w.splatBytePreserve(window_len, byte, block_size);
bytes_written = block_size;
},
.compressed => {
@ -167,7 +174,7 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame)
var offset_fse_buffer: [zstd.table_size_max.offset]Table.Fse = undefined;
var literals_buffer: [zstd.block_size_max]u8 = undefined;
var sequence_buffer: [zstd.block_size_max]u8 = undefined;
var decode: Frame.Zstandard.Decode = .init(&literal_fse_buffer, &match_fse_buffer, &offset_fse_buffer);
var decode: Frame.Zstandard.Decode = .init(&literal_fse_buffer, &match_fse_buffer, &offset_fse_buffer, window_len);
var remaining: Limit = .limited(block_size);
const literals = try LiteralsSection.decode(in, &remaining, &literals_buffer);
const sequences_header = try SequencesSection.Header.decode(in, &remaining);
@ -185,15 +192,16 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame)
try decode.readInitialFseState(&bit_stream);
// Ensures the following calls to `decodeSequence` will not flush.
if (frame_block_size_max > w.buffer.len) return error.OutputBufferUndersize;
const dest = (try w.writableSliceGreedy(frame_block_size_max))[0..frame_block_size_max];
if (window_len + frame_block_size_max > w.buffer.len) return error.OutputBufferUndersize;
const dest = (try w.writableSliceGreedyPreserve(window_len, frame_block_size_max))[0..frame_block_size_max];
const write_pos = dest.ptr - w.buffer.ptr;
for (0..sequences_header.sequence_count - 1) |_| {
bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream);
bytes_written += try decode.decodeSequence(w.buffer, write_pos + bytes_written, &bit_stream);
try decode.updateState(.literal, &bit_stream);
try decode.updateState(.match, &bit_stream);
try decode.updateState(.offset, &bit_stream);
}
bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream);
bytes_written += try decode.decodeSequence(w.buffer, write_pos + bytes_written, &bit_stream);
if (bytes_written > dest.len) return error.MalformedSequence;
w.advance(bytes_written);
}
@ -363,6 +371,7 @@ pub const Frame = struct {
};
pub const Decode = struct {
window_len: u32,
repeat_offsets: [3]u32,
offset: StateData(8),
@ -397,8 +406,10 @@ pub const Frame = struct {
literal_fse_buffer: []Table.Fse,
match_fse_buffer: []Table.Fse,
offset_fse_buffer: []Table.Fse,
window_len: u32,
) Decode {
return .{
.window_len = window_len,
.repeat_offsets = .{
zstd.start_repeated_offset_1,
zstd.start_repeated_offset_2,
@ -698,19 +709,19 @@ pub const Frame = struct {
};
}
/// Decode `len` bytes of literals into `dest`.
fn decodeLiterals(self: *Decode, dest: *Writer, len: usize) !void {
switch (self.literal_header.block_type) {
/// Decode `len` bytes of literals into `w`.
fn decodeLiterals(d: *Decode, w: *Writer, len: usize) !void {
switch (d.literal_header.block_type) {
.raw => {
try dest.writeAll(self.literal_streams.one[self.literal_written_count..][0..len]);
try w.writeAll(d.literal_streams.one[d.literal_written_count..][0..len]);
},
.rle => {
try dest.splatByteAll(self.literal_streams.one[0], len);
try w.splatByteAll(d.literal_streams.one[0], len);
},
.compressed, .treeless => {
if (len > dest.buffer.len) return error.OutputBufferUndersize;
const buf = try dest.writableSlice(len);
const huffman_tree = self.huffman_tree.?;
if (len > w.buffer.len) return error.OutputBufferUndersize;
const buf = try w.writableSlice(len);
const huffman_tree = d.huffman_tree.?;
const max_bit_count = huffman_tree.max_bit_count;
const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
@ -722,7 +733,7 @@ pub const Frame = struct {
for (buf) |*out| {
var prefix: u16 = 0;
while (true) {
const new_bits = try self.readLiteralsBits(bit_count_to_read);
const new_bits = try d.readLiteralsBits(bit_count_to_read);
prefix <<= bit_count_to_read;
prefix |= new_bits;
bits_read += bit_count_to_read;