std.compress.zstandard: add functions decoding into ring buffer

This supports decoding frames that do not declare the content size or
decoding in a streaming fashion.
This commit is contained in:
dweiller 2023-01-22 16:11:47 +11:00
parent 18091723d5
commit 05e63f241e
2 changed files with 272 additions and 3 deletions

View File

@ -0,0 +1,81 @@
//! This ring buffer stores read and write indices while being able to utilise the full
//! backing slice by incrementing the indices modulo twice the slice's length and reducing
//! indices modulo the slice's length on slice access. This means that the bit of information
//! distinguishing whether the buffer is full or empty in an implementation utilising
//! and extra flag is stored in difference of the indices.
const assert = @import("std").debug.assert;
const RingBuffer = @This();
data: []u8,
read_index: usize,
write_index: usize,
pub fn mask(self: RingBuffer, index: usize) usize {
return index % self.data.len;
}
pub fn mask2(self: RingBuffer, index: usize) usize {
return index % (2 * self.data.len);
}
pub fn write(self: *RingBuffer, byte: u8) !void {
if (self.isFull()) return error.Full;
self.writeAssumeCapacity(byte);
}
pub fn writeAssumeCapacity(self: *RingBuffer, byte: u8) void {
self.data[self.mask(self.write_index)] = byte;
self.write_index = self.mask2(self.write_index + 1);
}
pub fn writeSlice(self: *RingBuffer, bytes: []const u8) !void {
if (self.len() + bytes.len > self.data.len) return error.Full;
self.writeSliceAssumeCapacity(bytes);
}
pub fn writeSliceAssumeCapacity(self: *RingBuffer, bytes: []const u8) void {
for (bytes) |b| self.writeAssumeCapacity(b);
}
pub fn read(self: *RingBuffer) ?u8 {
if (self.isEmpty()) return null;
const byte = self.data[self.mask(self.read_index)];
self.read_index = self.mask2(self.read_index + 1);
return byte;
}
pub fn isEmpty(self: RingBuffer) bool {
return self.write_index == self.read_index;
}
pub fn isFull(self: RingBuffer) bool {
return self.mask2(self.write_index + self.data.len) == self.read_index;
}
pub fn len(self: RingBuffer) usize {
const adjusted_write_index = self.write_index + @boolToInt(self.write_index < self.read_index) * 2 * self.data.len;
return adjusted_write_index - self.read_index;
}
const Slice = struct {
first: []u8,
second: []u8,
};
pub fn sliceAt(self: RingBuffer, start_unmasked: usize, length: usize) Slice {
assert(length <= self.data.len);
const slice1_start = self.mask(start_unmasked);
const slice1_end = @min(self.data.len, slice1_start + length);
const slice1 = self.data[slice1_start..slice1_end];
const slice2 = self.data[0 .. length - slice1.len];
return Slice{
.first = slice1,
.second = slice2,
};
}
pub fn sliceLast(self: RingBuffer, length: usize) Slice {
return self.sliceAt(self.write_index + self.data.len - length, length);
}

View File

@ -6,6 +6,7 @@ const frame = types.frame;
const Literals = types.compressed_block.Literals;
const Sequences = types.compressed_block.Sequences;
const Table = types.compressed_block.Table;
const RingBuffer = @import("RingBuffer.zig");
const readInt = std.mem.readIntLittle;
const readIntSlice = std.mem.readIntSliceLittle;
@ -214,7 +215,7 @@ const DecodeState = struct {
}
fn executeSequenceSlice(self: *DecodeState, dest: []u8, write_pos: usize, literals: Literals, sequence: Sequence) !void {
try self.decodeLiteralsInto(dest[write_pos..], literals, sequence.literal_length);
try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length);
// TODO: should we validate offset against max_window_size?
assert(sequence.offset <= write_pos + sequence.literal_length);
@ -225,6 +226,15 @@ const DecodeState = struct {
std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]);
}
fn executeSequenceRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, sequence: Sequence) !void {
try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length);
// TODO: check that ring buffer window is full enough for match copies
const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length);
// TODO: would std.mem.copy and figuring out dest slice be better/faster?
for (copy_slice.first) |b| dest.writeAssumeCapacity(b);
for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
}
fn decodeSequenceSlice(
self: *DecodeState,
dest: []u8,
@ -246,6 +256,31 @@ const DecodeState = struct {
return sequence.match_length + sequence.literal_length;
}
fn decodeSequenceRingBuffer(
self: *DecodeState,
dest: *RingBuffer,
literals: Literals,
bit_reader: anytype,
last_sequence: bool,
) !usize {
const sequence = try self.nextSequence(bit_reader);
try self.executeSequenceRingBuffer(dest, literals, sequence);
if (std.options.log_level == .debug) {
const sequence_length = sequence.literal_length + sequence.match_length;
const written_slice = dest.sliceLast(sequence_length);
log.debug("sequence decompressed into '{x}{x}'", .{
std.fmt.fmtSliceHexUpper(written_slice.first),
std.fmt.fmtSliceHexUpper(written_slice.second),
});
}
if (!last_sequence) {
try self.updateState(.literal, bit_reader);
try self.updateState(.match, bit_reader);
try self.updateState(.offset, bit_reader);
}
return sequence.match_length + sequence.literal_length;
}
fn nextLiteralMultiStream(self: *DecodeState, literals: Literals) !void {
self.literal_stream_index += 1;
try self.initLiteralStream(literals.streams.four[self.literal_stream_index]);
@ -258,7 +293,7 @@ const DecodeState = struct {
while (0 == try self.literal_stream_reader.readBitsNoEof(u1, 1)) {}
}
fn decodeLiteralsInto(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void {
fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void {
if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
switch (literals.header.block_type) {
.raw => {
@ -327,6 +362,74 @@ const DecodeState = struct {
}
}
fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, len: usize) !void {
if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
switch (literals.header.block_type) {
.raw => {
const literal_data = literals.streams.one[self.literal_written_count .. self.literal_written_count + len];
dest.writeSliceAssumeCapacity(literal_data);
self.literal_written_count += len;
},
.rle => {
var i: usize = 0;
while (i < len) : (i += 1) {
dest.writeAssumeCapacity(literals.streams.one[0]);
}
self.literal_written_count += len;
},
.compressed, .treeless => {
// const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
const huffman_tree = self.huffman_tree orelse unreachable;
const max_bit_count = huffman_tree.max_bit_count;
const starting_bit_count = Literals.HuffmanTree.weightToBitCount(
huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
max_bit_count,
);
var bits_read: u4 = 0;
var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one;
var bit_count_to_read: u4 = starting_bit_count;
var i: usize = 0;
while (i < len) : (i += 1) {
var prefix: u16 = 0;
while (true) {
const new_bits = self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch |err|
switch (err) {
error.EndOfStream => if (literals.streams == .four and self.literal_stream_index < 3) bits: {
try self.nextLiteralMultiStream(literals);
break :bits try self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read);
} else {
return error.UnexpectedEndOfLiteralStream;
},
};
prefix <<= bit_count_to_read;
prefix |= new_bits;
bits_read += bit_count_to_read;
const result = try huffman_tree.query(huffman_tree_index, prefix);
switch (result) {
.symbol => |sym| {
dest.writeAssumeCapacity(sym);
bit_count_to_read = starting_bit_count;
bits_read = 0;
huffman_tree_index = huffman_tree.symbol_count_minus_one;
break;
},
.index => |index| {
huffman_tree_index = index;
const bit_count = Literals.HuffmanTree.weightToBitCount(
huffman_tree.nodes[index].weight,
max_bit_count,
);
bit_count_to_read = bit_count - bits_read;
},
}
}
}
self.literal_written_count += len;
},
}
}
fn getCode(self: *DecodeState, comptime choice: DataType) u32 {
return switch (@field(self, @tagName(choice)).table) {
.rle => |value| value,
@ -437,6 +540,14 @@ fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
return block_size;
}
fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize {
log.debug("writing raw block - size {d}", .{block_size});
const data = src[0..block_size];
dest.writeSliceAssumeCapacity(data);
consumed_count.* += block_size;
return block_size;
}
fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize {
log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
var write_pos: usize = 0;
@ -447,6 +558,16 @@ fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
return block_size;
}
fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize {
log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
var write_pos: usize = 0;
while (write_pos < block_size) : (write_pos += 1) {
dest.writeAssumeCapacity(src[0]);
}
consumed_count.* += 1;
return block_size;
}
fn prepareDecodeState(
decode_state: *DecodeState,
src: []const u8,
@ -545,7 +666,7 @@ pub fn decodeBlock(
if (decode_state.literal_written_count < literals.header.regenerated_size) {
log.debug("decoding remaining literals", .{});
const len = literals.header.regenerated_size - decode_state.literal_written_count;
try decode_state.decodeLiteralsInto(dest[written_count + bytes_written ..], literals, len);
try decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len);
log.debug("remaining decoded literals at {d}: {}", .{
written_count,
std.fmt.fmtSliceHexUpper(dest[written_count .. written_count + len]),
@ -562,6 +683,73 @@ pub fn decodeBlock(
}
}
pub fn decodeBlockRingBuffer(
dest: *RingBuffer,
src: []const u8,
block_header: frame.ZStandard.Block.Header,
decode_state: *DecodeState,
consumed_count: *usize,
block_size_maximum: usize,
) !usize {
const block_size = block_header.block_size;
if (block_size_maximum < block_size) return error.BlockSizeOverMaximum;
// TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
switch (block_header.block_type) {
.raw => return decodeRawBlockRingBuffer(dest, src, block_size, consumed_count),
.rle => return decodeRleBlockRingBuffer(dest, src, block_size, consumed_count),
.compressed => {
var bytes_read: usize = 0;
const literals = try decodeLiteralsSection(src, &bytes_read);
const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header);
var bytes_written: usize = 0;
if (sequences_header.sequence_count > 0) {
const bit_stream_bytes = src[bytes_read..block_size];
var reverse_byte_reader = reversedByteReader(bit_stream_bytes);
var bit_stream = reverseBitReader(reverse_byte_reader.reader());
while (0 == try bit_stream.readBitsNoEof(u1, 1)) {}
try decode_state.readInitialState(&bit_stream);
var i: usize = 0;
while (i < sequences_header.sequence_count) : (i += 1) {
log.debug("decoding sequence {d}", .{i});
const decompressed_size = try decode_state.decodeSequenceRingBuffer(
dest,
literals,
&bit_stream,
i == sequences_header.sequence_count - 1,
);
bytes_written += decompressed_size;
}
bytes_read += bit_stream_bytes.len;
}
if (decode_state.literal_written_count < literals.header.regenerated_size) {
log.debug("decoding remaining literals", .{});
const len = literals.header.regenerated_size - decode_state.literal_written_count;
try decode_state.decodeLiteralsRingBuffer(dest, literals, len);
const written_slice = dest.sliceLast(len);
log.debug("remaining decoded literals at {d}: {}{}", .{
bytes_written,
std.fmt.fmtSliceHexUpper(written_slice.first),
std.fmt.fmtSliceHexUpper(written_slice.second),
});
bytes_written += len;
}
decode_state.literal_written_count = 0;
assert(bytes_read == block_header.block_size);
consumed_count.* += bytes_read;
return bytes_written;
},
.reserved => return error.FrameContainsReservedBlock,
}
}
pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header {
const magic = readInt(u32, src[0..4]);
assert(isSkippableMagic(magic));