std.compress.zstandard: clean up api

This commit is contained in:
dweiller 2023-01-23 16:26:03 +11:00
parent cbfaa876d4
commit fc64c279a4
3 changed files with 113 additions and 101 deletions

View File

@ -1,6 +1,7 @@
const std = @import("std");
pub const decompress = @import("zstandard/decompress.zig");
pub usingnamespace @import("zstandard/types.zig");
test "decompression" {
const uncompressed = @embedFile("testdata/rfc8478.txt");

View File

@ -3,10 +3,10 @@ const assert = std.debug.assert;
const types = @import("types.zig");
const frame = types.frame;
const Literals = types.compressed_block.Literals;
const Sequences = types.compressed_block.Sequences;
const LiteralsSection = types.compressed_block.LiteralsSection;
const SequencesSection = types.compressed_block.SequencesSection;
const Table = types.compressed_block.Table;
const RingBuffer = @import("RingBuffer.zig");
pub const RingBuffer = @import("RingBuffer.zig");
const readInt = std.mem.readIntLittle;
const readIntSlice = std.mem.readIntSliceLittle;
@ -55,7 +55,7 @@ pub fn decodeFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWrit
};
}
const DecodeState = struct {
pub const DecodeState = struct {
repeat_offsets: [3]u32,
offset: StateData(8),
@ -70,7 +70,7 @@ const DecodeState = struct {
literal_stream_reader: ReverseBitReader,
literal_stream_index: usize,
huffman_tree: ?Literals.HuffmanTree,
huffman_tree: ?LiteralsSection.HuffmanTree,
literal_written_count: usize,
@ -84,7 +84,55 @@ const DecodeState = struct {
};
}
fn readInitialState(self: *DecodeState, bit_reader: anytype) !void {
pub fn prepare(
self: *DecodeState,
src: []const u8,
literals: LiteralsSection,
sequences_header: SequencesSection.Header,
) !usize {
if (literals.huffman_tree) |tree| {
self.huffman_tree = tree;
} else if (literals.header.block_type == .treeless and self.huffman_tree == null) {
return error.TreelessLiteralsFirst;
}
switch (literals.header.block_type) {
.raw, .rle => {},
.compressed, .treeless => {
self.literal_stream_index = 0;
switch (literals.streams) {
.one => |slice| try self.initLiteralStream(slice),
.four => |streams| try self.initLiteralStream(streams[0]),
}
},
}
if (sequences_header.sequence_count > 0) {
var bytes_read = try self.updateFseTable(
src,
.literal,
sequences_header.literal_lengths,
);
bytes_read += try self.updateFseTable(
src[bytes_read..],
.offset,
sequences_header.offsets,
);
bytes_read += try self.updateFseTable(
src[bytes_read..],
.match,
sequences_header.match_lengths,
);
self.fse_tables_undefined = false;
return bytes_read;
}
return 0;
}
pub fn readInitialFseState(self: *DecodeState, bit_reader: anytype) !void {
self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log);
self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log);
self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log);
@ -130,7 +178,7 @@ const DecodeState = struct {
self: *DecodeState,
src: []const u8,
comptime choice: DataType,
mode: Sequences.Header.Mode,
mode: SequencesSection.Header.Mode,
) !usize {
const field_name = @tagName(choice);
switch (mode) {
@ -213,7 +261,13 @@ const DecodeState = struct {
};
}
fn executeSequenceSlice(self: *DecodeState, dest: []u8, write_pos: usize, literals: Literals, sequence: Sequence) !void {
fn executeSequenceSlice(
self: *DecodeState,
dest: []u8,
write_pos: usize,
literals: LiteralsSection,
sequence: Sequence,
) !void {
try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length);
// TODO: should we validate offset against max_window_size?
@ -225,7 +279,12 @@ const DecodeState = struct {
std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]);
}
fn executeSequenceRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, sequence: Sequence) !void {
fn executeSequenceRingBuffer(
self: *DecodeState,
dest: *RingBuffer,
literals: LiteralsSection,
sequence: Sequence,
) !void {
try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length);
// TODO: check that ring buffer window is full enough for match copies
const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length);
@ -234,11 +293,11 @@ const DecodeState = struct {
for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
}
fn decodeSequenceSlice(
pub fn decodeSequenceSlice(
self: *DecodeState,
dest: []u8,
write_pos: usize,
literals: Literals,
literals: LiteralsSection,
bit_reader: anytype,
last_sequence: bool,
) !usize {
@ -255,10 +314,10 @@ const DecodeState = struct {
return sequence.match_length + sequence.literal_length;
}
fn decodeSequenceRingBuffer(
pub fn decodeSequenceRingBuffer(
self: *DecodeState,
dest: *RingBuffer,
literals: Literals,
literals: LiteralsSection,
bit_reader: anytype,
last_sequence: bool,
) !usize {
@ -280,7 +339,7 @@ const DecodeState = struct {
return sequence.match_length + sequence.literal_length;
}
fn nextLiteralMultiStream(self: *DecodeState, literals: Literals) !void {
fn nextLiteralMultiStream(self: *DecodeState, literals: LiteralsSection) !void {
self.literal_stream_index += 1;
try self.initLiteralStream(literals.streams.four[self.literal_stream_index]);
}
@ -290,7 +349,7 @@ const DecodeState = struct {
try self.literal_stream_reader.init(bytes);
}
fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void {
pub fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: LiteralsSection, len: usize) !void {
if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
switch (literals.header.block_type) {
.raw => {
@ -310,7 +369,7 @@ const DecodeState = struct {
// const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
const huffman_tree = self.huffman_tree orelse unreachable;
const max_bit_count = huffman_tree.max_bit_count;
const starting_bit_count = Literals.HuffmanTree.weightToBitCount(
const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
max_bit_count,
);
@ -345,7 +404,7 @@ const DecodeState = struct {
},
.index => |index| {
huffman_tree_index = index;
const bit_count = Literals.HuffmanTree.weightToBitCount(
const bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
huffman_tree.nodes[index].weight,
max_bit_count,
);
@ -359,7 +418,7 @@ const DecodeState = struct {
}
}
fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, len: usize) !void {
pub fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: LiteralsSection, len: usize) !void {
if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
switch (literals.header.block_type) {
.raw => {
@ -378,7 +437,7 @@ const DecodeState = struct {
// const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
const huffman_tree = self.huffman_tree orelse unreachable;
const max_bit_count = huffman_tree.max_bit_count;
const starting_bit_count = Literals.HuffmanTree.weightToBitCount(
const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
max_bit_count,
);
@ -413,7 +472,7 @@ const DecodeState = struct {
},
.index => |index| {
huffman_tree_index = index;
const bit_count = Literals.HuffmanTree.weightToBitCount(
const bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
huffman_tree.nodes[index].weight,
max_bit_count,
);
@ -647,54 +706,6 @@ fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21,
return block_size;
}
fn prepareDecodeState(
decode_state: *DecodeState,
src: []const u8,
literals: Literals,
sequences_header: Sequences.Header,
) !usize {
if (literals.huffman_tree) |tree| {
decode_state.huffman_tree = tree;
} else if (literals.header.block_type == .treeless and decode_state.huffman_tree == null) {
return error.TreelessLiteralsFirst;
}
switch (literals.header.block_type) {
.raw, .rle => {},
.compressed, .treeless => {
decode_state.literal_stream_index = 0;
switch (literals.streams) {
.one => |slice| try decode_state.initLiteralStream(slice),
.four => |streams| try decode_state.initLiteralStream(streams[0]),
}
},
}
if (sequences_header.sequence_count > 0) {
var bytes_read = try decode_state.updateFseTable(
src,
.literal,
sequences_header.literal_lengths,
);
bytes_read += try decode_state.updateFseTable(
src[bytes_read..],
.offset,
sequences_header.offsets,
);
bytes_read += try decode_state.updateFseTable(
src[bytes_read..],
.match,
sequences_header.match_lengths,
);
decode_state.fse_tables_undefined = false;
return bytes_read;
}
return 0;
}
pub fn decodeBlock(
dest: []u8,
src: []const u8,
@ -715,7 +726,7 @@ pub fn decodeBlock(
const literals = try decodeLiteralsSection(src, &bytes_read);
const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header);
bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header);
var bytes_written: usize = 0;
if (sequences_header.sequence_count > 0) {
@ -723,7 +734,7 @@ pub fn decodeBlock(
var bit_stream: ReverseBitReader = undefined;
try bit_stream.init(bit_stream_bytes);
try decode_state.readInitialState(&bit_stream);
try decode_state.readInitialFseState(&bit_stream);
var i: usize = 0;
while (i < sequences_header.sequence_count) : (i += 1) {
@ -780,7 +791,7 @@ pub fn decodeBlockRingBuffer(
const literals = try decodeLiteralsSection(src, &bytes_read);
const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header);
bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header);
var bytes_written: usize = 0;
if (sequences_header.sequence_count > 0) {
@ -788,7 +799,7 @@ pub fn decodeBlockRingBuffer(
var bit_stream: ReverseBitReader = undefined;
try bit_stream.init(bit_stream_bytes);
try decode_state.readInitialState(&bit_stream);
try decode_state.readInitialFseState(&bit_stream);
var i: usize = 0;
while (i < sequences_header.sequence_count) : (i += 1) {
@ -928,7 +939,7 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header {
};
}
pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals {
pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsSection {
// TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
var bytes_read: usize = 0;
const header = decodeLiteralsHeader(src, &bytes_read);
@ -936,7 +947,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
.raw => {
const stream = src[bytes_read .. bytes_read + header.regenerated_size];
consumed_count.* += header.regenerated_size + bytes_read;
return Literals{
return LiteralsSection{
.header = header,
.huffman_tree = null,
.streams = .{ .one = stream },
@ -945,7 +956,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
.rle => {
const stream = src[bytes_read .. bytes_read + 1];
consumed_count.* += 1 + bytes_read;
return Literals{
return LiteralsSection{
.header = header,
.huffman_tree = null,
.streams = .{ .one = stream },
@ -966,7 +977,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
const stream = src[bytes_read .. bytes_read + total_streams_size];
bytes_read += total_streams_size;
consumed_count.* += bytes_read;
return Literals{
return LiteralsSection{
.header = header,
.huffman_tree = huffman_tree,
.streams = .{ .one = stream },
@ -988,7 +999,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
consumed_count.* += total_streams_size + bytes_read;
return Literals{
return LiteralsSection{
.header = header,
.huffman_tree = huffman_tree,
.streams = .{ .four = .{
@ -1002,7 +1013,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
}
}
fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanTree {
fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.HuffmanTree {
var bytes_read: usize = 0;
bytes_read += 1;
const header = src[0];
@ -1094,7 +1105,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT
weights[symbol_count - 1] = @intCast(u4, std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1);
log.debug("weights[{d}] = {d}", .{ symbol_count - 1, weights[symbol_count - 1] });
var weight_sorted_prefixed_symbols: [256]Literals.HuffmanTree.PrefixedSymbol = undefined;
var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined;
for (weight_sorted_prefixed_symbols[0..symbol_count]) |_, i| {
weight_sorted_prefixed_symbols[i] = .{
.symbol = @intCast(u8, i),
@ -1104,7 +1115,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT
}
std.sort.sort(
Literals.HuffmanTree.PrefixedSymbol,
LiteralsSection.HuffmanTree.PrefixedSymbol,
weight_sorted_prefixed_symbols[0..symbol_count],
weights,
lessThanByWeight,
@ -1137,7 +1148,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT
}
}
consumed_count.* += bytes_read;
const tree = Literals.HuffmanTree{
const tree = LiteralsSection.HuffmanTree{
.max_bit_count = max_number_of_bits,
.symbol_count_minus_one = @intCast(u8, prefixed_symbol_count - 1),
.nodes = weight_sorted_prefixed_symbols,
@ -1148,8 +1159,8 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT
fn lessThanByWeight(
weights: [256]u4,
lhs: Literals.HuffmanTree.PrefixedSymbol,
rhs: Literals.HuffmanTree.PrefixedSymbol,
lhs: LiteralsSection.HuffmanTree.PrefixedSymbol,
rhs: LiteralsSection.HuffmanTree.PrefixedSymbol,
) bool {
// NOTE: this function relies on the use of a stable sorting algorithm,
// otherwise a special case of if (weights[lhs] == weights[rhs]) return lhs < rhs;
@ -1157,11 +1168,11 @@ fn lessThanByWeight(
return weights[lhs.symbol] < weights[rhs.symbol];
}
pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) Literals.Header {
pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) LiteralsSection.Header {
// TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
const start = consumed_count.*;
const byte0 = src[0];
const block_type = @intToEnum(Literals.BlockType, byte0 & 0b11);
const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11);
const size_format = @intCast(u2, (byte0 & 0b1100) >> 2);
var regenerated_size: u20 = undefined;
var compressed_size: ?u18 = null;
@ -1220,7 +1231,7 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) Literals.He
compressed_size,
},
);
return Literals.Header{
return LiteralsSection.Header{
.block_type = block_type,
.size_format = size_format,
.regenerated_size = regenerated_size,
@ -1228,7 +1239,7 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) Literals.He
};
}
fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Header {
pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !SequencesSection.Header {
var sequence_count: u24 = undefined;
var bytes_read: usize = 0;
@ -1237,7 +1248,7 @@ fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Hea
bytes_read += 1;
log.debug("decoded sequences header '{}': sequence count = 0", .{std.fmt.fmtSliceHexUpper(src[0..bytes_read])});
consumed_count.* += bytes_read;
return Sequences.Header{
return SequencesSection.Header{
.sequence_count = 0,
.offsets = undefined,
.match_lengths = undefined,
@ -1258,9 +1269,9 @@ fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Hea
bytes_read += 1;
consumed_count.* += bytes_read;
const matches_mode = @intToEnum(Sequences.Header.Mode, (compression_modes & 0b00001100) >> 2);
const offsets_mode = @intToEnum(Sequences.Header.Mode, (compression_modes & 0b00110000) >> 4);
const literal_mode = @intToEnum(Sequences.Header.Mode, (compression_modes & 0b11000000) >> 6);
const matches_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00001100) >> 2);
const offsets_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00110000) >> 4);
const literal_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b11000000) >> 6);
log.debug("decoded sequences header '{}': (sc={d},o={s},m={s},l={s})", .{
std.fmt.fmtSliceHexUpper(src[0..bytes_read]),
sequence_count,
@ -1270,7 +1281,7 @@ fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Hea
});
if (compression_modes & 0b11 != 0) return error.ReservedBitSet;
return Sequences.Header{
return SequencesSection.Header{
.sequence_count = sequence_count,
.offsets = offsets_mode,
.match_lengths = matches_mode,
@ -1428,25 +1439,25 @@ const ReversedByteReader = struct {
}
};
const ReverseBitReader = struct {
pub const ReverseBitReader = struct {
byte_reader: ReversedByteReader,
bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader),
fn init(self: *ReverseBitReader, bytes: []const u8) !void {
pub fn init(self: *ReverseBitReader, bytes: []const u8) !void {
self.byte_reader = ReversedByteReader.init(bytes);
self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {}
}
fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
return self.bit_reader.readBitsNoEof(U, num_bits);
}
fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
return try self.bit_reader.readBits(U, num_bits, out_bits);
}
fn alignToByte(self: *@This()) void {
pub fn alignToByte(self: *@This()) void {
self.bit_reader.alignToByte();
}
};
@ -1514,7 +1525,7 @@ fn dumpFseTable(prefix: []const u8, table: []const Table.Fse) void {
}
}
fn dumpHuffmanTree(tree: Literals.HuffmanTree) void {
fn dumpHuffmanTree(tree: LiteralsSection.HuffmanTree) void {
log.debug("Huffman tree: max bit count = {}, symbol count = {}", .{ tree.max_bit_count, tree.symbol_count_minus_one + 1 });
for (tree.nodes[0 .. tree.symbol_count_minus_one + 1]) |node| {
log.debug("symbol = {[symbol]d}, prefix = {[prefix]d}, weight = {[weight]d}", node);

View File

@ -52,7 +52,7 @@ pub const frame = struct {
};
pub const compressed_block = struct {
pub const Literals = struct {
pub const LiteralsSection = struct {
header: Header,
huffman_tree: ?HuffmanTree,
streams: Streams,
@ -119,8 +119,8 @@ pub const compressed_block = struct {
}
};
pub const Sequences = struct {
header: Sequences.Header,
pub const SequencesSection = struct {
header: SequencesSection.Header,
literals_length_table: Table,
offset_table: Table,
match_length_table: Table,