std.compress.zstandard: add error condition to ring buffer decoding

Previously `executeSequenceRingBuffer()` would not verify the offset
against the number of bytes already decoded, so it would happily copy
garbage bytes rather than return an error before the window was filled.

To fix this a new `written_count` is added to the decode state that
tracks the total number of bytes decoded.
This commit is contained in:
dweiller 2023-02-12 22:04:07 +11:00
parent 5a31fc2014
commit a53cf299a6

View File

@ -45,6 +45,7 @@ pub const DecodeState = struct {
huffman_tree: ?LiteralsSection.HuffmanTree,
literal_written_count: usize,
written_count: usize = 0,
fn StateData(comptime max_accuracy_log: comptime_int) type {
return struct {
@ -84,6 +85,8 @@ pub const DecodeState = struct {
.literal_stream_reader = undefined,
.literal_stream_index = undefined,
.huffman_tree = null,
.written_count = 0,
};
}
@ -296,6 +299,7 @@ pub const DecodeState = struct {
// NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr
// to allow repeats
std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]);
self.written_count += sequence.match_length;
}
fn executeSequenceRingBuffer(
@ -303,7 +307,8 @@ pub const DecodeState = struct {
dest: *RingBuffer,
sequence: Sequence,
) (error{MalformedSequence} || DecodeLiteralsError)!void {
if (sequence.offset > dest.data.len) return error.MalformedSequence;
if (sequence.offset > @min(dest.data.len, self.written_count + sequence.literal_length))
return error.MalformedSequence;
try self.decodeLiteralsRingBuffer(dest, sequence.literal_length);
const copy_start = dest.write_index + dest.data.len - sequence.offset;
@ -311,6 +316,7 @@ pub const DecodeState = struct {
// TODO: would std.mem.copy and figuring out dest slice be better/faster?
for (copy_slice.first) |b| dest.writeAssumeCapacity(b);
for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
self.written_count += sequence.match_length;
}
const DecodeSequenceError = error{
@ -444,6 +450,7 @@ pub const DecodeState = struct {
const literal_data = self.literal_streams.one[self.literal_written_count..literals_end];
std.mem.copy(u8, dest, literal_data);
self.literal_written_count += len;
self.written_count += len;
},
.rle => {
var i: usize = 0;
@ -451,6 +458,7 @@ pub const DecodeState = struct {
dest[i] = self.literal_streams.one[0];
}
self.literal_written_count += len;
self.written_count += len;
},
.compressed, .treeless => {
// const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
@ -497,6 +505,7 @@ pub const DecodeState = struct {
}
}
self.literal_written_count += len;
self.written_count += len;
},
}
}
@ -516,6 +525,7 @@ pub const DecodeState = struct {
const literal_data = self.literal_streams.one[self.literal_written_count..literals_end];
dest.writeSliceAssumeCapacity(literal_data);
self.literal_written_count += len;
self.written_count += len;
},
.rle => {
var i: usize = 0;
@ -523,6 +533,7 @@ pub const DecodeState = struct {
dest.writeAssumeCapacity(self.literal_streams.one[0]);
}
self.literal_written_count += len;
self.written_count += len;
},
.compressed, .treeless => {
// const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
@ -565,6 +576,7 @@ pub const DecodeState = struct {
}
}
self.literal_written_count += len;
self.written_count += len;
},
}
}
@ -612,6 +624,7 @@ pub fn decodeBlock(
const data = src[0..block_size];
std.mem.copy(u8, dest[written_count..], data);
consumed_count.* += block_size;
decode_state.written_count += block_size;
return block_size;
},
.rle => {
@ -622,6 +635,7 @@ pub fn decodeBlock(
dest[write_pos] = src[0];
}
consumed_count.* += 1;
decode_state.written_count += block_size;
return block_size;
},
.compressed => {
@ -712,6 +726,7 @@ pub fn decodeBlockRingBuffer(
const data = src[0..block_size];
dest.writeSliceAssumeCapacity(data);
consumed_count.* += block_size;
decode_state.written_count += block_size;
return block_size;
},
.rle => {
@ -721,6 +736,7 @@ pub fn decodeBlockRingBuffer(
dest.writeAssumeCapacity(src[0]);
}
consumed_count.* += 1;
decode_state.written_count += block_size;
return block_size;
},
.compressed => {
@ -814,6 +830,7 @@ pub fn decodeBlockReader(
try source.readNoEof(slice.first);
try source.readNoEof(slice.second);
dest.write_index = dest.mask2(dest.write_index + block_size);
decode_state.written_count += block_size;
},
.rle => {
const byte = try source.readByte();
@ -821,6 +838,7 @@ pub fn decodeBlockReader(
while (i < block_size) : (i += 1) {
dest.writeAssumeCapacity(byte);
}
decode_state.written_count += block_size;
},
.compressed => {
const literals = try decodeLiteralsSection(block_reader, literals_buffer);