zig/lib/std/compress/zstandard/decompress.zig
2023-02-20 09:09:06 +11:00

1763 lines
70 KiB
Zig

const std = @import("std");
const assert = std.debug.assert;
const types = @import("types.zig");
const frame = types.frame;
const LiteralsSection = types.compressed_block.LiteralsSection;
const SequencesSection = types.compressed_block.SequencesSection;
const Table = types.compressed_block.Table;
pub const RingBuffer = @import("RingBuffer.zig");
const readInt = std.mem.readIntLittle;
const readIntSlice = std.mem.readIntSliceLittle;
fn readVarInt(comptime T: type, bytes: []const u8) T {
return std.mem.readVarInt(T, bytes, .Little);
}
pub fn isSkippableMagic(magic: u32) bool {
return frame.Skippable.magic_number_min <= magic and magic <= frame.Skippable.magic_number_max;
}
/// Returns the kind of frame at the beginning of `src`.
///
/// Errors:
/// - returns `error.BadMagic` if `source` begins with bytes not equal to the
/// Zstandard frame magic number, or outside the range of magic numbers for
/// skippable frames.
pub fn decodeFrameType(source: anytype) !frame.Kind {
const magic = try source.readIntLittle(u32);
return if (magic == frame.ZStandard.magic_number)
.zstandard
else if (isSkippableMagic(magic))
.skippable
else
error.BadMagic;
}
const ReadWriteCount = struct {
read_count: usize,
write_count: usize,
};
/// Decodes the frame at the start of `src` into `dest`. Returns the number of
/// bytes read from `src` and written to `dest`.
///
/// Errors:
/// - returns `error.UnknownContentSizeUnsupported`
/// - returns `error.ContentTooLarge`
/// - returns `error.BadMagic`
pub fn decodeFrame(
dest: []u8,
src: []const u8,
verify_checksum: bool,
) !ReadWriteCount {
var fbs = std.io.fixedBufferStream(src);
return switch (try decodeFrameType(fbs.reader())) {
.zstandard => decodeZStandardFrame(dest, src, verify_checksum),
.skippable => ReadWriteCount{
.read_count = try fbs.reader().readIntLittle(u32) + 8,
.write_count = 0,
},
};
}
pub const DecodeState = struct {
repeat_offsets: [3]u32,
offset: StateData(8),
match: StateData(9),
literal: StateData(9),
offset_fse_buffer: []Table.Fse,
match_fse_buffer: []Table.Fse,
literal_fse_buffer: []Table.Fse,
fse_tables_undefined: bool,
literal_stream_reader: ReverseBitReader,
literal_stream_index: usize,
literal_streams: LiteralsSection.Streams,
literal_header: LiteralsSection.Header,
huffman_tree: ?LiteralsSection.HuffmanTree,
literal_written_count: usize,
fn StateData(comptime max_accuracy_log: comptime_int) type {
return struct {
state: State,
table: Table,
accuracy_log: u8,
const State = std.meta.Int(.unsigned, max_accuracy_log);
};
}
pub fn init(
literal_fse_buffer: []Table.Fse,
match_fse_buffer: []Table.Fse,
offset_fse_buffer: []Table.Fse,
) DecodeState {
return DecodeState{
.repeat_offsets = .{
types.compressed_block.start_repeated_offset_1,
types.compressed_block.start_repeated_offset_2,
types.compressed_block.start_repeated_offset_3,
},
.offset = undefined,
.match = undefined,
.literal = undefined,
.literal_fse_buffer = literal_fse_buffer,
.match_fse_buffer = match_fse_buffer,
.offset_fse_buffer = offset_fse_buffer,
.fse_tables_undefined = true,
.literal_written_count = 0,
.literal_header = undefined,
.literal_streams = undefined,
.literal_stream_reader = undefined,
.literal_stream_index = undefined,
.huffman_tree = null,
};
}
/// Prepare the decoder to decode a compressed block. Loads the literals
/// stream and Huffman tree from `literals` and reads the FSE tables from
/// `source`.
///
/// Errors:
/// - returns `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's
/// first byte does not have any bits set.
/// - returns `error.TreelessLiteralsFirst` `literals` is a treeless literals section
/// and the decode state does not have a Huffman tree from a previous block.
pub fn prepare(
self: *DecodeState,
source: anytype,
literals: LiteralsSection,
sequences_header: SequencesSection.Header,
) !void {
self.literal_written_count = 0;
self.literal_header = literals.header;
self.literal_streams = literals.streams;
if (literals.huffman_tree) |tree| {
self.huffman_tree = tree;
} else if (literals.header.block_type == .treeless and self.huffman_tree == null) {
return error.TreelessLiteralsFirst;
}
switch (literals.header.block_type) {
.raw, .rle => {},
.compressed, .treeless => {
self.literal_stream_index = 0;
switch (literals.streams) {
.one => |slice| try self.initLiteralStream(slice),
.four => |streams| try self.initLiteralStream(streams[0]),
}
},
}
if (sequences_header.sequence_count > 0) {
try self.updateFseTable(source, .literal, sequences_header.literal_lengths);
try self.updateFseTable(source, .offset, sequences_header.offsets);
try self.updateFseTable(source, .match, sequences_header.match_lengths);
self.fse_tables_undefined = false;
}
}
/// Read initial FSE states for sequence decoding. Returns `error.EndOfStream`
/// if `bit_reader` does not contain enough bits.
pub fn readInitialFseState(self: *DecodeState, bit_reader: *ReverseBitReader) error{EndOfStream}!void {
self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log);
self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log);
self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log);
}
fn updateRepeatOffset(self: *DecodeState, offset: u32) void {
std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1]);
std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]);
self.repeat_offsets[0] = offset;
}
fn useRepeatOffset(self: *DecodeState, index: usize) u32 {
if (index == 1)
std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1])
else if (index == 2) {
std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]);
std.mem.swap(u32, &self.repeat_offsets[1], &self.repeat_offsets[2]);
}
return self.repeat_offsets[0];
}
const DataType = enum { offset, match, literal };
fn updateState(
self: *DecodeState,
comptime choice: DataType,
bit_reader: *ReverseBitReader,
) error{ MalformedFseBits, EndOfStream }!void {
switch (@field(self, @tagName(choice)).table) {
.rle => {},
.fse => |table| {
const data = table[@field(self, @tagName(choice)).state];
const T = @TypeOf(@field(self, @tagName(choice))).State;
const bits_summand = try bit_reader.readBitsNoEof(T, data.bits);
const next_state = std.math.cast(
@TypeOf(@field(self, @tagName(choice))).State,
data.baseline + bits_summand,
) orelse return error.MalformedFseBits;
@field(self, @tagName(choice)).state = next_state;
},
}
}
const FseTableError = error{
MalformedFseTable,
MalformedAccuracyLog,
RepeatModeFirst,
EndOfStream,
};
fn updateFseTable(
self: *DecodeState,
source: anytype,
comptime choice: DataType,
mode: SequencesSection.Header.Mode,
) !void {
const field_name = @tagName(choice);
switch (mode) {
.predefined => {
@field(self, field_name).accuracy_log =
@field(types.compressed_block.default_accuracy_log, field_name);
@field(self, field_name).table =
@field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table");
},
.rle => {
@field(self, field_name).accuracy_log = 0;
@field(self, field_name).table = .{ .rle = try source.readByte() };
},
.fse => {
var bit_reader = bitReader(source);
const table_size = try decodeFseTable(
&bit_reader,
@field(types.compressed_block.table_symbol_count_max, field_name),
@field(types.compressed_block.table_accuracy_log_max, field_name),
@field(self, field_name ++ "_fse_buffer"),
);
@field(self, field_name).table = .{
.fse = @field(self, field_name ++ "_fse_buffer")[0..table_size],
};
@field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size);
},
.repeat => if (self.fse_tables_undefined) return error.RepeatModeFirst,
}
}
const Sequence = struct {
literal_length: u32,
match_length: u32,
offset: u32,
};
fn nextSequence(
self: *DecodeState,
bit_reader: *ReverseBitReader,
) error{ OffsetCodeTooLarge, EndOfStream }!Sequence {
const raw_code = self.getCode(.offset);
const offset_code = std.math.cast(u5, raw_code) orelse {
return error.OffsetCodeTooLarge;
};
const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code);
const match_code = self.getCode(.match);
const match = types.compressed_block.match_length_code_table[match_code];
const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]);
const literal_code = self.getCode(.literal);
const literal = types.compressed_block.literals_length_code_table[literal_code];
const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]);
const offset = if (offset_value > 3) offset: {
const offset = offset_value - 3;
self.updateRepeatOffset(offset);
break :offset offset;
} else offset: {
if (literal_length == 0) {
if (offset_value == 3) {
const offset = self.repeat_offsets[0] - 1;
self.updateRepeatOffset(offset);
break :offset offset;
}
break :offset self.useRepeatOffset(offset_value);
}
break :offset self.useRepeatOffset(offset_value - 1);
};
return .{
.literal_length = literal_length,
.match_length = match_length,
.offset = offset,
};
}
fn executeSequenceSlice(
self: *DecodeState,
dest: []u8,
write_pos: usize,
sequence: Sequence,
) (error{MalformedSequence} || DecodeLiteralsError)!void {
if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence;
try self.decodeLiteralsSlice(dest[write_pos..], sequence.literal_length);
const copy_start = write_pos + sequence.literal_length - sequence.offset;
const copy_end = copy_start + sequence.match_length;
// NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr
// to allow repeats
std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]);
}
fn executeSequenceRingBuffer(
self: *DecodeState,
dest: *RingBuffer,
sequence: Sequence,
) (error{MalformedSequence} || DecodeLiteralsError)!void {
if (sequence.offset > dest.data.len) return error.MalformedSequence;
try self.decodeLiteralsRingBuffer(dest, sequence.literal_length);
const copy_start = dest.write_index + dest.data.len - sequence.offset;
const copy_slice = dest.sliceAt(copy_start, sequence.match_length);
// TODO: would std.mem.copy and figuring out dest slice be better/faster?
for (copy_slice.first) |b| dest.writeAssumeCapacity(b);
for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
}
const DecodeSequenceError = error{
OffsetCodeTooLarge,
EndOfStream,
MalformedSequence,
MalformedFseBits,
} || DecodeLiteralsError;
/// Decode one sequence from `bit_reader` into `dest`, written starting at
/// `write_pos` and update FSE states if `last_sequence` is `false`. Returns
/// `error.MalformedSequence` error if the decompressed sequence would be longer
/// than `sequence_size_limit` or the sequence's offset is too large; returns
/// `error.EndOfStream` if `bit_reader` does not contain enough bits; returns
/// `error.UnexpectedEndOfLiteralStream` if the decoder state's literal streams
/// do not contain enough literals for the sequence (this may mean the literal
/// stream or the sequence is malformed).
pub fn decodeSequenceSlice(
self: *DecodeState,
dest: []u8,
write_pos: usize,
bit_reader: *ReverseBitReader,
sequence_size_limit: usize,
last_sequence: bool,
) DecodeSequenceError!usize {
const sequence = try self.nextSequence(bit_reader);
const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
if (sequence_length > sequence_size_limit) return error.MalformedSequence;
try self.executeSequenceSlice(dest, write_pos, sequence);
if (!last_sequence) {
try self.updateState(.literal, bit_reader);
try self.updateState(.match, bit_reader);
try self.updateState(.offset, bit_reader);
}
return sequence_length;
}
/// Decode one sequence from `bit_reader` into `dest`; see `decodeSequenceSlice`.
pub fn decodeSequenceRingBuffer(
self: *DecodeState,
dest: *RingBuffer,
bit_reader: anytype,
sequence_size_limit: usize,
last_sequence: bool,
) DecodeSequenceError!usize {
const sequence = try self.nextSequence(bit_reader);
const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
if (sequence_length > sequence_size_limit) return error.MalformedSequence;
try self.executeSequenceRingBuffer(dest, sequence);
if (!last_sequence) {
try self.updateState(.literal, bit_reader);
try self.updateState(.match, bit_reader);
try self.updateState(.offset, bit_reader);
}
return sequence_length;
}
fn nextLiteralMultiStream(
self: *DecodeState,
) error{BitStreamHasNoStartBit}!void {
self.literal_stream_index += 1;
try self.initLiteralStream(self.literal_streams.four[self.literal_stream_index]);
}
pub fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
try self.literal_stream_reader.init(bytes);
}
const LiteralBitsError = error{
BitStreamHasNoStartBit,
UnexpectedEndOfLiteralStream,
};
fn readLiteralsBits(
self: *DecodeState,
comptime T: type,
bit_count_to_read: usize,
) LiteralBitsError!T {
return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
if (self.literal_streams == .four and self.literal_stream_index < 3) {
try self.nextLiteralMultiStream();
break :bits self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch
return error.UnexpectedEndOfLiteralStream;
} else {
return error.UnexpectedEndOfLiteralStream;
}
};
}
const DecodeLiteralsError = error{
MalformedLiteralsLength,
PrefixNotFound,
} || LiteralBitsError;
/// Decode `len` bytes of literals into `dest`. `literals` should be the
/// `LiteralsSection` that was passed to `prepare()`. Returns
/// `error.MalformedLiteralsLength` if the number of literal bytes decoded by
/// `self` plus `len` is greater than the regenerated size of `literals`.
/// Returns `error.UnexpectedEndOfLiteralStream` and `error.PrefixNotFound` if
/// there are problems decoding Huffman compressed literals.
pub fn decodeLiteralsSlice(
self: *DecodeState,
dest: []u8,
len: usize,
) DecodeLiteralsError!void {
if (self.literal_written_count + len > self.literal_header.regenerated_size)
return error.MalformedLiteralsLength;
switch (self.literal_header.block_type) {
.raw => {
const literals_end = self.literal_written_count + len;
const literal_data = self.literal_streams.one[self.literal_written_count..literals_end];
std.mem.copy(u8, dest, literal_data);
self.literal_written_count += len;
},
.rle => {
var i: usize = 0;
while (i < len) : (i += 1) {
dest[i] = self.literal_streams.one[0];
}
self.literal_written_count += len;
},
.compressed, .treeless => {
// const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
const huffman_tree = self.huffman_tree orelse unreachable;
const max_bit_count = huffman_tree.max_bit_count;
const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
max_bit_count,
);
var bits_read: u4 = 0;
var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one;
var bit_count_to_read: u4 = starting_bit_count;
var i: usize = 0;
while (i < len) : (i += 1) {
var prefix: u16 = 0;
while (true) {
const new_bits = self.readLiteralsBits(u16, bit_count_to_read) catch |err| {
return err;
};
prefix <<= bit_count_to_read;
prefix |= new_bits;
bits_read += bit_count_to_read;
const result = huffman_tree.query(huffman_tree_index, prefix) catch |err| {
return err;
};
switch (result) {
.symbol => |sym| {
dest[i] = sym;
bit_count_to_read = starting_bit_count;
bits_read = 0;
huffman_tree_index = huffman_tree.symbol_count_minus_one;
break;
},
.index => |index| {
huffman_tree_index = index;
const bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
huffman_tree.nodes[index].weight,
max_bit_count,
);
bit_count_to_read = bit_count - bits_read;
},
}
}
}
self.literal_written_count += len;
},
}
}
/// Decode literals into `dest`; see `decodeLiteralsSlice()`.
pub fn decodeLiteralsRingBuffer(
self: *DecodeState,
dest: *RingBuffer,
len: usize,
) DecodeLiteralsError!void {
if (self.literal_written_count + len > self.literal_header.regenerated_size)
return error.MalformedLiteralsLength;
switch (self.literal_header.block_type) {
.raw => {
const literals_end = self.literal_written_count + len;
const literal_data = self.literal_streams.one[self.literal_written_count..literals_end];
dest.writeSliceAssumeCapacity(literal_data);
self.literal_written_count += len;
},
.rle => {
var i: usize = 0;
while (i < len) : (i += 1) {
dest.writeAssumeCapacity(self.literal_streams.one[0]);
}
self.literal_written_count += len;
},
.compressed, .treeless => {
// const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
const huffman_tree = self.huffman_tree orelse unreachable;
const max_bit_count = huffman_tree.max_bit_count;
const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
max_bit_count,
);
var bits_read: u4 = 0;
var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one;
var bit_count_to_read: u4 = starting_bit_count;
var i: usize = 0;
while (i < len) : (i += 1) {
var prefix: u16 = 0;
while (true) {
const new_bits = try self.readLiteralsBits(u16, bit_count_to_read);
prefix <<= bit_count_to_read;
prefix |= new_bits;
bits_read += bit_count_to_read;
const result = try huffman_tree.query(huffman_tree_index, prefix);
switch (result) {
.symbol => |sym| {
dest.writeAssumeCapacity(sym);
bit_count_to_read = starting_bit_count;
bits_read = 0;
huffman_tree_index = huffman_tree.symbol_count_minus_one;
break;
},
.index => |index| {
huffman_tree_index = index;
const bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
huffman_tree.nodes[index].weight,
max_bit_count,
);
bit_count_to_read = bit_count - bits_read;
},
}
}
}
self.literal_written_count += len;
},
}
}
fn getCode(self: *DecodeState, comptime choice: DataType) u32 {
return switch (@field(self, @tagName(choice)).table) {
.rle => |value| value,
.fse => |table| table[@field(self, @tagName(choice)).state].symbol,
};
}
};
pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 {
const hash = hasher.final();
return @intCast(u32, hash & 0xFFFFFFFF);
}
const FrameError = error{
DictionaryIdFlagUnsupported,
ChecksumFailure,
} || InvalidBit || DecodeBlockError;
/// Decode a Zstandard frame from `src` into `dest`, returning the number of
/// bytes read from `src` and written to `dest`; if the frame does not declare
/// its decompressed content size `error.UnknownContentSizeUnsupported` is
/// returned. Returns `error.DictionaryIdFlagUnsupported` if the frame uses a
/// dictionary, and `error.ChecksumFailure` if `verify_checksum` is `true` and
/// the frame contains a checksum that does not match the checksum computed from
/// the decompressed frame.
pub fn decodeZStandardFrame(
dest: []u8,
src: []const u8,
verify_checksum: bool,
) (error{ UnknownContentSizeUnsupported, ContentTooLarge, EndOfStream } || FrameError)!ReadWriteCount {
assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
var consumed_count: usize = 4;
var fbs = std.io.fixedBufferStream(src[consumed_count..]);
var source = fbs.reader();
const frame_header = try decodeZStandardHeader(source);
consumed_count += fbs.pos;
if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported;
const content_size = frame_header.content_size orelse return error.UnknownContentSizeUnsupported;
if (dest.len < content_size) return error.ContentTooLarge;
const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null;
const written_count = try decodeFrameBlocks(
dest,
src[consumed_count..],
&consumed_count,
if (hasher_opt) |*hasher| hasher else null,
);
if (frame_header.descriptor.content_checksum_flag) {
const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
consumed_count += 4;
if (hasher_opt) |*hasher| {
if (checksum != computeChecksum(hasher)) return error.ChecksumFailure;
}
}
return ReadWriteCount{ .read_count = consumed_count, .write_count = written_count };
}
pub const FrameContext = struct {
hasher_opt: ?std.hash.XxHash64,
window_size: usize,
has_checksum: bool,
block_size_max: usize,
pub fn init(frame_header: frame.ZStandard.Header, window_size_max: usize, verify_checksum: bool) !FrameContext {
if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported;
const window_size_raw = frameWindowSize(frame_header) orelse return error.WindowSizeUnknown;
const window_size = if (window_size_raw > window_size_max)
return error.WindowTooLarge
else
@intCast(usize, window_size_raw);
const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
return .{
.hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null,
.window_size = window_size,
.has_checksum = frame_header.descriptor.content_checksum_flag,
.block_size_max = @min(1 << 17, window_size),
};
}
};
/// Decode a Zstandard from from `src` and return the decompressed bytes; see
/// `decodeZStandardFrame()`. Returns `error.WindowSizeUnknown` if the frame
/// does not declare its content size or a window descriptor (this indicates a
/// malformed frame).
///
/// Errors:
/// - returns `error.WindowTooLarge`
/// - returns `error.WindowSizeUnknown`
pub fn decodeZStandardFrameAlloc(
allocator: std.mem.Allocator,
src: []const u8,
verify_checksum: bool,
window_size_max: usize,
) (error{ WindowSizeUnknown, WindowTooLarge, OutOfMemory, EndOfStream } || FrameError)![]u8 {
var result = std.ArrayList(u8).init(allocator);
assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
var consumed_count: usize = 4;
var frame_context = context: {
var fbs = std.io.fixedBufferStream(src[consumed_count..]);
var source = fbs.reader();
const frame_header = try decodeZStandardHeader(source);
consumed_count += fbs.pos;
break :context try FrameContext.init(frame_header, window_size_max, verify_checksum);
};
var ring_buffer = try RingBuffer.init(allocator, frame_context.window_size);
defer ring_buffer.deinit(allocator);
// These tables take 7680 bytes
var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined;
var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined;
var offset_fse_data: [types.compressed_block.table_size_max.offset]Table.Fse = undefined;
var block_header = try decodeBlockHeaderSlice(src[consumed_count..]);
consumed_count += 3;
var decode_state = DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data);
while (true) : ({
block_header = try decodeBlockHeaderSlice(src[consumed_count..]);
consumed_count += 3;
}) {
if (block_header.block_size > frame_context.block_size_max) return error.BlockSizeOverMaximum;
const written_size = try decodeBlockRingBuffer(
&ring_buffer,
src[consumed_count..],
block_header,
&decode_state,
&consumed_count,
frame_context.block_size_max,
);
const written_slice = ring_buffer.sliceLast(written_size);
try result.appendSlice(written_slice.first);
try result.appendSlice(written_slice.second);
if (frame_context.hasher_opt) |*hasher| {
hasher.update(written_slice.first);
hasher.update(written_slice.second);
}
if (block_header.last_block) break;
}
if (frame_context.has_checksum) {
const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
consumed_count += 4;
if (frame_context.hasher_opt) |*hasher| {
if (checksum != computeChecksum(hasher)) return error.ChecksumFailure;
}
}
return result.toOwnedSlice();
}
const DecodeBlockError = error{
BlockSizeOverMaximum,
MalformedBlockSize,
ReservedBlock,
MalformedRleBlock,
MalformedCompressedBlock,
EndOfStream,
};
/// Convenience wrapper for decoding all blocks in a frame; see `decodeBlock()`.
pub fn decodeFrameBlocks(
dest: []u8,
src: []const u8,
consumed_count: *usize,
hash: ?*std.hash.XxHash64,
) DecodeBlockError!usize {
// These tables take 7680 bytes
var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined;
var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined;
var offset_fse_data: [types.compressed_block.table_size_max.offset]Table.Fse = undefined;
var block_header = try decodeBlockHeaderSlice(src);
var bytes_read: usize = 3;
defer consumed_count.* += bytes_read;
var decode_state = DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data);
var written_count: usize = 0;
while (true) : ({
block_header = try decodeBlockHeaderSlice(src[bytes_read..]);
bytes_read += 3;
}) {
const written_size = try decodeBlock(
dest,
src[bytes_read..],
block_header,
&decode_state,
&bytes_read,
written_count,
);
if (hash) |hash_state| hash_state.update(dest[written_count .. written_count + written_size]);
written_count += written_size;
if (block_header.last_block) break;
}
return written_count;
}
/// Decode a single block from `src` into `dest`. The beginning of `src` should
/// be the start of the block content (i.e. directly after the block header).
/// Increments `consumed_count` by the number of bytes read from `src` to decode
/// the block and returns the decompressed size of the block.
pub fn decodeBlock(
dest: []u8,
src: []const u8,
block_header: frame.ZStandard.Block.Header,
decode_state: *DecodeState,
consumed_count: *usize,
written_count: usize,
) DecodeBlockError!usize {
const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB
const block_size = block_header.block_size;
if (block_size_max < block_size) return error.BlockSizeOverMaximum;
switch (block_header.block_type) {
.raw => {
if (src.len < block_size) return error.MalformedBlockSize;
const data = src[0..block_size];
std.mem.copy(u8, dest[written_count..], data);
consumed_count.* += block_size;
return block_size;
},
.rle => {
if (src.len < 1) return error.MalformedRleBlock;
var write_pos: usize = written_count;
while (write_pos < block_size + written_count) : (write_pos += 1) {
dest[write_pos] = src[0];
}
consumed_count.* += 1;
return block_size;
},
.compressed => {
if (src.len < block_size) return error.MalformedBlockSize;
var bytes_read: usize = 0;
const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch
return error.MalformedCompressedBlock;
var fbs = std.io.fixedBufferStream(src[bytes_read..]);
const fbs_reader = fbs.reader();
const sequences_header = decodeSequencesHeader(fbs_reader) catch
return error.MalformedCompressedBlock;
decode_state.prepare(fbs_reader, literals, sequences_header) catch
return error.MalformedCompressedBlock;
bytes_read += fbs.pos;
var bytes_written: usize = 0;
if (sequences_header.sequence_count > 0) {
const bit_stream_bytes = src[bytes_read..block_size];
var bit_stream: ReverseBitReader = undefined;
bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock;
decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock;
var sequence_size_limit = block_size_max;
var i: usize = 0;
while (i < sequences_header.sequence_count) : (i += 1) {
const write_pos = written_count + bytes_written;
const decompressed_size = decode_state.decodeSequenceSlice(
dest,
write_pos,
&bit_stream,
sequence_size_limit,
i == sequences_header.sequence_count - 1,
) catch return error.MalformedCompressedBlock;
bytes_written += decompressed_size;
sequence_size_limit -= decompressed_size;
}
bytes_read += bit_stream_bytes.len;
}
if (bytes_read != block_size) return error.MalformedCompressedBlock;
if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count;
decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], len) catch
return error.MalformedCompressedBlock;
bytes_written += len;
}
consumed_count.* += bytes_read;
return bytes_written;
},
.reserved => return error.ReservedBlock,
}
}
/// Decode a single block from `src` into `dest`; see `decodeBlock()`. Returns
/// the size of the decompressed block, which can be used with `dest.sliceLast()`
/// to get the decompressed bytes.
pub fn decodeBlockRingBuffer(
dest: *RingBuffer,
src: []const u8,
block_header: frame.ZStandard.Block.Header,
decode_state: *DecodeState,
consumed_count: *usize,
block_size_max: usize,
) DecodeBlockError!usize {
const block_size = block_header.block_size;
if (block_size_max < block_size) return error.BlockSizeOverMaximum;
switch (block_header.block_type) {
.raw => {
if (src.len < block_size) return error.MalformedBlockSize;
const data = src[0..block_size];
dest.writeSliceAssumeCapacity(data);
consumed_count.* += block_size;
return block_size;
},
.rle => {
if (src.len < 1) return error.MalformedRleBlock;
var write_pos: usize = 0;
while (write_pos < block_size) : (write_pos += 1) {
dest.writeAssumeCapacity(src[0]);
}
consumed_count.* += 1;
return block_size;
},
.compressed => {
if (src.len < block_size) return error.MalformedBlockSize;
var bytes_read: usize = 0;
const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch
return error.MalformedCompressedBlock;
var fbs = std.io.fixedBufferStream(src[bytes_read..]);
const fbs_reader = fbs.reader();
const sequences_header = decodeSequencesHeader(fbs_reader) catch
return error.MalformedCompressedBlock;
decode_state.prepare(fbs_reader, literals, sequences_header) catch
return error.MalformedCompressedBlock;
bytes_read += fbs.pos;
var bytes_written: usize = 0;
if (sequences_header.sequence_count > 0) {
const bit_stream_bytes = src[bytes_read..block_size];
var bit_stream: ReverseBitReader = undefined;
bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock;
decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock;
var sequence_size_limit = block_size_max;
var i: usize = 0;
while (i < sequences_header.sequence_count) : (i += 1) {
const decompressed_size = decode_state.decodeSequenceRingBuffer(
dest,
&bit_stream,
sequence_size_limit,
i == sequences_header.sequence_count - 1,
) catch return error.MalformedCompressedBlock;
bytes_written += decompressed_size;
sequence_size_limit -= decompressed_size;
}
bytes_read += bit_stream_bytes.len;
}
if (bytes_read != block_size) return error.MalformedCompressedBlock;
if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count;
decode_state.decodeLiteralsRingBuffer(dest, len) catch
return error.MalformedCompressedBlock;
bytes_written += len;
}
consumed_count.* += bytes_read;
if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
return bytes_written;
},
.reserved => return error.ReservedBlock,
}
}
/// Decode a single block from `source` into `dest`. Literal and sequence data
/// from the block is copied into `literals_buffer` and `sequence_buffer`, which
/// must be large enough or `error.LiteralsBufferTooSmall` and
/// `error.SequenceBufferTooSmall` are returned (the maximum block size is an
/// upper bound for the size of both buffers). See `decodeBlock`
/// and `decodeBlockRingBuffer` for function that can decode a block without
/// these extra copies.
pub fn decodeBlockReader(
dest: *RingBuffer,
source: anytype,
block_header: frame.ZStandard.Block.Header,
decode_state: *DecodeState,
block_size_max: usize,
literals_buffer: []u8,
sequence_buffer: []u8,
) !void {
const block_size = block_header.block_size;
var block_reader_limited = std.io.limitedReader(source, block_size);
const block_reader = block_reader_limited.reader();
if (block_size_max < block_size) return error.BlockSizeOverMaximum;
switch (block_header.block_type) {
.raw => {
const slice = dest.sliceAt(dest.write_index, block_size);
try source.readNoEof(slice.first);
try source.readNoEof(slice.second);
dest.write_index = dest.mask2(dest.write_index + block_size);
},
.rle => {
const byte = try source.readByte();
var i: usize = 0;
while (i < block_size) : (i += 1) {
dest.writeAssumeCapacity(byte);
}
},
.compressed => {
const literals = try decodeLiteralsSection(block_reader, literals_buffer);
const sequences_header = try decodeSequencesHeader(block_reader);
try decode_state.prepare(block_reader, literals, sequences_header);
if (sequences_header.sequence_count > 0) {
if (sequence_buffer.len < block_reader_limited.bytes_left)
return error.SequenceBufferTooSmall;
const size = try block_reader.readAll(sequence_buffer);
var bit_stream: ReverseBitReader = undefined;
try bit_stream.init(sequence_buffer[0..size]);
decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock;
var sequence_size_limit = block_size_max;
var i: usize = 0;
while (i < sequences_header.sequence_count) : (i += 1) {
const decompressed_size = decode_state.decodeSequenceRingBuffer(
dest,
&bit_stream,
sequence_size_limit,
i == sequences_header.sequence_count - 1,
) catch return error.MalformedCompressedBlock;
sequence_size_limit -= decompressed_size;
}
}
if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count;
decode_state.decodeLiteralsRingBuffer(dest, len) catch
return error.MalformedCompressedBlock;
}
decode_state.literal_written_count = 0;
assert(block_reader.readByte() == error.EndOfStream);
},
.reserved => return error.ReservedBlock,
}
}
/// Decode the header of a skippable frame.
pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header {
const magic = readInt(u32, src[0..4]);
assert(isSkippableMagic(magic));
const frame_size = readInt(u32, src[4..8]);
return .{
.magic_number = magic,
.frame_size = frame_size,
};
}
/// Returns the window size required to decompress a frame, or `null` if it cannot be
/// determined, which indicates a malformed frame header.
pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 {
if (header.window_descriptor) |descriptor| {
const exponent = (descriptor & 0b11111000) >> 3;
const mantissa = descriptor & 0b00000111;
const window_log = 10 + exponent;
const window_base = @as(u64, 1) << @intCast(u6, window_log);
const window_add = (window_base / 8) * mantissa;
return window_base + window_add;
} else return header.content_size;
}
const InvalidBit = error{ UnusedBitSet, ReservedBitSet };
/// Decode the header of a Zstandard frame.
///
/// Errors:
/// - returns `error.UnusedBitSet` if the unused bits of the header are set
/// - returns `error.ReservedBitSet` if the reserved bits of the header are
/// set
pub fn decodeZStandardHeader(source: anytype) (error{EndOfStream} || InvalidBit)!frame.ZStandard.Header {
const descriptor = @bitCast(frame.ZStandard.Header.Descriptor, try source.readByte());
if (descriptor.unused) return error.UnusedBitSet;
if (descriptor.reserved) return error.ReservedBitSet;
var window_descriptor: ?u8 = null;
if (!descriptor.single_segment_flag) {
window_descriptor = try source.readByte();
}
var dictionary_id: ?u32 = null;
if (descriptor.dictionary_id_flag > 0) {
// if flag is 3 then field_size = 4, else field_size = flag
const field_size = (@as(u4, 1) << descriptor.dictionary_id_flag) >> 1;
dictionary_id = try source.readVarInt(u32, .Little, field_size);
}
var content_size: ?u64 = null;
if (descriptor.single_segment_flag or descriptor.content_size_flag > 0) {
const field_size = @as(u4, 1) << descriptor.content_size_flag;
content_size = try source.readVarInt(u64, .Little, field_size);
if (field_size == 2) content_size.? += 256;
}
const header = frame.ZStandard.Header{
.descriptor = descriptor,
.window_descriptor = window_descriptor,
.dictionary_id = dictionary_id,
.content_size = content_size,
};
return header;
}
/// Decode the header of a block.
pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header {
const last_block = src[0] & 1 == 1;
const block_type = @intToEnum(frame.ZStandard.Block.Type, (src[0] & 0b110) >> 1);
const block_size = ((src[0] & 0b11111000) >> 3) + (@as(u21, src[1]) << 5) + (@as(u21, src[2]) << 13);
return .{
.last_block = last_block,
.block_type = block_type,
.block_size = block_size,
};
}
pub fn decodeBlockHeaderSlice(src: []const u8) error{EndOfStream}!frame.ZStandard.Block.Header {
if (src.len < 3) return error.EndOfStream;
return decodeBlockHeader(src[0..3]);
}
/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the
/// number of bytes the section uses.
///
/// Errors:
/// - returns `error.MalformedLiteralsHeader` if the header is invalid
/// - returns `error.MalformedLiteralsSection` if there are errors decoding
pub fn decodeLiteralsSectionSlice(
src: []const u8,
consumed_count: *usize,
) (error{ MalformedLiteralsHeader, MalformedLiteralsSection, EndOfStream } || DecodeHuffmanError)!LiteralsSection {
var bytes_read: usize = 0;
const header = header: {
var fbs = std.io.fixedBufferStream(src);
defer bytes_read = fbs.pos;
break :header decodeLiteralsHeader(fbs.reader()) catch return error.MalformedLiteralsHeader;
};
switch (header.block_type) {
.raw => {
if (src.len < bytes_read + header.regenerated_size) return error.MalformedLiteralsSection;
const stream = src[bytes_read .. bytes_read + header.regenerated_size];
consumed_count.* += header.regenerated_size + bytes_read;
return LiteralsSection{
.header = header,
.huffman_tree = null,
.streams = .{ .one = stream },
};
},
.rle => {
if (src.len < bytes_read + 1) return error.MalformedLiteralsSection;
const stream = src[bytes_read .. bytes_read + 1];
consumed_count.* += 1 + bytes_read;
return LiteralsSection{
.header = header,
.huffman_tree = null,
.streams = .{ .one = stream },
};
},
.compressed, .treeless => {
const huffman_tree_start = bytes_read;
const huffman_tree = if (header.block_type == .compressed)
try decodeHuffmanTreeSlice(src[bytes_read..], &bytes_read)
else
null;
const huffman_tree_size = bytes_read - huffman_tree_start;
const total_streams_size = @as(usize, header.compressed_size.?) - huffman_tree_size;
if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection;
const stream_data = src[bytes_read .. bytes_read + total_streams_size];
const streams = try decodeStreams(header.size_format, stream_data);
consumed_count.* += bytes_read + total_streams_size;
return LiteralsSection{
.header = header,
.huffman_tree = huffman_tree,
.streams = streams,
};
},
}
}
/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the
/// number of bytes the section uses.
///
/// Errors:
/// - returns `error.MalformedLiteralsHeader` if the header is invalid
/// - returns `error.MalformedLiteralsSection` if there are errors decoding
pub fn decodeLiteralsSection(
source: anytype,
buffer: []u8,
) !LiteralsSection {
const header = try decodeLiteralsHeader(source);
switch (header.block_type) {
.raw => {
try source.readNoEof(buffer[0..header.regenerated_size]);
return LiteralsSection{
.header = header,
.huffman_tree = null,
.streams = .{ .one = buffer },
};
},
.rle => {
buffer[0] = try source.readByte();
return LiteralsSection{
.header = header,
.huffman_tree = null,
.streams = .{ .one = buffer[0..1] },
};
},
.compressed, .treeless => {
var counting_reader = std.io.countingReader(source);
const huffman_tree = if (header.block_type == .compressed)
try decodeHuffmanTree(counting_reader.reader(), buffer)
else
null;
const huffman_tree_size = counting_reader.bytes_read;
const total_streams_size = @as(usize, header.compressed_size.?) - @intCast(usize, huffman_tree_size);
if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall;
try source.readNoEof(buffer[0..total_streams_size]);
const stream_data = buffer[0..total_streams_size];
const streams = try decodeStreams(header.size_format, stream_data);
return LiteralsSection{
.header = header,
.huffman_tree = huffman_tree,
.streams = streams,
};
},
}
}
fn decodeStreams(size_format: u2, stream_data: []const u8) !LiteralsSection.Streams {
if (size_format == 0) {
return .{ .one = stream_data };
}
if (stream_data.len < 6) return error.MalformedLiteralsSection;
const stream_1_length = @as(usize, readInt(u16, stream_data[0..2]));
const stream_2_length = @as(usize, readInt(u16, stream_data[2..4]));
const stream_3_length = @as(usize, readInt(u16, stream_data[4..6]));
const stream_1_start = 6;
const stream_2_start = stream_1_start + stream_1_length;
const stream_3_start = stream_2_start + stream_2_length;
const stream_4_start = stream_3_start + stream_3_length;
return .{ .four = .{
stream_data[stream_1_start .. stream_1_start + stream_1_length],
stream_data[stream_2_start .. stream_2_start + stream_2_length],
stream_data[stream_3_start .. stream_3_start + stream_3_length],
stream_data[stream_4_start..],
} };
}
const DecodeHuffmanError = error{
MalformedHuffmanTree,
MalformedFseTable,
MalformedAccuracyLog,
};
fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, weights: *[256]u4) !usize {
var stream = std.io.limitedReader(source, compressed_size);
var bit_reader = bitReader(stream.reader());
var entries: [1 << 6]Table.Fse = undefined;
const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
error.EndOfStream => return error.MalformedFseTable,
};
const accuracy_log = std.math.log2_int_ceil(usize, table_size);
const amount = try stream.reader().readAll(buffer);
var huff_bits: ReverseBitReader = undefined;
huff_bits.init(buffer[0..amount]) catch return error.MalformedHuffmanTree;
return assignWeights(&huff_bits, accuracy_log, &entries, weights);
}
fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *[256]u4) !usize {
if (src.len < compressed_size) return error.MalformedHuffmanTree;
var stream = std.io.fixedBufferStream(src[0..compressed_size]);
var counting_reader = std.io.countingReader(stream.reader());
var bit_reader = bitReader(counting_reader.reader());
var entries: [1 << 6]Table.Fse = undefined;
const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
error.EndOfStream => return error.MalformedFseTable,
};
const accuracy_log = std.math.log2_int_ceil(usize, table_size);
const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedHuffmanTree;
var huff_data = src[start_index..compressed_size];
var huff_bits: ReverseBitReader = undefined;
huff_bits.init(huff_data) catch return error.MalformedHuffmanTree;
return assignWeights(&huff_bits, accuracy_log, &entries, weights);
}
fn assignWeights(huff_bits: *ReverseBitReader, accuracy_log: usize, entries: *[1 << 6]Table.Fse, weights: *[256]u4) !usize {
var i: usize = 0;
var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
while (i < 255) {
const even_data = entries[even_state];
var read_bits: usize = 0;
const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable;
weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
i += 1;
if (read_bits < even_data.bits) {
weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree;
i += 1;
break;
}
even_state = even_data.baseline + even_bits;
read_bits = 0;
const odd_data = entries[odd_state];
const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable;
weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree;
i += 1;
if (read_bits < odd_data.bits) {
if (i == 256) return error.MalformedHuffmanTree;
weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree;
i += 1;
break;
}
odd_state = odd_data.baseline + odd_bits;
} else return error.MalformedHuffmanTree;
return i + 1; // stream contains all but the last symbol
}
fn decodeDirectHuffmanTree(source: anytype, encoded_symbol_count: usize, weights: *[256]u4) !usize {
const weights_byte_count = (encoded_symbol_count + 1) / 2;
var i: usize = 0;
while (i < weights_byte_count) : (i += 1) {
const byte = try source.readByte();
weights[2 * i] = @intCast(u4, byte >> 4);
weights[2 * i + 1] = @intCast(u4, byte & 0xF);
}
return encoded_symbol_count + 1;
}
fn assignSymbols(weight_sorted_prefixed_symbols: []LiteralsSection.HuffmanTree.PrefixedSymbol, weights: [256]u4) usize {
for (weight_sorted_prefixed_symbols) |_, i| {
weight_sorted_prefixed_symbols[i] = .{
.symbol = @intCast(u8, i),
.weight = undefined,
.prefix = undefined,
};
}
std.sort.sort(
LiteralsSection.HuffmanTree.PrefixedSymbol,
weight_sorted_prefixed_symbols,
weights,
lessThanByWeight,
);
var prefix: u16 = 0;
var prefixed_symbol_count: usize = 0;
var sorted_index: usize = 0;
const symbol_count = weight_sorted_prefixed_symbols.len;
while (sorted_index < symbol_count) {
var symbol = weight_sorted_prefixed_symbols[sorted_index].symbol;
const weight = weights[symbol];
if (weight == 0) {
sorted_index += 1;
continue;
}
while (sorted_index < symbol_count) : ({
sorted_index += 1;
prefixed_symbol_count += 1;
prefix += 1;
}) {
symbol = weight_sorted_prefixed_symbols[sorted_index].symbol;
if (weights[symbol] != weight) {
prefix = ((prefix - 1) >> (weights[symbol] - weight)) + 1;
break;
}
weight_sorted_prefixed_symbols[prefixed_symbol_count].symbol = symbol;
weight_sorted_prefixed_symbols[prefixed_symbol_count].prefix = prefix;
weight_sorted_prefixed_symbols[prefixed_symbol_count].weight = weight;
}
}
return prefixed_symbol_count;
}
fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) LiteralsSection.HuffmanTree {
var weight_power_sum: u16 = 0;
for (weights[0 .. symbol_count - 1]) |value| {
if (value > 0) {
weight_power_sum += @as(u16, 1) << (value - 1);
}
}
// advance to next power of two (even if weight_power_sum is a power of 2)
const max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1;
const next_power_of_two = @as(u16, 1) << max_number_of_bits;
weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1;
var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined;
const prefixed_symbol_count = assignSymbols(weight_sorted_prefixed_symbols[0..symbol_count], weights.*);
const tree = LiteralsSection.HuffmanTree{
.max_bit_count = max_number_of_bits,
.symbol_count_minus_one = @intCast(u8, prefixed_symbol_count - 1),
.nodes = weight_sorted_prefixed_symbols,
};
return tree;
}
fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.HuffmanTree {
const header = try source.readByte();
var weights: [256]u4 = undefined;
const symbol_count = if (header < 128)
// FSE compressed weights
try decodeFseHuffmanTree(source, header, buffer, &weights)
else
try decodeDirectHuffmanTree(source, header - 127, &weights);
return buildHuffmanTree(&weights, symbol_count);
}
fn decodeHuffmanTreeSlice(src: []const u8, consumed_count: *usize) (error{EndOfStream} || DecodeHuffmanError)!LiteralsSection.HuffmanTree {
if (src.len == 0) return error.MalformedHuffmanTree;
const header = src[0];
var bytes_read: usize = 1;
var weights: [256]u4 = undefined;
const symbol_count = if (header < 128) count: {
// FSE compressed weights
bytes_read += header;
break :count try decodeFseHuffmanTreeSlice(src[1..], header, &weights);
} else count: {
var fbs = std.io.fixedBufferStream(src[1..]);
defer bytes_read += fbs.pos;
break :count try decodeDirectHuffmanTree(fbs.reader(), header - 127, &weights);
};
consumed_count.* += bytes_read;
return buildHuffmanTree(&weights, symbol_count);
}
fn lessThanByWeight(
weights: [256]u4,
lhs: LiteralsSection.HuffmanTree.PrefixedSymbol,
rhs: LiteralsSection.HuffmanTree.PrefixedSymbol,
) bool {
// NOTE: this function relies on the use of a stable sorting algorithm,
// otherwise a special case of if (weights[lhs] == weights[rhs]) return lhs < rhs;
// should be added
return weights[lhs.symbol] < weights[rhs.symbol];
}
/// Decode a literals section header.
pub fn decodeLiteralsHeader(source: anytype) !LiteralsSection.Header {
const byte0 = try source.readByte();
const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11);
const size_format = @intCast(u2, (byte0 & 0b1100) >> 2);
var regenerated_size: u20 = undefined;
var compressed_size: ?u18 = null;
switch (block_type) {
.raw, .rle => {
switch (size_format) {
0, 2 => {
regenerated_size = byte0 >> 3;
},
1 => regenerated_size = (byte0 >> 4) + (@as(u20, try source.readByte()) << 4),
3 => regenerated_size = (byte0 >> 4) +
(@as(u20, try source.readByte()) << 4) +
(@as(u20, try source.readByte()) << 12),
}
},
.compressed, .treeless => {
const byte1 = try source.readByte();
const byte2 = try source.readByte();
switch (size_format) {
0, 1 => {
regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4);
compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2);
},
2 => {
const byte3 = try source.readByte();
regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12);
compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6);
},
3 => {
const byte3 = try source.readByte();
const byte4 = try source.readByte();
regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12);
compressed_size = ((byte2 & 0b11000000) >> 6) + (@as(u18, byte3) << 2) + (@as(u18, byte4) << 10);
},
}
},
}
return LiteralsSection.Header{
.block_type = block_type,
.size_format = size_format,
.regenerated_size = regenerated_size,
.compressed_size = compressed_size,
};
}
/// Decode a sequences section header.
///
/// Errors:
/// - returns `error.ReservedBitSet` is the reserved bit is set
/// - returns `error.MalformedSequencesHeader` if the header is invalid
pub fn decodeSequencesHeader(
source: anytype,
) !SequencesSection.Header {
var sequence_count: u24 = undefined;
const byte0 = try source.readByte();
if (byte0 == 0) {
return SequencesSection.Header{
.sequence_count = 0,
.offsets = undefined,
.match_lengths = undefined,
.literal_lengths = undefined,
};
} else if (byte0 < 128) {
sequence_count = byte0;
} else if (byte0 < 255) {
sequence_count = (@as(u24, (byte0 - 128)) << 8) + try source.readByte();
} else {
sequence_count = (try source.readByte()) + (@as(u24, try source.readByte()) << 8) + 0x7F00;
}
const compression_modes = try source.readByte();
const matches_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00001100) >> 2);
const offsets_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00110000) >> 4);
const literal_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b11000000) >> 6);
if (compression_modes & 0b11 != 0) return error.ReservedBitSet;
return SequencesSection.Header{
.sequence_count = sequence_count,
.offsets = offsets_mode,
.match_lengths = matches_mode,
.literal_lengths = literal_mode,
};
}
fn buildFseTable(values: []const u16, entries: []Table.Fse) !void {
const total_probability = @intCast(u16, entries.len);
const accuracy_log = std.math.log2_int(u16, total_probability);
assert(total_probability <= 1 << 9);
var less_than_one_count: usize = 0;
for (values) |value, i| {
if (value == 0) {
entries[entries.len - 1 - less_than_one_count] = Table.Fse{
.symbol = @intCast(u8, i),
.baseline = 0,
.bits = accuracy_log,
};
less_than_one_count += 1;
}
}
var position: usize = 0;
var temp_states: [1 << 9]u16 = undefined;
for (values) |value, symbol| {
if (value == 0 or value == 1) continue;
const probability = value - 1;
const state_share_dividend = std.math.ceilPowerOfTwo(u16, probability) catch
return error.MalformedFseTable;
const share_size = @divExact(total_probability, state_share_dividend);
const double_state_count = state_share_dividend - probability;
const single_state_count = probability - double_state_count;
const share_size_log = std.math.log2_int(u16, share_size);
var i: u16 = 0;
while (i < probability) : (i += 1) {
temp_states[i] = @intCast(u16, position);
position += (entries.len >> 1) + (entries.len >> 3) + 3;
position &= entries.len - 1;
while (position >= entries.len - less_than_one_count) {
position += (entries.len >> 1) + (entries.len >> 3) + 3;
position &= entries.len - 1;
}
}
std.sort.sort(u16, temp_states[0..probability], {}, std.sort.asc(u16));
i = 0;
while (i < probability) : (i += 1) {
entries[temp_states[i]] = if (i < double_state_count) Table.Fse{
.symbol = @intCast(u8, symbol),
.bits = share_size_log + 1,
.baseline = single_state_count * share_size + i * 2 * share_size,
} else Table.Fse{
.symbol = @intCast(u8, symbol),
.bits = share_size_log,
.baseline = (i - double_state_count) * share_size,
};
}
}
}
fn decodeFseTable(
bit_reader: anytype,
expected_symbol_count: usize,
max_accuracy_log: u4,
entries: []Table.Fse,
) !usize {
const accuracy_log_biased = try bit_reader.readBitsNoEof(u4, 4);
if (accuracy_log_biased > max_accuracy_log -| 5) return error.MalformedAccuracyLog;
const accuracy_log = accuracy_log_biased + 5;
var values: [256]u16 = undefined;
var value_count: usize = 0;
const total_probability = @as(u16, 1) << accuracy_log;
var accumulated_probability: u16 = 0;
while (accumulated_probability < total_probability) {
// WARNING: The RFC in poorly worded, and would suggest std.math.log2_int_ceil is correct here,
// but power of two (remaining probabilities + 1) need max bits set to 1 more.
const max_bits = std.math.log2_int(u16, total_probability - accumulated_probability + 1) + 1;
const small = try bit_reader.readBitsNoEof(u16, max_bits - 1);
const cutoff = (@as(u16, 1) << max_bits) - 1 - (total_probability - accumulated_probability + 1);
const value = if (small < cutoff)
small
else value: {
const value_read = small + (try bit_reader.readBitsNoEof(u16, 1) << (max_bits - 1));
break :value if (value_read < @as(u16, 1) << (max_bits - 1))
value_read
else
value_read - cutoff;
};
accumulated_probability += if (value != 0) value - 1 else 1;
values[value_count] = value;
value_count += 1;
if (value == 1) {
while (true) {
const repeat_flag = try bit_reader.readBitsNoEof(u2, 2);
var i: usize = 0;
while (i < repeat_flag) : (i += 1) {
values[value_count] = 1;
value_count += 1;
}
if (repeat_flag < 3) break;
}
}
}
bit_reader.alignToByte();
if (value_count < 2) return error.MalformedFseTable;
if (accumulated_probability != total_probability) return error.MalformedFseTable;
if (value_count > expected_symbol_count) return error.MalformedFseTable;
const table_size = total_probability;
try buildFseTable(values[0..value_count], entries[0..table_size]);
return table_size;
}
const ReversedByteReader = struct {
remaining_bytes: usize,
bytes: []const u8,
const Reader = std.io.Reader(*ReversedByteReader, error{}, readFn);
fn init(bytes: []const u8) ReversedByteReader {
return .{
.bytes = bytes,
.remaining_bytes = bytes.len,
};
}
fn reader(self: *ReversedByteReader) Reader {
return .{ .context = self };
}
fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize {
if (ctx.remaining_bytes == 0) return 0;
const byte_index = ctx.remaining_bytes - 1;
buffer[0] = ctx.bytes[byte_index];
// buffer[0] = @bitReverse(ctx.bytes[byte_index]);
ctx.remaining_bytes = byte_index;
return 1;
}
};
/// A bit reader for reading the reversed bit streams used to encode
/// FSE compressed data.
pub const ReverseBitReader = struct {
byte_reader: ReversedByteReader,
bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader),
pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
self.byte_reader = ReversedByteReader.init(bytes);
self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {}
}
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U {
return self.bit_reader.readBitsNoEof(U, num_bits);
}
pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) error{}!U {
return try self.bit_reader.readBits(U, num_bits, out_bits);
}
pub fn alignToByte(self: *@This()) void {
self.bit_reader.alignToByte();
}
};
fn BitReader(comptime Reader: type) type {
return struct {
underlying: std.io.BitReader(.Little, Reader),
fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
return self.underlying.readBitsNoEof(U, num_bits);
}
fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
return self.underlying.readBits(U, num_bits, out_bits);
}
fn alignToByte(self: *@This()) void {
self.underlying.alignToByte();
}
};
}
pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
return .{ .underlying = std.io.bitReader(.Little, reader) };
}
test {
std.testing.refAllDecls(@This());
}
test buildFseTable {
const literals_length_default_values = [36]u16{
5, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 2, 2, 2, 2, 2,
0, 0, 0, 0,
};
const match_lengths_default_values = [53]u16{
2, 5, 4, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0,
0, 0, 0, 0, 0,
};
const offset_codes_default_values = [29]u16{
2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0,
};
var entries: [64]Table.Fse = undefined;
try buildFseTable(&literals_length_default_values, &entries);
try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_literal_fse_table.fse, &entries);
try buildFseTable(&match_lengths_default_values, &entries);
try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_match_fse_table.fse, &entries);
try buildFseTable(&offset_codes_default_values, entries[0..32]);
try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_offset_fse_table.fse, entries[0..32]);
}