diff --git a/lib/std/compress/zstandard/decode/block.zig b/lib/std/compress/zstandard/decode/block.zig index 18fbf289e2..5a197b4edd 100644 --- a/lib/std/compress/zstandard/decode/block.zig +++ b/lib/std/compress/zstandard/decode/block.zig @@ -45,6 +45,7 @@ pub const DecodeState = struct { huffman_tree: ?LiteralsSection.HuffmanTree, literal_written_count: usize, + written_count: usize = 0, fn StateData(comptime max_accuracy_log: comptime_int) type { return struct { @@ -84,6 +85,8 @@ pub const DecodeState = struct { .literal_stream_reader = undefined, .literal_stream_index = undefined, .huffman_tree = null, + + .written_count = 0, }; } @@ -296,6 +299,7 @@ pub const DecodeState = struct { // NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr // to allow repeats std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]); + self.written_count += sequence.match_length; } fn executeSequenceRingBuffer( @@ -303,7 +307,8 @@ pub const DecodeState = struct { dest: *RingBuffer, sequence: Sequence, ) (error{MalformedSequence} || DecodeLiteralsError)!void { - if (sequence.offset > dest.data.len) return error.MalformedSequence; + if (sequence.offset > @min(dest.data.len, self.written_count + sequence.literal_length)) + return error.MalformedSequence; try self.decodeLiteralsRingBuffer(dest, sequence.literal_length); const copy_start = dest.write_index + dest.data.len - sequence.offset; @@ -311,6 +316,7 @@ pub const DecodeState = struct { // TODO: would std.mem.copy and figuring out dest slice be better/faster? for (copy_slice.first) |b| dest.writeAssumeCapacity(b); for (copy_slice.second) |b| dest.writeAssumeCapacity(b); + self.written_count += sequence.match_length; } const DecodeSequenceError = error{ @@ -444,6 +450,7 @@ pub const DecodeState = struct { const literal_data = self.literal_streams.one[self.literal_written_count..literals_end]; std.mem.copy(u8, dest, literal_data); self.literal_written_count += len; + self.written_count += len; }, .rle => { var i: usize = 0; @@ -451,6 +458,7 @@ pub const DecodeState = struct { dest[i] = self.literal_streams.one[0]; } self.literal_written_count += len; + self.written_count += len; }, .compressed, .treeless => { // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; @@ -497,6 +505,7 @@ pub const DecodeState = struct { } } self.literal_written_count += len; + self.written_count += len; }, } } @@ -516,6 +525,7 @@ pub const DecodeState = struct { const literal_data = self.literal_streams.one[self.literal_written_count..literals_end]; dest.writeSliceAssumeCapacity(literal_data); self.literal_written_count += len; + self.written_count += len; }, .rle => { var i: usize = 0; @@ -523,6 +533,7 @@ pub const DecodeState = struct { dest.writeAssumeCapacity(self.literal_streams.one[0]); } self.literal_written_count += len; + self.written_count += len; }, .compressed, .treeless => { // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; @@ -565,6 +576,7 @@ pub const DecodeState = struct { } } self.literal_written_count += len; + self.written_count += len; }, } } @@ -612,6 +624,7 @@ pub fn decodeBlock( const data = src[0..block_size]; std.mem.copy(u8, dest[written_count..], data); consumed_count.* += block_size; + decode_state.written_count += block_size; return block_size; }, .rle => { @@ -622,6 +635,7 @@ pub fn decodeBlock( dest[write_pos] = src[0]; } consumed_count.* += 1; + decode_state.written_count += block_size; return block_size; }, .compressed => { @@ -712,6 +726,7 @@ pub fn decodeBlockRingBuffer( const data = src[0..block_size]; dest.writeSliceAssumeCapacity(data); consumed_count.* += block_size; + decode_state.written_count += block_size; return block_size; }, .rle => { @@ -721,6 +736,7 @@ pub fn decodeBlockRingBuffer( dest.writeAssumeCapacity(src[0]); } consumed_count.* += 1; + decode_state.written_count += block_size; return block_size; }, .compressed => { @@ -814,6 +830,7 @@ pub fn decodeBlockReader( try source.readNoEof(slice.first); try source.readNoEof(slice.second); dest.write_index = dest.mask2(dest.write_index + block_size); + decode_state.written_count += block_size; }, .rle => { const byte = try source.readByte(); @@ -821,6 +838,7 @@ pub fn decodeBlockReader( while (i < block_size) : (i += 1) { dest.writeAssumeCapacity(byte); } + decode_state.written_count += block_size; }, .compressed => { const literals = try decodeLiteralsSection(block_reader, literals_buffer);