std.compress.zstandard: split decompressor into multiple files

This commit is contained in:
dweiller 2023-02-02 18:44:01 +11:00
parent 6e3e72884b
commit 7e2755646f
6 changed files with 1538 additions and 1466 deletions

View File

@ -13,7 +13,7 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
allocator: Allocator,
in_reader: ReaderType,
decode_state: decompress.DecodeState,
decode_state: decompress.block.DecodeState,
frame_context: decompress.FrameContext,
buffer: RingBuffer,
last_block: bool,
@ -24,7 +24,7 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
sequence_buffer: []u8,
checksum: if (verify_checksum) ?u32 else void,
pub const Error = ReaderType.Error || error{ MalformedBlock, MalformedFrame, EndOfStream };
pub const Error = ReaderType.Error || error{ MalformedBlock, MalformedFrame };
pub const Reader = std.io.Reader(*Self, Error, read);
@ -34,21 +34,41 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
.zstandard => {
const frame_context = context: {
const frame_header = try decompress.decodeZStandardHeader(source);
break :context try decompress.FrameContext.init(frame_header, window_size_max, verify_checksum);
break :context try decompress.FrameContext.init(
frame_header,
window_size_max,
verify_checksum,
);
};
const literal_fse_buffer = try allocator.alloc(types.compressed_block.Table.Fse, types.compressed_block.table_size_max.literal);
const literal_fse_buffer = try allocator.alloc(
types.compressed_block.Table.Fse,
types.compressed_block.table_size_max.literal,
);
errdefer allocator.free(literal_fse_buffer);
const match_fse_buffer = try allocator.alloc(types.compressed_block.Table.Fse, types.compressed_block.table_size_max.match);
const match_fse_buffer = try allocator.alloc(
types.compressed_block.Table.Fse,
types.compressed_block.table_size_max.match,
);
errdefer allocator.free(match_fse_buffer);
const offset_fse_buffer = try allocator.alloc(types.compressed_block.Table.Fse, types.compressed_block.table_size_max.offset);
const offset_fse_buffer = try allocator.alloc(
types.compressed_block.Table.Fse,
types.compressed_block.table_size_max.offset,
);
errdefer allocator.free(offset_fse_buffer);
const decode_state = decompress.DecodeState.init(literal_fse_buffer, match_fse_buffer, offset_fse_buffer);
const decode_state = decompress.block.DecodeState.init(
literal_fse_buffer,
match_fse_buffer,
offset_fse_buffer,
);
const buffer = try RingBuffer.init(allocator, frame_context.window_size);
const literals_data = try allocator.alloc(u8, window_size_max);
errdefer allocator.free(literals_data);
const sequence_data = try allocator.alloc(u8, window_size_max);
errdefer allocator.free(sequence_data);
@ -87,10 +107,10 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
if (buffer.len == 0) return 0;
if (self.buffer.isEmpty() and !self.last_block) {
const header_bytes = try self.in_reader.readBytesNoEof(3);
const block_header = decompress.decodeBlockHeader(&header_bytes);
const header_bytes = self.in_reader.readBytesNoEof(3) catch return error.MalformedFrame;
const block_header = decompress.block.decodeBlockHeader(&header_bytes);
decompress.decodeBlockReader(
decompress.block.decodeBlockReader(
&self.buffer,
self.in_reader,
block_header,

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,154 @@
const std = @import("std");
const assert = std.debug.assert;
const types = @import("../types.zig");
const Table = types.compressed_block.Table;
pub fn decodeFseTable(
bit_reader: anytype,
expected_symbol_count: usize,
max_accuracy_log: u4,
entries: []Table.Fse,
) !usize {
const accuracy_log_biased = try bit_reader.readBitsNoEof(u4, 4);
if (accuracy_log_biased > max_accuracy_log -| 5) return error.MalformedAccuracyLog;
const accuracy_log = accuracy_log_biased + 5;
var values: [256]u16 = undefined;
var value_count: usize = 0;
const total_probability = @as(u16, 1) << accuracy_log;
var accumulated_probability: u16 = 0;
while (accumulated_probability < total_probability) {
// WARNING: The RFC in poorly worded, and would suggest std.math.log2_int_ceil is correct here,
// but power of two (remaining probabilities + 1) need max bits set to 1 more.
const max_bits = std.math.log2_int(u16, total_probability - accumulated_probability + 1) + 1;
const small = try bit_reader.readBitsNoEof(u16, max_bits - 1);
const cutoff = (@as(u16, 1) << max_bits) - 1 - (total_probability - accumulated_probability + 1);
const value = if (small < cutoff)
small
else value: {
const value_read = small + (try bit_reader.readBitsNoEof(u16, 1) << (max_bits - 1));
break :value if (value_read < @as(u16, 1) << (max_bits - 1))
value_read
else
value_read - cutoff;
};
accumulated_probability += if (value != 0) value - 1 else 1;
values[value_count] = value;
value_count += 1;
if (value == 1) {
while (true) {
const repeat_flag = try bit_reader.readBitsNoEof(u2, 2);
var i: usize = 0;
while (i < repeat_flag) : (i += 1) {
values[value_count] = 1;
value_count += 1;
}
if (repeat_flag < 3) break;
}
}
}
bit_reader.alignToByte();
if (value_count < 2) return error.MalformedFseTable;
if (accumulated_probability != total_probability) return error.MalformedFseTable;
if (value_count > expected_symbol_count) return error.MalformedFseTable;
const table_size = total_probability;
try buildFseTable(values[0..value_count], entries[0..table_size]);
return table_size;
}
fn buildFseTable(values: []const u16, entries: []Table.Fse) !void {
const total_probability = @intCast(u16, entries.len);
const accuracy_log = std.math.log2_int(u16, total_probability);
assert(total_probability <= 1 << 9);
var less_than_one_count: usize = 0;
for (values) |value, i| {
if (value == 0) {
entries[entries.len - 1 - less_than_one_count] = Table.Fse{
.symbol = @intCast(u8, i),
.baseline = 0,
.bits = accuracy_log,
};
less_than_one_count += 1;
}
}
var position: usize = 0;
var temp_states: [1 << 9]u16 = undefined;
for (values) |value, symbol| {
if (value == 0 or value == 1) continue;
const probability = value - 1;
const state_share_dividend = std.math.ceilPowerOfTwo(u16, probability) catch
return error.MalformedFseTable;
const share_size = @divExact(total_probability, state_share_dividend);
const double_state_count = state_share_dividend - probability;
const single_state_count = probability - double_state_count;
const share_size_log = std.math.log2_int(u16, share_size);
var i: u16 = 0;
while (i < probability) : (i += 1) {
temp_states[i] = @intCast(u16, position);
position += (entries.len >> 1) + (entries.len >> 3) + 3;
position &= entries.len - 1;
while (position >= entries.len - less_than_one_count) {
position += (entries.len >> 1) + (entries.len >> 3) + 3;
position &= entries.len - 1;
}
}
std.sort.sort(u16, temp_states[0..probability], {}, std.sort.asc(u16));
i = 0;
while (i < probability) : (i += 1) {
entries[temp_states[i]] = if (i < double_state_count) Table.Fse{
.symbol = @intCast(u8, symbol),
.bits = share_size_log + 1,
.baseline = single_state_count * share_size + i * 2 * share_size,
} else Table.Fse{
.symbol = @intCast(u8, symbol),
.bits = share_size_log,
.baseline = (i - double_state_count) * share_size,
};
}
}
}
test buildFseTable {
const literals_length_default_values = [36]u16{
5, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 2, 2, 2, 2, 2,
0, 0, 0, 0,
};
const match_lengths_default_values = [53]u16{
2, 5, 4, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0,
0, 0, 0, 0, 0,
};
const offset_codes_default_values = [29]u16{
2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0,
};
var entries: [64]Table.Fse = undefined;
try buildFseTable(&literals_length_default_values, &entries);
try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_literal_fse_table.fse, &entries);
try buildFseTable(&match_lengths_default_values, &entries);
try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_match_fse_table.fse, &entries);
try buildFseTable(&offset_codes_default_values, entries[0..32]);
try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_offset_fse_table.fse, entries[0..32]);
}

View File

@ -0,0 +1,212 @@
const std = @import("std");
const types = @import("../types.zig");
const LiteralsSection = types.compressed_block.LiteralsSection;
const Table = types.compressed_block.Table;
const readers = @import("../readers.zig");
const decodeFseTable = @import("fse.zig").decodeFseTable;
pub const Error = error{
MalformedHuffmanTree,
MalformedFseTable,
MalformedAccuracyLog,
EndOfStream,
};
fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, weights: *[256]u4) !usize {
var stream = std.io.limitedReader(source, compressed_size);
var bit_reader = readers.bitReader(stream.reader());
var entries: [1 << 6]Table.Fse = undefined;
const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
error.EndOfStream => return error.MalformedFseTable,
};
const accuracy_log = std.math.log2_int_ceil(usize, table_size);
const amount = try stream.reader().readAll(buffer);
var huff_bits: readers.ReverseBitReader = undefined;
huff_bits.init(buffer[0..amount]) catch return error.MalformedHuffmanTree;
return assignWeights(&huff_bits, accuracy_log, &entries, weights);
}
fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *[256]u4) !usize {
if (src.len < compressed_size) return error.MalformedHuffmanTree;
var stream = std.io.fixedBufferStream(src[0..compressed_size]);
var counting_reader = std.io.countingReader(stream.reader());
var bit_reader = readers.bitReader(counting_reader.reader());
var entries: [1 << 6]Table.Fse = undefined;
const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
error.EndOfStream => return error.MalformedFseTable,
};
const accuracy_log = std.math.log2_int_ceil(usize, table_size);
const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedHuffmanTree;
var huff_data = src[start_index..compressed_size];
var huff_bits: readers.ReverseBitReader = undefined;
huff_bits.init(huff_data) catch return error.MalformedHuffmanTree;
return assignWeights(&huff_bits, accuracy_log, &entries, weights);
}
fn assignWeights(huff_bits: *readers.ReverseBitReader, accuracy_log: usize, entries: *[1 << 6]Table.Fse, weights: *[256]u4) !usize {
var i: usize = 0;
var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
while (i < 255) {
const even_data = entries[even_state];
var read_bits: usize = 0;
const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable;
weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
i += 1;
if (read_bits < even_data.bits) {
weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree;
i += 1;
break;
}
even_state = even_data.baseline + even_bits;
read_bits = 0;
const odd_data = entries[odd_state];
const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable;
weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree;
i += 1;
if (read_bits < odd_data.bits) {
if (i == 256) return error.MalformedHuffmanTree;
weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree;
i += 1;
break;
}
odd_state = odd_data.baseline + odd_bits;
} else return error.MalformedHuffmanTree;
return i + 1; // stream contains all but the last symbol
}
fn decodeDirectHuffmanTree(source: anytype, encoded_symbol_count: usize, weights: *[256]u4) !usize {
const weights_byte_count = (encoded_symbol_count + 1) / 2;
var i: usize = 0;
while (i < weights_byte_count) : (i += 1) {
const byte = try source.readByte();
weights[2 * i] = @intCast(u4, byte >> 4);
weights[2 * i + 1] = @intCast(u4, byte & 0xF);
}
return encoded_symbol_count + 1;
}
fn assignSymbols(weight_sorted_prefixed_symbols: []LiteralsSection.HuffmanTree.PrefixedSymbol, weights: [256]u4) usize {
for (weight_sorted_prefixed_symbols) |_, i| {
weight_sorted_prefixed_symbols[i] = .{
.symbol = @intCast(u8, i),
.weight = undefined,
.prefix = undefined,
};
}
std.sort.sort(
LiteralsSection.HuffmanTree.PrefixedSymbol,
weight_sorted_prefixed_symbols,
weights,
lessThanByWeight,
);
var prefix: u16 = 0;
var prefixed_symbol_count: usize = 0;
var sorted_index: usize = 0;
const symbol_count = weight_sorted_prefixed_symbols.len;
while (sorted_index < symbol_count) {
var symbol = weight_sorted_prefixed_symbols[sorted_index].symbol;
const weight = weights[symbol];
if (weight == 0) {
sorted_index += 1;
continue;
}
while (sorted_index < symbol_count) : ({
sorted_index += 1;
prefixed_symbol_count += 1;
prefix += 1;
}) {
symbol = weight_sorted_prefixed_symbols[sorted_index].symbol;
if (weights[symbol] != weight) {
prefix = ((prefix - 1) >> (weights[symbol] - weight)) + 1;
break;
}
weight_sorted_prefixed_symbols[prefixed_symbol_count].symbol = symbol;
weight_sorted_prefixed_symbols[prefixed_symbol_count].prefix = prefix;
weight_sorted_prefixed_symbols[prefixed_symbol_count].weight = weight;
}
}
return prefixed_symbol_count;
}
fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) LiteralsSection.HuffmanTree {
var weight_power_sum: u16 = 0;
for (weights[0 .. symbol_count - 1]) |value| {
if (value > 0) {
weight_power_sum += @as(u16, 1) << (value - 1);
}
}
// advance to next power of two (even if weight_power_sum is a power of 2)
const max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1;
const next_power_of_two = @as(u16, 1) << max_number_of_bits;
weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1;
var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined;
const prefixed_symbol_count = assignSymbols(weight_sorted_prefixed_symbols[0..symbol_count], weights.*);
const tree = LiteralsSection.HuffmanTree{
.max_bit_count = max_number_of_bits,
.symbol_count_minus_one = @intCast(u8, prefixed_symbol_count - 1),
.nodes = weight_sorted_prefixed_symbols,
};
return tree;
}
pub fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.HuffmanTree {
const header = try source.readByte();
var weights: [256]u4 = undefined;
const symbol_count = if (header < 128)
// FSE compressed weights
try decodeFseHuffmanTree(source, header, buffer, &weights)
else
try decodeDirectHuffmanTree(source, header - 127, &weights);
return buildHuffmanTree(&weights, symbol_count);
}
pub fn decodeHuffmanTreeSlice(src: []const u8, consumed_count: *usize) Error!LiteralsSection.HuffmanTree {
if (src.len == 0) return error.MalformedHuffmanTree;
const header = src[0];
var bytes_read: usize = 1;
var weights: [256]u4 = undefined;
const symbol_count = if (header < 128) count: {
// FSE compressed weights
bytes_read += header;
break :count try decodeFseHuffmanTreeSlice(src[1..], header, &weights);
} else count: {
var fbs = std.io.fixedBufferStream(src[1..]);
defer bytes_read += fbs.pos;
break :count try decodeDirectHuffmanTree(fbs.reader(), header - 127, &weights);
};
consumed_count.* += bytes_read;
return buildHuffmanTree(&weights, symbol_count);
}
fn lessThanByWeight(
weights: [256]u4,
lhs: LiteralsSection.HuffmanTree.PrefixedSymbol,
rhs: LiteralsSection.HuffmanTree.PrefixedSymbol,
) bool {
// NOTE: this function relies on the use of a stable sorting algorithm,
// otherwise a special case of if (weights[lhs] == weights[rhs]) return lhs < rhs;
// should be added
return weights[lhs.symbol] < weights[rhs.symbol];
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,75 @@
const std = @import("std");
pub const ReversedByteReader = struct {
remaining_bytes: usize,
bytes: []const u8,
const Reader = std.io.Reader(*ReversedByteReader, error{}, readFn);
pub fn init(bytes: []const u8) ReversedByteReader {
return .{
.bytes = bytes,
.remaining_bytes = bytes.len,
};
}
pub fn reader(self: *ReversedByteReader) Reader {
return .{ .context = self };
}
fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize {
if (ctx.remaining_bytes == 0) return 0;
const byte_index = ctx.remaining_bytes - 1;
buffer[0] = ctx.bytes[byte_index];
// buffer[0] = @bitReverse(ctx.bytes[byte_index]);
ctx.remaining_bytes = byte_index;
return 1;
}
};
/// A bit reader for reading the reversed bit streams used to encode
/// FSE compressed data.
pub const ReverseBitReader = struct {
byte_reader: ReversedByteReader,
bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader),
pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
self.byte_reader = ReversedByteReader.init(bytes);
self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {}
}
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U {
return self.bit_reader.readBitsNoEof(U, num_bits);
}
pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) error{}!U {
return try self.bit_reader.readBits(U, num_bits, out_bits);
}
pub fn alignToByte(self: *@This()) void {
self.bit_reader.alignToByte();
}
};
pub fn BitReader(comptime Reader: type) type {
return struct {
underlying: std.io.BitReader(.Little, Reader),
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
return self.underlying.readBitsNoEof(U, num_bits);
}
pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
return self.underlying.readBits(U, num_bits, out_bits);
}
pub fn alignToByte(self: *@This()) void {
self.underlying.alignToByte();
}
};
}
pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
return .{ .underlying = std.io.bitReader(.Little, reader) };
}