diff --git a/lib/std/compress/zstd.zig b/lib/std/compress/zstd.zig index cc4f2d38ef..0352a0e1f4 100644 --- a/lib/std/compress/zstd.zig +++ b/lib/std/compress/zstd.zig @@ -1,12 +1,11 @@ const std = @import("../std.zig"); const assert = std.debug.assert; +pub const Decompress = @import("zstd/Decompress.zig"); + /// Recommended amount by the standard. Lower than this may result in inability /// to decompress common streams. pub const default_window_len = 8 * 1024 * 1024; - -pub const Decompress = @import("zstd/Decompress.zig"); - pub const block_size_max = 1 << 17; pub const literals_length_default_distribution = [36]i16{ diff --git a/lib/std/compress/zstd/Decompress.zig b/lib/std/compress/zstd/Decompress.zig index 0c192331f1..d9952af657 100644 --- a/lib/std/compress/zstd/Decompress.zig +++ b/lib/std/compress/zstd/Decompress.zig @@ -10,6 +10,7 @@ input: *Reader, reader: Reader, state: State, verify_checksum: bool, +window_len: u32, err: ?Error = null, const State = union(enum) { @@ -29,6 +30,8 @@ pub const Options = struct { /// Verifying checksums is not implemented yet and will cause a panic if /// you set this to true. verify_checksum: bool = false, + /// Affects the minimum capacity of the provided buffer. + window_len: u32 = zstd.default_window_len, }; pub const Error = error{ @@ -65,11 +68,14 @@ pub const Error = error{ WindowSizeUnknown, }; +/// If buffer that is written to is not big enough, some streams will fail with +/// `error.OutputBufferUndersize`. A safe value is `zstd.default_window_len * 2`. pub fn init(input: *Reader, buffer: []u8, options: Options) Decompress { return .{ .input = input, .state = .new_frame, .verify_checksum = options.verify_checksum, + .window_len = options.window_len, .reader = .{ .vtable = &.{ .stream = stream }, .buffer = buffer, @@ -143,6 +149,7 @@ fn initFrame(d: *Decompress, window_size_max: usize, magic: Frame.Magic) !void { fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame) !usize { const in = d.input; + const window_len = d.window_len; const header_bytes = try in.takeArray(3); const block_header: Frame.Zstandard.Block.Header = @bitCast(header_bytes.*); @@ -153,12 +160,12 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame) var bytes_written: usize = 0; switch (block_header.type) { .raw => { - try in.streamExact(w, block_size); + try in.streamExactPreserve(w, window_len, block_size); bytes_written = block_size; }, .rle => { const byte = try in.takeByte(); - try w.splatByteAll(byte, block_size); + try w.splatBytePreserve(window_len, byte, block_size); bytes_written = block_size; }, .compressed => { @@ -167,7 +174,7 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame) var offset_fse_buffer: [zstd.table_size_max.offset]Table.Fse = undefined; var literals_buffer: [zstd.block_size_max]u8 = undefined; var sequence_buffer: [zstd.block_size_max]u8 = undefined; - var decode: Frame.Zstandard.Decode = .init(&literal_fse_buffer, &match_fse_buffer, &offset_fse_buffer); + var decode: Frame.Zstandard.Decode = .init(&literal_fse_buffer, &match_fse_buffer, &offset_fse_buffer, window_len); var remaining: Limit = .limited(block_size); const literals = try LiteralsSection.decode(in, &remaining, &literals_buffer); const sequences_header = try SequencesSection.Header.decode(in, &remaining); @@ -185,15 +192,16 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame) try decode.readInitialFseState(&bit_stream); // Ensures the following calls to `decodeSequence` will not flush. - if (frame_block_size_max > w.buffer.len) return error.OutputBufferUndersize; - const dest = (try w.writableSliceGreedy(frame_block_size_max))[0..frame_block_size_max]; + if (window_len + frame_block_size_max > w.buffer.len) return error.OutputBufferUndersize; + const dest = (try w.writableSliceGreedyPreserve(window_len, frame_block_size_max))[0..frame_block_size_max]; + const write_pos = dest.ptr - w.buffer.ptr; for (0..sequences_header.sequence_count - 1) |_| { - bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream); + bytes_written += try decode.decodeSequence(w.buffer, write_pos + bytes_written, &bit_stream); try decode.updateState(.literal, &bit_stream); try decode.updateState(.match, &bit_stream); try decode.updateState(.offset, &bit_stream); } - bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream); + bytes_written += try decode.decodeSequence(w.buffer, write_pos + bytes_written, &bit_stream); if (bytes_written > dest.len) return error.MalformedSequence; w.advance(bytes_written); } @@ -363,6 +371,7 @@ pub const Frame = struct { }; pub const Decode = struct { + window_len: u32, repeat_offsets: [3]u32, offset: StateData(8), @@ -397,8 +406,10 @@ pub const Frame = struct { literal_fse_buffer: []Table.Fse, match_fse_buffer: []Table.Fse, offset_fse_buffer: []Table.Fse, + window_len: u32, ) Decode { return .{ + .window_len = window_len, .repeat_offsets = .{ zstd.start_repeated_offset_1, zstd.start_repeated_offset_2, @@ -698,19 +709,19 @@ pub const Frame = struct { }; } - /// Decode `len` bytes of literals into `dest`. - fn decodeLiterals(self: *Decode, dest: *Writer, len: usize) !void { - switch (self.literal_header.block_type) { + /// Decode `len` bytes of literals into `w`. + fn decodeLiterals(d: *Decode, w: *Writer, len: usize) !void { + switch (d.literal_header.block_type) { .raw => { - try dest.writeAll(self.literal_streams.one[self.literal_written_count..][0..len]); + try w.writeAll(d.literal_streams.one[d.literal_written_count..][0..len]); }, .rle => { - try dest.splatByteAll(self.literal_streams.one[0], len); + try w.splatByteAll(d.literal_streams.one[0], len); }, .compressed, .treeless => { - if (len > dest.buffer.len) return error.OutputBufferUndersize; - const buf = try dest.writableSlice(len); - const huffman_tree = self.huffman_tree.?; + if (len > w.buffer.len) return error.OutputBufferUndersize; + const buf = try w.writableSlice(len); + const huffman_tree = d.huffman_tree.?; const max_bit_count = huffman_tree.max_bit_count; const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount( huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight, @@ -722,7 +733,7 @@ pub const Frame = struct { for (buf) |*out| { var prefix: u16 = 0; while (true) { - const new_bits = try self.readLiteralsBits(bit_count_to_read); + const new_bits = try d.readLiteralsBits(bit_count_to_read); prefix <<= bit_count_to_read; prefix |= new_bits; bits_read += bit_count_to_read;