diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index 6fb6a70027..d17e6bac46 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -13,7 +13,7 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool allocator: Allocator, in_reader: ReaderType, - decode_state: decompress.DecodeState, + decode_state: decompress.block.DecodeState, frame_context: decompress.FrameContext, buffer: RingBuffer, last_block: bool, @@ -24,7 +24,7 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool sequence_buffer: []u8, checksum: if (verify_checksum) ?u32 else void, - pub const Error = ReaderType.Error || error{ MalformedBlock, MalformedFrame, EndOfStream }; + pub const Error = ReaderType.Error || error{ MalformedBlock, MalformedFrame }; pub const Reader = std.io.Reader(*Self, Error, read); @@ -34,21 +34,41 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool .zstandard => { const frame_context = context: { const frame_header = try decompress.decodeZStandardHeader(source); - break :context try decompress.FrameContext.init(frame_header, window_size_max, verify_checksum); + break :context try decompress.FrameContext.init( + frame_header, + window_size_max, + verify_checksum, + ); }; - const literal_fse_buffer = try allocator.alloc(types.compressed_block.Table.Fse, types.compressed_block.table_size_max.literal); + const literal_fse_buffer = try allocator.alloc( + types.compressed_block.Table.Fse, + types.compressed_block.table_size_max.literal, + ); errdefer allocator.free(literal_fse_buffer); - const match_fse_buffer = try allocator.alloc(types.compressed_block.Table.Fse, types.compressed_block.table_size_max.match); + + const match_fse_buffer = try allocator.alloc( + types.compressed_block.Table.Fse, + types.compressed_block.table_size_max.match, + ); errdefer allocator.free(match_fse_buffer); - const offset_fse_buffer = try allocator.alloc(types.compressed_block.Table.Fse, types.compressed_block.table_size_max.offset); + + const offset_fse_buffer = try allocator.alloc( + types.compressed_block.Table.Fse, + types.compressed_block.table_size_max.offset, + ); errdefer allocator.free(offset_fse_buffer); - const decode_state = decompress.DecodeState.init(literal_fse_buffer, match_fse_buffer, offset_fse_buffer); + const decode_state = decompress.block.DecodeState.init( + literal_fse_buffer, + match_fse_buffer, + offset_fse_buffer, + ); const buffer = try RingBuffer.init(allocator, frame_context.window_size); const literals_data = try allocator.alloc(u8, window_size_max); errdefer allocator.free(literals_data); + const sequence_data = try allocator.alloc(u8, window_size_max); errdefer allocator.free(sequence_data); @@ -87,10 +107,10 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool if (buffer.len == 0) return 0; if (self.buffer.isEmpty() and !self.last_block) { - const header_bytes = try self.in_reader.readBytesNoEof(3); - const block_header = decompress.decodeBlockHeader(&header_bytes); + const header_bytes = self.in_reader.readBytesNoEof(3) catch return error.MalformedFrame; + const block_header = decompress.block.decodeBlockHeader(&header_bytes); - decompress.decodeBlockReader( + decompress.block.decodeBlockReader( &self.buffer, self.in_reader, block_header, diff --git a/lib/std/compress/zstandard/decode/block.zig b/lib/std/compress/zstandard/decode/block.zig new file mode 100644 index 0000000000..7a7bf28b31 --- /dev/null +++ b/lib/std/compress/zstandard/decode/block.zig @@ -0,0 +1,1051 @@ +const std = @import("std"); +const assert = std.debug.assert; + +const types = @import("../types.zig"); +const frame = types.frame; +const Table = types.compressed_block.Table; +const LiteralsSection = types.compressed_block.LiteralsSection; +const SequencesSection = types.compressed_block.SequencesSection; + +const huffman = @import("huffman.zig"); + +const RingBuffer = @import("../RingBuffer.zig"); + +const readers = @import("../readers.zig"); + +const decodeFseTable = @import("fse.zig").decodeFseTable; + +const readInt = std.mem.readIntLittle; + +pub const Error = error{ + BlockSizeOverMaximum, + MalformedBlockSize, + ReservedBlock, + MalformedRleBlock, + MalformedCompressedBlock, + EndOfStream, +}; + +pub const DecodeState = struct { + repeat_offsets: [3]u32, + + offset: StateData(8), + match: StateData(9), + literal: StateData(9), + + offset_fse_buffer: []Table.Fse, + match_fse_buffer: []Table.Fse, + literal_fse_buffer: []Table.Fse, + + fse_tables_undefined: bool, + + literal_stream_reader: readers.ReverseBitReader, + literal_stream_index: usize, + literal_streams: LiteralsSection.Streams, + literal_header: LiteralsSection.Header, + huffman_tree: ?LiteralsSection.HuffmanTree, + + literal_written_count: usize, + + fn StateData(comptime max_accuracy_log: comptime_int) type { + return struct { + state: State, + table: Table, + accuracy_log: u8, + + const State = std.meta.Int(.unsigned, max_accuracy_log); + }; + } + + pub fn init( + literal_fse_buffer: []Table.Fse, + match_fse_buffer: []Table.Fse, + offset_fse_buffer: []Table.Fse, + ) DecodeState { + return DecodeState{ + .repeat_offsets = .{ + types.compressed_block.start_repeated_offset_1, + types.compressed_block.start_repeated_offset_2, + types.compressed_block.start_repeated_offset_3, + }, + + .offset = undefined, + .match = undefined, + .literal = undefined, + + .literal_fse_buffer = literal_fse_buffer, + .match_fse_buffer = match_fse_buffer, + .offset_fse_buffer = offset_fse_buffer, + + .fse_tables_undefined = true, + + .literal_written_count = 0, + .literal_header = undefined, + .literal_streams = undefined, + .literal_stream_reader = undefined, + .literal_stream_index = undefined, + .huffman_tree = null, + }; + } + + /// Prepare the decoder to decode a compressed block. Loads the literals + /// stream and Huffman tree from `literals` and reads the FSE tables from + /// `source`. + /// + /// Errors: + /// - returns `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's + /// first byte does not have any bits set. + /// - returns `error.TreelessLiteralsFirst` `literals` is a treeless literals section + /// and the decode state does not have a Huffman tree from a previous block. + pub fn prepare( + self: *DecodeState, + source: anytype, + literals: LiteralsSection, + sequences_header: SequencesSection.Header, + ) !void { + self.literal_written_count = 0; + self.literal_header = literals.header; + self.literal_streams = literals.streams; + + if (literals.huffman_tree) |tree| { + self.huffman_tree = tree; + } else if (literals.header.block_type == .treeless and self.huffman_tree == null) { + return error.TreelessLiteralsFirst; + } + + switch (literals.header.block_type) { + .raw, .rle => {}, + .compressed, .treeless => { + self.literal_stream_index = 0; + switch (literals.streams) { + .one => |slice| try self.initLiteralStream(slice), + .four => |streams| try self.initLiteralStream(streams[0]), + } + }, + } + + if (sequences_header.sequence_count > 0) { + try self.updateFseTable(source, .literal, sequences_header.literal_lengths); + try self.updateFseTable(source, .offset, sequences_header.offsets); + try self.updateFseTable(source, .match, sequences_header.match_lengths); + self.fse_tables_undefined = false; + } + } + + /// Read initial FSE states for sequence decoding. Returns `error.EndOfStream` + /// if `bit_reader` does not contain enough bits. + pub fn readInitialFseState(self: *DecodeState, bit_reader: *readers.ReverseBitReader) error{EndOfStream}!void { + self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log); + self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log); + self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log); + } + + fn updateRepeatOffset(self: *DecodeState, offset: u32) void { + std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1]); + std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]); + self.repeat_offsets[0] = offset; + } + + fn useRepeatOffset(self: *DecodeState, index: usize) u32 { + if (index == 1) + std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1]) + else if (index == 2) { + std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]); + std.mem.swap(u32, &self.repeat_offsets[1], &self.repeat_offsets[2]); + } + return self.repeat_offsets[0]; + } + + const DataType = enum { offset, match, literal }; + + fn updateState( + self: *DecodeState, + comptime choice: DataType, + bit_reader: *readers.ReverseBitReader, + ) error{ MalformedFseBits, EndOfStream }!void { + switch (@field(self, @tagName(choice)).table) { + .rle => {}, + .fse => |table| { + const data = table[@field(self, @tagName(choice)).state]; + const T = @TypeOf(@field(self, @tagName(choice))).State; + const bits_summand = try bit_reader.readBitsNoEof(T, data.bits); + const next_state = std.math.cast( + @TypeOf(@field(self, @tagName(choice))).State, + data.baseline + bits_summand, + ) orelse return error.MalformedFseBits; + @field(self, @tagName(choice)).state = next_state; + }, + } + } + + const FseTableError = error{ + MalformedFseTable, + MalformedAccuracyLog, + RepeatModeFirst, + EndOfStream, + }; + + fn updateFseTable( + self: *DecodeState, + source: anytype, + comptime choice: DataType, + mode: SequencesSection.Header.Mode, + ) !void { + const field_name = @tagName(choice); + switch (mode) { + .predefined => { + @field(self, field_name).accuracy_log = + @field(types.compressed_block.default_accuracy_log, field_name); + + @field(self, field_name).table = + @field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table"); + }, + .rle => { + @field(self, field_name).accuracy_log = 0; + @field(self, field_name).table = .{ .rle = try source.readByte() }; + }, + .fse => { + var bit_reader = readers.bitReader(source); + + const table_size = try decodeFseTable( + &bit_reader, + @field(types.compressed_block.table_symbol_count_max, field_name), + @field(types.compressed_block.table_accuracy_log_max, field_name), + @field(self, field_name ++ "_fse_buffer"), + ); + @field(self, field_name).table = .{ + .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size], + }; + @field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size); + }, + .repeat => if (self.fse_tables_undefined) return error.RepeatModeFirst, + } + } + + const Sequence = struct { + literal_length: u32, + match_length: u32, + offset: u32, + }; + + fn nextSequence( + self: *DecodeState, + bit_reader: *readers.ReverseBitReader, + ) error{ OffsetCodeTooLarge, EndOfStream }!Sequence { + const raw_code = self.getCode(.offset); + const offset_code = std.math.cast(u5, raw_code) orelse { + return error.OffsetCodeTooLarge; + }; + const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code); + + const match_code = self.getCode(.match); + const match = types.compressed_block.match_length_code_table[match_code]; + const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]); + + const literal_code = self.getCode(.literal); + const literal = types.compressed_block.literals_length_code_table[literal_code]; + const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]); + + const offset = if (offset_value > 3) offset: { + const offset = offset_value - 3; + self.updateRepeatOffset(offset); + break :offset offset; + } else offset: { + if (literal_length == 0) { + if (offset_value == 3) { + const offset = self.repeat_offsets[0] - 1; + self.updateRepeatOffset(offset); + break :offset offset; + } + break :offset self.useRepeatOffset(offset_value); + } + break :offset self.useRepeatOffset(offset_value - 1); + }; + + return .{ + .literal_length = literal_length, + .match_length = match_length, + .offset = offset, + }; + } + + fn executeSequenceSlice( + self: *DecodeState, + dest: []u8, + write_pos: usize, + sequence: Sequence, + ) (error{MalformedSequence} || DecodeLiteralsError)!void { + if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence; + + try self.decodeLiteralsSlice(dest[write_pos..], sequence.literal_length); + const copy_start = write_pos + sequence.literal_length - sequence.offset; + const copy_end = copy_start + sequence.match_length; + // NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr + // to allow repeats + std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]); + } + + fn executeSequenceRingBuffer( + self: *DecodeState, + dest: *RingBuffer, + sequence: Sequence, + ) (error{MalformedSequence} || DecodeLiteralsError)!void { + if (sequence.offset > dest.data.len) return error.MalformedSequence; + + try self.decodeLiteralsRingBuffer(dest, sequence.literal_length); + const copy_start = dest.write_index + dest.data.len - sequence.offset; + const copy_slice = dest.sliceAt(copy_start, sequence.match_length); + // TODO: would std.mem.copy and figuring out dest slice be better/faster? + for (copy_slice.first) |b| dest.writeAssumeCapacity(b); + for (copy_slice.second) |b| dest.writeAssumeCapacity(b); + } + + const DecodeSequenceError = error{ + OffsetCodeTooLarge, + EndOfStream, + MalformedSequence, + MalformedFseBits, + } || DecodeLiteralsError; + + /// Decode one sequence from `bit_reader` into `dest`, written starting at + /// `write_pos` and update FSE states if `last_sequence` is `false`. Returns + /// `error.MalformedSequence` error if the decompressed sequence would be longer + /// than `sequence_size_limit` or the sequence's offset is too large; returns + /// `error.EndOfStream` if `bit_reader` does not contain enough bits; returns + /// `error.UnexpectedEndOfLiteralStream` if the decoder state's literal streams + /// do not contain enough literals for the sequence (this may mean the literal + /// stream or the sequence is malformed). + pub fn decodeSequenceSlice( + self: *DecodeState, + dest: []u8, + write_pos: usize, + bit_reader: *readers.ReverseBitReader, + sequence_size_limit: usize, + last_sequence: bool, + ) DecodeSequenceError!usize { + const sequence = try self.nextSequence(bit_reader); + const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; + if (sequence_length > sequence_size_limit) return error.MalformedSequence; + + try self.executeSequenceSlice(dest, write_pos, sequence); + if (!last_sequence) { + try self.updateState(.literal, bit_reader); + try self.updateState(.match, bit_reader); + try self.updateState(.offset, bit_reader); + } + return sequence_length; + } + + /// Decode one sequence from `bit_reader` into `dest`; see `decodeSequenceSlice`. + pub fn decodeSequenceRingBuffer( + self: *DecodeState, + dest: *RingBuffer, + bit_reader: anytype, + sequence_size_limit: usize, + last_sequence: bool, + ) DecodeSequenceError!usize { + const sequence = try self.nextSequence(bit_reader); + const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; + if (sequence_length > sequence_size_limit) return error.MalformedSequence; + + try self.executeSequenceRingBuffer(dest, sequence); + if (!last_sequence) { + try self.updateState(.literal, bit_reader); + try self.updateState(.match, bit_reader); + try self.updateState(.offset, bit_reader); + } + return sequence_length; + } + + fn nextLiteralMultiStream( + self: *DecodeState, + ) error{BitStreamHasNoStartBit}!void { + self.literal_stream_index += 1; + try self.initLiteralStream(self.literal_streams.four[self.literal_stream_index]); + } + + pub fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void { + try self.literal_stream_reader.init(bytes); + } + + const LiteralBitsError = error{ + BitStreamHasNoStartBit, + UnexpectedEndOfLiteralStream, + }; + fn readLiteralsBits( + self: *DecodeState, + comptime T: type, + bit_count_to_read: usize, + ) LiteralBitsError!T { + return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: { + if (self.literal_streams == .four and self.literal_stream_index < 3) { + try self.nextLiteralMultiStream(); + break :bits self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch + return error.UnexpectedEndOfLiteralStream; + } else { + return error.UnexpectedEndOfLiteralStream; + } + }; + } + + const DecodeLiteralsError = error{ + MalformedLiteralsLength, + PrefixNotFound, + } || LiteralBitsError; + + /// Decode `len` bytes of literals into `dest`. `literals` should be the + /// `LiteralsSection` that was passed to `prepare()`. Returns + /// `error.MalformedLiteralsLength` if the number of literal bytes decoded by + /// `self` plus `len` is greater than the regenerated size of `literals`. + /// Returns `error.UnexpectedEndOfLiteralStream` and `error.PrefixNotFound` if + /// there are problems decoding Huffman compressed literals. + pub fn decodeLiteralsSlice( + self: *DecodeState, + dest: []u8, + len: usize, + ) DecodeLiteralsError!void { + if (self.literal_written_count + len > self.literal_header.regenerated_size) + return error.MalformedLiteralsLength; + + switch (self.literal_header.block_type) { + .raw => { + const literals_end = self.literal_written_count + len; + const literal_data = self.literal_streams.one[self.literal_written_count..literals_end]; + std.mem.copy(u8, dest, literal_data); + self.literal_written_count += len; + }, + .rle => { + var i: usize = 0; + while (i < len) : (i += 1) { + dest[i] = self.literal_streams.one[0]; + } + self.literal_written_count += len; + }, + .compressed, .treeless => { + // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; + const huffman_tree = self.huffman_tree orelse unreachable; + const max_bit_count = huffman_tree.max_bit_count; + const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount( + huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight, + max_bit_count, + ); + var bits_read: u4 = 0; + var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one; + var bit_count_to_read: u4 = starting_bit_count; + var i: usize = 0; + while (i < len) : (i += 1) { + var prefix: u16 = 0; + while (true) { + const new_bits = self.readLiteralsBits(u16, bit_count_to_read) catch |err| { + return err; + }; + prefix <<= bit_count_to_read; + prefix |= new_bits; + bits_read += bit_count_to_read; + const result = huffman_tree.query(huffman_tree_index, prefix) catch |err| { + return err; + }; + + switch (result) { + .symbol => |sym| { + dest[i] = sym; + bit_count_to_read = starting_bit_count; + bits_read = 0; + huffman_tree_index = huffman_tree.symbol_count_minus_one; + break; + }, + .index => |index| { + huffman_tree_index = index; + const bit_count = LiteralsSection.HuffmanTree.weightToBitCount( + huffman_tree.nodes[index].weight, + max_bit_count, + ); + bit_count_to_read = bit_count - bits_read; + }, + } + } + } + self.literal_written_count += len; + }, + } + } + + /// Decode literals into `dest`; see `decodeLiteralsSlice()`. + pub fn decodeLiteralsRingBuffer( + self: *DecodeState, + dest: *RingBuffer, + len: usize, + ) DecodeLiteralsError!void { + if (self.literal_written_count + len > self.literal_header.regenerated_size) + return error.MalformedLiteralsLength; + + switch (self.literal_header.block_type) { + .raw => { + const literals_end = self.literal_written_count + len; + const literal_data = self.literal_streams.one[self.literal_written_count..literals_end]; + dest.writeSliceAssumeCapacity(literal_data); + self.literal_written_count += len; + }, + .rle => { + var i: usize = 0; + while (i < len) : (i += 1) { + dest.writeAssumeCapacity(self.literal_streams.one[0]); + } + self.literal_written_count += len; + }, + .compressed, .treeless => { + // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; + const huffman_tree = self.huffman_tree orelse unreachable; + const max_bit_count = huffman_tree.max_bit_count; + const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount( + huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight, + max_bit_count, + ); + var bits_read: u4 = 0; + var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one; + var bit_count_to_read: u4 = starting_bit_count; + var i: usize = 0; + while (i < len) : (i += 1) { + var prefix: u16 = 0; + while (true) { + const new_bits = try self.readLiteralsBits(u16, bit_count_to_read); + prefix <<= bit_count_to_read; + prefix |= new_bits; + bits_read += bit_count_to_read; + const result = try huffman_tree.query(huffman_tree_index, prefix); + + switch (result) { + .symbol => |sym| { + dest.writeAssumeCapacity(sym); + bit_count_to_read = starting_bit_count; + bits_read = 0; + huffman_tree_index = huffman_tree.symbol_count_minus_one; + break; + }, + .index => |index| { + huffman_tree_index = index; + const bit_count = LiteralsSection.HuffmanTree.weightToBitCount( + huffman_tree.nodes[index].weight, + max_bit_count, + ); + bit_count_to_read = bit_count - bits_read; + }, + } + } + } + self.literal_written_count += len; + }, + } + } + + fn getCode(self: *DecodeState, comptime choice: DataType) u32 { + return switch (@field(self, @tagName(choice)).table) { + .rle => |value| value, + .fse => |table| table[@field(self, @tagName(choice)).state].symbol, + }; + } +}; + +/// Decode a single block from `src` into `dest`. The beginning of `src` must be +/// the start of the block content (i.e. directly after the block header). +/// Increments `consumed_count` by the number of bytes read from `src` to decode +/// the block and returns the decompressed size of the block. +/// +/// Errors returned: +/// +/// - `error.BlockSizeOverMaximum` if block's size is larger than 1 << 17 or +/// `dest[written_count..].len` +/// - `error.MalformedBlockSize` if `src.len` is smaller than the block size +/// and the block is a raw or compressed block +/// - `error.ReservedBlock` if the block is a reserved block +/// - `error.MalformedRleBlock` if the block is an RLE block and `src.len < 1` +/// - `error.MalformedCompressedBlock` if there are errors decoding a +/// compressed block +/// - `error.EndOfStream` if the sequence bit stream ends unexpectedly +pub fn decodeBlock( + dest: []u8, + src: []const u8, + block_header: frame.ZStandard.Block.Header, + decode_state: *DecodeState, + consumed_count: *usize, + written_count: usize, +) Error!usize { + const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB + const block_size = block_header.block_size; + if (block_size_max < block_size) return error.BlockSizeOverMaximum; + switch (block_header.block_type) { + .raw => { + if (src.len < block_size) return error.MalformedBlockSize; + const data = src[0..block_size]; + std.mem.copy(u8, dest[written_count..], data); + consumed_count.* += block_size; + return block_size; + }, + .rle => { + if (src.len < 1) return error.MalformedRleBlock; + var write_pos: usize = written_count; + while (write_pos < block_size + written_count) : (write_pos += 1) { + dest[write_pos] = src[0]; + } + consumed_count.* += 1; + return block_size; + }, + .compressed => { + if (src.len < block_size) return error.MalformedBlockSize; + var bytes_read: usize = 0; + const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch + return error.MalformedCompressedBlock; + var fbs = std.io.fixedBufferStream(src[bytes_read..]); + const fbs_reader = fbs.reader(); + const sequences_header = decodeSequencesHeader(fbs_reader) catch + return error.MalformedCompressedBlock; + + decode_state.prepare(fbs_reader, literals, sequences_header) catch + return error.MalformedCompressedBlock; + + bytes_read += fbs.pos; + + var bytes_written: usize = 0; + if (sequences_header.sequence_count > 0) { + const bit_stream_bytes = src[bytes_read..block_size]; + var bit_stream: readers.ReverseBitReader = undefined; + bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock; + + decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock; + + var sequence_size_limit = block_size_max; + var i: usize = 0; + while (i < sequences_header.sequence_count) : (i += 1) { + const write_pos = written_count + bytes_written; + const decompressed_size = decode_state.decodeSequenceSlice( + dest, + write_pos, + &bit_stream, + sequence_size_limit, + i == sequences_header.sequence_count - 1, + ) catch return error.MalformedCompressedBlock; + bytes_written += decompressed_size; + sequence_size_limit -= decompressed_size; + } + + bytes_read += bit_stream_bytes.len; + } + if (bytes_read != block_size) return error.MalformedCompressedBlock; + + if (decode_state.literal_written_count < literals.header.regenerated_size) { + const len = literals.header.regenerated_size - decode_state.literal_written_count; + decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], len) catch + return error.MalformedCompressedBlock; + bytes_written += len; + } + + consumed_count.* += bytes_read; + return bytes_written; + }, + .reserved => return error.ReservedBlock, + } +} + +/// Decode a single block from `src` into `dest`; see `decodeBlock()`. Returns +/// the size of the decompressed block, which can be used with `dest.sliceLast()` +/// to get the decompressed bytes. `error.BlockSizeOverMaximum` is returned if +/// the block's compressed or decompressed size is larger than `block_size_max`. +pub fn decodeBlockRingBuffer( + dest: *RingBuffer, + src: []const u8, + block_header: frame.ZStandard.Block.Header, + decode_state: *DecodeState, + consumed_count: *usize, + block_size_max: usize, +) Error!usize { + const block_size = block_header.block_size; + if (block_size_max < block_size) return error.BlockSizeOverMaximum; + switch (block_header.block_type) { + .raw => { + if (src.len < block_size) return error.MalformedBlockSize; + const data = src[0..block_size]; + dest.writeSliceAssumeCapacity(data); + consumed_count.* += block_size; + return block_size; + }, + .rle => { + if (src.len < 1) return error.MalformedRleBlock; + var write_pos: usize = 0; + while (write_pos < block_size) : (write_pos += 1) { + dest.writeAssumeCapacity(src[0]); + } + consumed_count.* += 1; + return block_size; + }, + .compressed => { + if (src.len < block_size) return error.MalformedBlockSize; + var bytes_read: usize = 0; + const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch + return error.MalformedCompressedBlock; + var fbs = std.io.fixedBufferStream(src[bytes_read..]); + const fbs_reader = fbs.reader(); + const sequences_header = decodeSequencesHeader(fbs_reader) catch + return error.MalformedCompressedBlock; + + decode_state.prepare(fbs_reader, literals, sequences_header) catch + return error.MalformedCompressedBlock; + + bytes_read += fbs.pos; + + var bytes_written: usize = 0; + if (sequences_header.sequence_count > 0) { + const bit_stream_bytes = src[bytes_read..block_size]; + var bit_stream: readers.ReverseBitReader = undefined; + bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock; + + decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock; + + var sequence_size_limit = block_size_max; + var i: usize = 0; + while (i < sequences_header.sequence_count) : (i += 1) { + const decompressed_size = decode_state.decodeSequenceRingBuffer( + dest, + &bit_stream, + sequence_size_limit, + i == sequences_header.sequence_count - 1, + ) catch return error.MalformedCompressedBlock; + bytes_written += decompressed_size; + sequence_size_limit -= decompressed_size; + } + + bytes_read += bit_stream_bytes.len; + } + if (bytes_read != block_size) return error.MalformedCompressedBlock; + + if (decode_state.literal_written_count < literals.header.regenerated_size) { + const len = literals.header.regenerated_size - decode_state.literal_written_count; + decode_state.decodeLiteralsRingBuffer(dest, len) catch + return error.MalformedCompressedBlock; + bytes_written += len; + } + + consumed_count.* += bytes_read; + if (bytes_written > block_size_max) return error.BlockSizeOverMaximum; + return bytes_written; + }, + .reserved => return error.ReservedBlock, + } +} + +/// Decode a single block from `source` into `dest`. Literal and sequence data +/// from the block is copied into `literals_buffer` and `sequence_buffer`, which +/// must be large enough or `error.LiteralsBufferTooSmall` and +/// `error.SequenceBufferTooSmall` are returned (the maximum block size is an +/// upper bound for the size of both buffers). See `decodeBlock` +/// and `decodeBlockRingBuffer` for function that can decode a block without +/// these extra copies. +pub fn decodeBlockReader( + dest: *RingBuffer, + source: anytype, + block_header: frame.ZStandard.Block.Header, + decode_state: *DecodeState, + block_size_max: usize, + literals_buffer: []u8, + sequence_buffer: []u8, +) !void { + const block_size = block_header.block_size; + var block_reader_limited = std.io.limitedReader(source, block_size); + const block_reader = block_reader_limited.reader(); + if (block_size_max < block_size) return error.BlockSizeOverMaximum; + switch (block_header.block_type) { + .raw => { + const slice = dest.sliceAt(dest.write_index, block_size); + try source.readNoEof(slice.first); + try source.readNoEof(slice.second); + dest.write_index = dest.mask2(dest.write_index + block_size); + }, + .rle => { + const byte = try source.readByte(); + var i: usize = 0; + while (i < block_size) : (i += 1) { + dest.writeAssumeCapacity(byte); + } + }, + .compressed => { + const literals = try decodeLiteralsSection(block_reader, literals_buffer); + const sequences_header = try decodeSequencesHeader(block_reader); + + try decode_state.prepare(block_reader, literals, sequences_header); + + if (sequences_header.sequence_count > 0) { + if (sequence_buffer.len < block_reader_limited.bytes_left) + return error.SequenceBufferTooSmall; + + const size = try block_reader.readAll(sequence_buffer); + var bit_stream: readers.ReverseBitReader = undefined; + try bit_stream.init(sequence_buffer[0..size]); + + decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock; + + var sequence_size_limit = block_size_max; + var i: usize = 0; + while (i < sequences_header.sequence_count) : (i += 1) { + const decompressed_size = decode_state.decodeSequenceRingBuffer( + dest, + &bit_stream, + sequence_size_limit, + i == sequences_header.sequence_count - 1, + ) catch return error.MalformedCompressedBlock; + sequence_size_limit -= decompressed_size; + } + } + + if (decode_state.literal_written_count < literals.header.regenerated_size) { + const len = literals.header.regenerated_size - decode_state.literal_written_count; + decode_state.decodeLiteralsRingBuffer(dest, len) catch + return error.MalformedCompressedBlock; + } + + decode_state.literal_written_count = 0; + assert(block_reader.readByte() == error.EndOfStream); + }, + .reserved => return error.ReservedBlock, + } +} + +/// Decode the header of a block. +pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header { + const last_block = src[0] & 1 == 1; + const block_type = @intToEnum(frame.ZStandard.Block.Type, (src[0] & 0b110) >> 1); + const block_size = ((src[0] & 0b11111000) >> 3) + (@as(u21, src[1]) << 5) + (@as(u21, src[2]) << 13); + return .{ + .last_block = last_block, + .block_type = block_type, + .block_size = block_size, + }; +} + +pub fn decodeBlockHeaderSlice(src: []const u8) error{EndOfStream}!frame.ZStandard.Block.Header { + if (src.len < 3) return error.EndOfStream; + return decodeBlockHeader(src[0..3]); +} + +/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the +/// number of bytes the section uses. +/// +/// Errors: +/// - returns `error.MalformedLiteralsHeader` if the header is invalid +/// - returns `error.MalformedLiteralsSection` if there are errors decoding +pub fn decodeLiteralsSectionSlice( + src: []const u8, + consumed_count: *usize, +) (error{ MalformedLiteralsHeader, MalformedLiteralsSection, EndOfStream } || huffman.Error)!LiteralsSection { + var bytes_read: usize = 0; + const header = header: { + var fbs = std.io.fixedBufferStream(src); + defer bytes_read = fbs.pos; + break :header decodeLiteralsHeader(fbs.reader()) catch return error.MalformedLiteralsHeader; + }; + switch (header.block_type) { + .raw => { + if (src.len < bytes_read + header.regenerated_size) return error.MalformedLiteralsSection; + const stream = src[bytes_read .. bytes_read + header.regenerated_size]; + consumed_count.* += header.regenerated_size + bytes_read; + return LiteralsSection{ + .header = header, + .huffman_tree = null, + .streams = .{ .one = stream }, + }; + }, + .rle => { + if (src.len < bytes_read + 1) return error.MalformedLiteralsSection; + const stream = src[bytes_read .. bytes_read + 1]; + consumed_count.* += 1 + bytes_read; + return LiteralsSection{ + .header = header, + .huffman_tree = null, + .streams = .{ .one = stream }, + }; + }, + .compressed, .treeless => { + const huffman_tree_start = bytes_read; + const huffman_tree = if (header.block_type == .compressed) + try huffman.decodeHuffmanTreeSlice(src[bytes_read..], &bytes_read) + else + null; + const huffman_tree_size = bytes_read - huffman_tree_start; + const total_streams_size = @as(usize, header.compressed_size.?) - huffman_tree_size; + + if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection; + const stream_data = src[bytes_read .. bytes_read + total_streams_size]; + + const streams = try decodeStreams(header.size_format, stream_data); + consumed_count.* += bytes_read + total_streams_size; + return LiteralsSection{ + .header = header, + .huffman_tree = huffman_tree, + .streams = streams, + }; + }, + } +} + +/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the +/// number of bytes the section uses. +/// +/// Errors: +/// - returns `error.MalformedLiteralsHeader` if the header is invalid +/// - returns `error.MalformedLiteralsSection` if there are errors decoding +pub fn decodeLiteralsSection( + source: anytype, + buffer: []u8, +) !LiteralsSection { + const header = try decodeLiteralsHeader(source); + switch (header.block_type) { + .raw => { + try source.readNoEof(buffer[0..header.regenerated_size]); + return LiteralsSection{ + .header = header, + .huffman_tree = null, + .streams = .{ .one = buffer }, + }; + }, + .rle => { + buffer[0] = try source.readByte(); + return LiteralsSection{ + .header = header, + .huffman_tree = null, + .streams = .{ .one = buffer[0..1] }, + }; + }, + .compressed, .treeless => { + var counting_reader = std.io.countingReader(source); + const huffman_tree = if (header.block_type == .compressed) + try huffman.decodeHuffmanTree(counting_reader.reader(), buffer) + else + null; + const huffman_tree_size = counting_reader.bytes_read; + const total_streams_size = @as(usize, header.compressed_size.?) - @intCast(usize, huffman_tree_size); + + if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall; + try source.readNoEof(buffer[0..total_streams_size]); + const stream_data = buffer[0..total_streams_size]; + + const streams = try decodeStreams(header.size_format, stream_data); + return LiteralsSection{ + .header = header, + .huffman_tree = huffman_tree, + .streams = streams, + }; + }, + } +} + +fn decodeStreams(size_format: u2, stream_data: []const u8) !LiteralsSection.Streams { + if (size_format == 0) { + return .{ .one = stream_data }; + } + + if (stream_data.len < 6) return error.MalformedLiteralsSection; + + const stream_1_length = @as(usize, readInt(u16, stream_data[0..2])); + const stream_2_length = @as(usize, readInt(u16, stream_data[2..4])); + const stream_3_length = @as(usize, readInt(u16, stream_data[4..6])); + + const stream_1_start = 6; + const stream_2_start = stream_1_start + stream_1_length; + const stream_3_start = stream_2_start + stream_2_length; + const stream_4_start = stream_3_start + stream_3_length; + + return .{ .four = .{ + stream_data[stream_1_start .. stream_1_start + stream_1_length], + stream_data[stream_2_start .. stream_2_start + stream_2_length], + stream_data[stream_3_start .. stream_3_start + stream_3_length], + stream_data[stream_4_start..], + } }; +} + +/// Decode a literals section header. +pub fn decodeLiteralsHeader(source: anytype) !LiteralsSection.Header { + const byte0 = try source.readByte(); + const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11); + const size_format = @intCast(u2, (byte0 & 0b1100) >> 2); + var regenerated_size: u20 = undefined; + var compressed_size: ?u18 = null; + switch (block_type) { + .raw, .rle => { + switch (size_format) { + 0, 2 => { + regenerated_size = byte0 >> 3; + }, + 1 => regenerated_size = (byte0 >> 4) + (@as(u20, try source.readByte()) << 4), + 3 => regenerated_size = (byte0 >> 4) + + (@as(u20, try source.readByte()) << 4) + + (@as(u20, try source.readByte()) << 12), + } + }, + .compressed, .treeless => { + const byte1 = try source.readByte(); + const byte2 = try source.readByte(); + switch (size_format) { + 0, 1 => { + regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4); + compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2); + }, + 2 => { + const byte3 = try source.readByte(); + regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12); + compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6); + }, + 3 => { + const byte3 = try source.readByte(); + const byte4 = try source.readByte(); + regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12); + compressed_size = ((byte2 & 0b11000000) >> 6) + (@as(u18, byte3) << 2) + (@as(u18, byte4) << 10); + }, + } + }, + } + return LiteralsSection.Header{ + .block_type = block_type, + .size_format = size_format, + .regenerated_size = regenerated_size, + .compressed_size = compressed_size, + }; +} + +/// Decode a sequences section header. +/// +/// Errors: +/// - returns `error.ReservedBitSet` is the reserved bit is set +/// - returns `error.MalformedSequencesHeader` if the header is invalid +pub fn decodeSequencesHeader( + source: anytype, +) !SequencesSection.Header { + var sequence_count: u24 = undefined; + + const byte0 = try source.readByte(); + if (byte0 == 0) { + return SequencesSection.Header{ + .sequence_count = 0, + .offsets = undefined, + .match_lengths = undefined, + .literal_lengths = undefined, + }; + } else if (byte0 < 128) { + sequence_count = byte0; + } else if (byte0 < 255) { + sequence_count = (@as(u24, (byte0 - 128)) << 8) + try source.readByte(); + } else { + sequence_count = (try source.readByte()) + (@as(u24, try source.readByte()) << 8) + 0x7F00; + } + + const compression_modes = try source.readByte(); + + const matches_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00001100) >> 2); + const offsets_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00110000) >> 4); + const literal_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b11000000) >> 6); + if (compression_modes & 0b11 != 0) return error.ReservedBitSet; + + return SequencesSection.Header{ + .sequence_count = sequence_count, + .offsets = offsets_mode, + .match_lengths = matches_mode, + .literal_lengths = literal_mode, + }; +} diff --git a/lib/std/compress/zstandard/decode/fse.zig b/lib/std/compress/zstandard/decode/fse.zig new file mode 100644 index 0000000000..5f87c1f81b --- /dev/null +++ b/lib/std/compress/zstandard/decode/fse.zig @@ -0,0 +1,154 @@ +const std = @import("std"); +const assert = std.debug.assert; + +const types = @import("../types.zig"); +const Table = types.compressed_block.Table; + +pub fn decodeFseTable( + bit_reader: anytype, + expected_symbol_count: usize, + max_accuracy_log: u4, + entries: []Table.Fse, +) !usize { + const accuracy_log_biased = try bit_reader.readBitsNoEof(u4, 4); + if (accuracy_log_biased > max_accuracy_log -| 5) return error.MalformedAccuracyLog; + const accuracy_log = accuracy_log_biased + 5; + + var values: [256]u16 = undefined; + var value_count: usize = 0; + + const total_probability = @as(u16, 1) << accuracy_log; + var accumulated_probability: u16 = 0; + + while (accumulated_probability < total_probability) { + // WARNING: The RFC in poorly worded, and would suggest std.math.log2_int_ceil is correct here, + // but power of two (remaining probabilities + 1) need max bits set to 1 more. + const max_bits = std.math.log2_int(u16, total_probability - accumulated_probability + 1) + 1; + const small = try bit_reader.readBitsNoEof(u16, max_bits - 1); + + const cutoff = (@as(u16, 1) << max_bits) - 1 - (total_probability - accumulated_probability + 1); + + const value = if (small < cutoff) + small + else value: { + const value_read = small + (try bit_reader.readBitsNoEof(u16, 1) << (max_bits - 1)); + break :value if (value_read < @as(u16, 1) << (max_bits - 1)) + value_read + else + value_read - cutoff; + }; + + accumulated_probability += if (value != 0) value - 1 else 1; + + values[value_count] = value; + value_count += 1; + + if (value == 1) { + while (true) { + const repeat_flag = try bit_reader.readBitsNoEof(u2, 2); + var i: usize = 0; + while (i < repeat_flag) : (i += 1) { + values[value_count] = 1; + value_count += 1; + } + if (repeat_flag < 3) break; + } + } + } + bit_reader.alignToByte(); + + if (value_count < 2) return error.MalformedFseTable; + if (accumulated_probability != total_probability) return error.MalformedFseTable; + if (value_count > expected_symbol_count) return error.MalformedFseTable; + + const table_size = total_probability; + + try buildFseTable(values[0..value_count], entries[0..table_size]); + return table_size; +} + +fn buildFseTable(values: []const u16, entries: []Table.Fse) !void { + const total_probability = @intCast(u16, entries.len); + const accuracy_log = std.math.log2_int(u16, total_probability); + assert(total_probability <= 1 << 9); + + var less_than_one_count: usize = 0; + for (values) |value, i| { + if (value == 0) { + entries[entries.len - 1 - less_than_one_count] = Table.Fse{ + .symbol = @intCast(u8, i), + .baseline = 0, + .bits = accuracy_log, + }; + less_than_one_count += 1; + } + } + + var position: usize = 0; + var temp_states: [1 << 9]u16 = undefined; + for (values) |value, symbol| { + if (value == 0 or value == 1) continue; + const probability = value - 1; + + const state_share_dividend = std.math.ceilPowerOfTwo(u16, probability) catch + return error.MalformedFseTable; + const share_size = @divExact(total_probability, state_share_dividend); + const double_state_count = state_share_dividend - probability; + const single_state_count = probability - double_state_count; + const share_size_log = std.math.log2_int(u16, share_size); + + var i: u16 = 0; + while (i < probability) : (i += 1) { + temp_states[i] = @intCast(u16, position); + position += (entries.len >> 1) + (entries.len >> 3) + 3; + position &= entries.len - 1; + while (position >= entries.len - less_than_one_count) { + position += (entries.len >> 1) + (entries.len >> 3) + 3; + position &= entries.len - 1; + } + } + std.sort.sort(u16, temp_states[0..probability], {}, std.sort.asc(u16)); + i = 0; + while (i < probability) : (i += 1) { + entries[temp_states[i]] = if (i < double_state_count) Table.Fse{ + .symbol = @intCast(u8, symbol), + .bits = share_size_log + 1, + .baseline = single_state_count * share_size + i * 2 * share_size, + } else Table.Fse{ + .symbol = @intCast(u8, symbol), + .bits = share_size_log, + .baseline = (i - double_state_count) * share_size, + }; + } + } +} + +test buildFseTable { + const literals_length_default_values = [36]u16{ + 5, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 2, 2, 2, 2, 2, + 0, 0, 0, 0, + }; + + const match_lengths_default_values = [53]u16{ + 2, 5, 4, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, + 0, 0, 0, 0, 0, + }; + + const offset_codes_default_values = [29]u16{ + 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, + }; + + var entries: [64]Table.Fse = undefined; + try buildFseTable(&literals_length_default_values, &entries); + try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_literal_fse_table.fse, &entries); + + try buildFseTable(&match_lengths_default_values, &entries); + try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_match_fse_table.fse, &entries); + + try buildFseTable(&offset_codes_default_values, entries[0..32]); + try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_offset_fse_table.fse, entries[0..32]); +} diff --git a/lib/std/compress/zstandard/decode/huffman.zig b/lib/std/compress/zstandard/decode/huffman.zig new file mode 100644 index 0000000000..c759bfd6ab --- /dev/null +++ b/lib/std/compress/zstandard/decode/huffman.zig @@ -0,0 +1,212 @@ +const std = @import("std"); + +const types = @import("../types.zig"); +const LiteralsSection = types.compressed_block.LiteralsSection; +const Table = types.compressed_block.Table; + +const readers = @import("../readers.zig"); + +const decodeFseTable = @import("fse.zig").decodeFseTable; + +pub const Error = error{ + MalformedHuffmanTree, + MalformedFseTable, + MalformedAccuracyLog, + EndOfStream, +}; + +fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, weights: *[256]u4) !usize { + var stream = std.io.limitedReader(source, compressed_size); + var bit_reader = readers.bitReader(stream.reader()); + + var entries: [1 << 6]Table.Fse = undefined; + const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) { + error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e, + error.EndOfStream => return error.MalformedFseTable, + }; + const accuracy_log = std.math.log2_int_ceil(usize, table_size); + + const amount = try stream.reader().readAll(buffer); + var huff_bits: readers.ReverseBitReader = undefined; + huff_bits.init(buffer[0..amount]) catch return error.MalformedHuffmanTree; + + return assignWeights(&huff_bits, accuracy_log, &entries, weights); +} + +fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *[256]u4) !usize { + if (src.len < compressed_size) return error.MalformedHuffmanTree; + var stream = std.io.fixedBufferStream(src[0..compressed_size]); + var counting_reader = std.io.countingReader(stream.reader()); + var bit_reader = readers.bitReader(counting_reader.reader()); + + var entries: [1 << 6]Table.Fse = undefined; + const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) { + error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e, + error.EndOfStream => return error.MalformedFseTable, + }; + const accuracy_log = std.math.log2_int_ceil(usize, table_size); + + const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedHuffmanTree; + var huff_data = src[start_index..compressed_size]; + var huff_bits: readers.ReverseBitReader = undefined; + huff_bits.init(huff_data) catch return error.MalformedHuffmanTree; + + return assignWeights(&huff_bits, accuracy_log, &entries, weights); +} + +fn assignWeights(huff_bits: *readers.ReverseBitReader, accuracy_log: usize, entries: *[1 << 6]Table.Fse, weights: *[256]u4) !usize { + var i: usize = 0; + var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; + var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; + + while (i < 255) { + const even_data = entries[even_state]; + var read_bits: usize = 0; + const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable; + weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree; + i += 1; + if (read_bits < even_data.bits) { + weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree; + i += 1; + break; + } + even_state = even_data.baseline + even_bits; + + read_bits = 0; + const odd_data = entries[odd_state]; + const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable; + weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree; + i += 1; + if (read_bits < odd_data.bits) { + if (i == 256) return error.MalformedHuffmanTree; + weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree; + i += 1; + break; + } + odd_state = odd_data.baseline + odd_bits; + } else return error.MalformedHuffmanTree; + + return i + 1; // stream contains all but the last symbol +} + +fn decodeDirectHuffmanTree(source: anytype, encoded_symbol_count: usize, weights: *[256]u4) !usize { + const weights_byte_count = (encoded_symbol_count + 1) / 2; + var i: usize = 0; + while (i < weights_byte_count) : (i += 1) { + const byte = try source.readByte(); + weights[2 * i] = @intCast(u4, byte >> 4); + weights[2 * i + 1] = @intCast(u4, byte & 0xF); + } + return encoded_symbol_count + 1; +} + +fn assignSymbols(weight_sorted_prefixed_symbols: []LiteralsSection.HuffmanTree.PrefixedSymbol, weights: [256]u4) usize { + for (weight_sorted_prefixed_symbols) |_, i| { + weight_sorted_prefixed_symbols[i] = .{ + .symbol = @intCast(u8, i), + .weight = undefined, + .prefix = undefined, + }; + } + + std.sort.sort( + LiteralsSection.HuffmanTree.PrefixedSymbol, + weight_sorted_prefixed_symbols, + weights, + lessThanByWeight, + ); + + var prefix: u16 = 0; + var prefixed_symbol_count: usize = 0; + var sorted_index: usize = 0; + const symbol_count = weight_sorted_prefixed_symbols.len; + while (sorted_index < symbol_count) { + var symbol = weight_sorted_prefixed_symbols[sorted_index].symbol; + const weight = weights[symbol]; + if (weight == 0) { + sorted_index += 1; + continue; + } + + while (sorted_index < symbol_count) : ({ + sorted_index += 1; + prefixed_symbol_count += 1; + prefix += 1; + }) { + symbol = weight_sorted_prefixed_symbols[sorted_index].symbol; + if (weights[symbol] != weight) { + prefix = ((prefix - 1) >> (weights[symbol] - weight)) + 1; + break; + } + weight_sorted_prefixed_symbols[prefixed_symbol_count].symbol = symbol; + weight_sorted_prefixed_symbols[prefixed_symbol_count].prefix = prefix; + weight_sorted_prefixed_symbols[prefixed_symbol_count].weight = weight; + } + } + return prefixed_symbol_count; +} + +fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) LiteralsSection.HuffmanTree { + var weight_power_sum: u16 = 0; + for (weights[0 .. symbol_count - 1]) |value| { + if (value > 0) { + weight_power_sum += @as(u16, 1) << (value - 1); + } + } + + // advance to next power of two (even if weight_power_sum is a power of 2) + const max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1; + const next_power_of_two = @as(u16, 1) << max_number_of_bits; + weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1; + + var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined; + const prefixed_symbol_count = assignSymbols(weight_sorted_prefixed_symbols[0..symbol_count], weights.*); + const tree = LiteralsSection.HuffmanTree{ + .max_bit_count = max_number_of_bits, + .symbol_count_minus_one = @intCast(u8, prefixed_symbol_count - 1), + .nodes = weight_sorted_prefixed_symbols, + }; + return tree; +} + +pub fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.HuffmanTree { + const header = try source.readByte(); + var weights: [256]u4 = undefined; + const symbol_count = if (header < 128) + // FSE compressed weights + try decodeFseHuffmanTree(source, header, buffer, &weights) + else + try decodeDirectHuffmanTree(source, header - 127, &weights); + + return buildHuffmanTree(&weights, symbol_count); +} + +pub fn decodeHuffmanTreeSlice(src: []const u8, consumed_count: *usize) Error!LiteralsSection.HuffmanTree { + if (src.len == 0) return error.MalformedHuffmanTree; + const header = src[0]; + var bytes_read: usize = 1; + var weights: [256]u4 = undefined; + const symbol_count = if (header < 128) count: { + // FSE compressed weights + bytes_read += header; + break :count try decodeFseHuffmanTreeSlice(src[1..], header, &weights); + } else count: { + var fbs = std.io.fixedBufferStream(src[1..]); + defer bytes_read += fbs.pos; + break :count try decodeDirectHuffmanTree(fbs.reader(), header - 127, &weights); + }; + + consumed_count.* += bytes_read; + return buildHuffmanTree(&weights, symbol_count); +} + +fn lessThanByWeight( + weights: [256]u4, + lhs: LiteralsSection.HuffmanTree.PrefixedSymbol, + rhs: LiteralsSection.HuffmanTree.PrefixedSymbol, +) bool { + // NOTE: this function relies on the use of a stable sorting algorithm, + // otherwise a special case of if (weights[lhs] == weights[rhs]) return lhs < rhs; + // should be added + return weights[lhs.symbol] < weights[rhs.symbol]; +} diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 2b397007a9..abc30fda2a 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -6,8 +6,13 @@ const frame = types.frame; const LiteralsSection = types.compressed_block.LiteralsSection; const SequencesSection = types.compressed_block.SequencesSection; const Table = types.compressed_block.Table; + +pub const block = @import("decode/block.zig"); + pub const RingBuffer = @import("RingBuffer.zig"); +const readers = @import("readers.zig"); + const readInt = std.mem.readIntLittle; const readIntSlice = std.mem.readIntSliceLittle; fn readVarInt(comptime T: type, bytes: []const u8) T { @@ -61,526 +66,6 @@ pub fn decodeFrame( }; } -pub const DecodeState = struct { - repeat_offsets: [3]u32, - - offset: StateData(8), - match: StateData(9), - literal: StateData(9), - - offset_fse_buffer: []Table.Fse, - match_fse_buffer: []Table.Fse, - literal_fse_buffer: []Table.Fse, - - fse_tables_undefined: bool, - - literal_stream_reader: ReverseBitReader, - literal_stream_index: usize, - literal_streams: LiteralsSection.Streams, - literal_header: LiteralsSection.Header, - huffman_tree: ?LiteralsSection.HuffmanTree, - - literal_written_count: usize, - - fn StateData(comptime max_accuracy_log: comptime_int) type { - return struct { - state: State, - table: Table, - accuracy_log: u8, - - const State = std.meta.Int(.unsigned, max_accuracy_log); - }; - } - - pub fn init( - literal_fse_buffer: []Table.Fse, - match_fse_buffer: []Table.Fse, - offset_fse_buffer: []Table.Fse, - ) DecodeState { - return DecodeState{ - .repeat_offsets = .{ - types.compressed_block.start_repeated_offset_1, - types.compressed_block.start_repeated_offset_2, - types.compressed_block.start_repeated_offset_3, - }, - - .offset = undefined, - .match = undefined, - .literal = undefined, - - .literal_fse_buffer = literal_fse_buffer, - .match_fse_buffer = match_fse_buffer, - .offset_fse_buffer = offset_fse_buffer, - - .fse_tables_undefined = true, - - .literal_written_count = 0, - .literal_header = undefined, - .literal_streams = undefined, - .literal_stream_reader = undefined, - .literal_stream_index = undefined, - .huffman_tree = null, - }; - } - - /// Prepare the decoder to decode a compressed block. Loads the literals - /// stream and Huffman tree from `literals` and reads the FSE tables from - /// `source`. - /// - /// Errors: - /// - returns `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's - /// first byte does not have any bits set. - /// - returns `error.TreelessLiteralsFirst` `literals` is a treeless literals section - /// and the decode state does not have a Huffman tree from a previous block. - pub fn prepare( - self: *DecodeState, - source: anytype, - literals: LiteralsSection, - sequences_header: SequencesSection.Header, - ) !void { - self.literal_written_count = 0; - self.literal_header = literals.header; - self.literal_streams = literals.streams; - - if (literals.huffman_tree) |tree| { - self.huffman_tree = tree; - } else if (literals.header.block_type == .treeless and self.huffman_tree == null) { - return error.TreelessLiteralsFirst; - } - - switch (literals.header.block_type) { - .raw, .rle => {}, - .compressed, .treeless => { - self.literal_stream_index = 0; - switch (literals.streams) { - .one => |slice| try self.initLiteralStream(slice), - .four => |streams| try self.initLiteralStream(streams[0]), - } - }, - } - - if (sequences_header.sequence_count > 0) { - try self.updateFseTable(source, .literal, sequences_header.literal_lengths); - try self.updateFseTable(source, .offset, sequences_header.offsets); - try self.updateFseTable(source, .match, sequences_header.match_lengths); - self.fse_tables_undefined = false; - } - } - - /// Read initial FSE states for sequence decoding. Returns `error.EndOfStream` - /// if `bit_reader` does not contain enough bits. - pub fn readInitialFseState(self: *DecodeState, bit_reader: *ReverseBitReader) error{EndOfStream}!void { - self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log); - self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log); - self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log); - } - - fn updateRepeatOffset(self: *DecodeState, offset: u32) void { - std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1]); - std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]); - self.repeat_offsets[0] = offset; - } - - fn useRepeatOffset(self: *DecodeState, index: usize) u32 { - if (index == 1) - std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1]) - else if (index == 2) { - std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]); - std.mem.swap(u32, &self.repeat_offsets[1], &self.repeat_offsets[2]); - } - return self.repeat_offsets[0]; - } - - const DataType = enum { offset, match, literal }; - - fn updateState( - self: *DecodeState, - comptime choice: DataType, - bit_reader: *ReverseBitReader, - ) error{ MalformedFseBits, EndOfStream }!void { - switch (@field(self, @tagName(choice)).table) { - .rle => {}, - .fse => |table| { - const data = table[@field(self, @tagName(choice)).state]; - const T = @TypeOf(@field(self, @tagName(choice))).State; - const bits_summand = try bit_reader.readBitsNoEof(T, data.bits); - const next_state = std.math.cast( - @TypeOf(@field(self, @tagName(choice))).State, - data.baseline + bits_summand, - ) orelse return error.MalformedFseBits; - @field(self, @tagName(choice)).state = next_state; - }, - } - } - - const FseTableError = error{ - MalformedFseTable, - MalformedAccuracyLog, - RepeatModeFirst, - EndOfStream, - }; - - fn updateFseTable( - self: *DecodeState, - source: anytype, - comptime choice: DataType, - mode: SequencesSection.Header.Mode, - ) !void { - const field_name = @tagName(choice); - switch (mode) { - .predefined => { - @field(self, field_name).accuracy_log = - @field(types.compressed_block.default_accuracy_log, field_name); - - @field(self, field_name).table = - @field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table"); - }, - .rle => { - @field(self, field_name).accuracy_log = 0; - @field(self, field_name).table = .{ .rle = try source.readByte() }; - }, - .fse => { - var bit_reader = bitReader(source); - - const table_size = try decodeFseTable( - &bit_reader, - @field(types.compressed_block.table_symbol_count_max, field_name), - @field(types.compressed_block.table_accuracy_log_max, field_name), - @field(self, field_name ++ "_fse_buffer"), - ); - @field(self, field_name).table = .{ - .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size], - }; - @field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size); - }, - .repeat => if (self.fse_tables_undefined) return error.RepeatModeFirst, - } - } - - const Sequence = struct { - literal_length: u32, - match_length: u32, - offset: u32, - }; - - fn nextSequence( - self: *DecodeState, - bit_reader: *ReverseBitReader, - ) error{ OffsetCodeTooLarge, EndOfStream }!Sequence { - const raw_code = self.getCode(.offset); - const offset_code = std.math.cast(u5, raw_code) orelse { - return error.OffsetCodeTooLarge; - }; - const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code); - - const match_code = self.getCode(.match); - const match = types.compressed_block.match_length_code_table[match_code]; - const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]); - - const literal_code = self.getCode(.literal); - const literal = types.compressed_block.literals_length_code_table[literal_code]; - const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]); - - const offset = if (offset_value > 3) offset: { - const offset = offset_value - 3; - self.updateRepeatOffset(offset); - break :offset offset; - } else offset: { - if (literal_length == 0) { - if (offset_value == 3) { - const offset = self.repeat_offsets[0] - 1; - self.updateRepeatOffset(offset); - break :offset offset; - } - break :offset self.useRepeatOffset(offset_value); - } - break :offset self.useRepeatOffset(offset_value - 1); - }; - - return .{ - .literal_length = literal_length, - .match_length = match_length, - .offset = offset, - }; - } - - fn executeSequenceSlice( - self: *DecodeState, - dest: []u8, - write_pos: usize, - sequence: Sequence, - ) (error{MalformedSequence} || DecodeLiteralsError)!void { - if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence; - - try self.decodeLiteralsSlice(dest[write_pos..], sequence.literal_length); - const copy_start = write_pos + sequence.literal_length - sequence.offset; - const copy_end = copy_start + sequence.match_length; - // NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr - // to allow repeats - std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]); - } - - fn executeSequenceRingBuffer( - self: *DecodeState, - dest: *RingBuffer, - sequence: Sequence, - ) (error{MalformedSequence} || DecodeLiteralsError)!void { - if (sequence.offset > dest.data.len) return error.MalformedSequence; - - try self.decodeLiteralsRingBuffer(dest, sequence.literal_length); - const copy_start = dest.write_index + dest.data.len - sequence.offset; - const copy_slice = dest.sliceAt(copy_start, sequence.match_length); - // TODO: would std.mem.copy and figuring out dest slice be better/faster? - for (copy_slice.first) |b| dest.writeAssumeCapacity(b); - for (copy_slice.second) |b| dest.writeAssumeCapacity(b); - } - - const DecodeSequenceError = error{ - OffsetCodeTooLarge, - EndOfStream, - MalformedSequence, - MalformedFseBits, - } || DecodeLiteralsError; - - /// Decode one sequence from `bit_reader` into `dest`, written starting at - /// `write_pos` and update FSE states if `last_sequence` is `false`. Returns - /// `error.MalformedSequence` error if the decompressed sequence would be longer - /// than `sequence_size_limit` or the sequence's offset is too large; returns - /// `error.EndOfStream` if `bit_reader` does not contain enough bits; returns - /// `error.UnexpectedEndOfLiteralStream` if the decoder state's literal streams - /// do not contain enough literals for the sequence (this may mean the literal - /// stream or the sequence is malformed). - pub fn decodeSequenceSlice( - self: *DecodeState, - dest: []u8, - write_pos: usize, - bit_reader: *ReverseBitReader, - sequence_size_limit: usize, - last_sequence: bool, - ) DecodeSequenceError!usize { - const sequence = try self.nextSequence(bit_reader); - const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; - if (sequence_length > sequence_size_limit) return error.MalformedSequence; - - try self.executeSequenceSlice(dest, write_pos, sequence); - if (!last_sequence) { - try self.updateState(.literal, bit_reader); - try self.updateState(.match, bit_reader); - try self.updateState(.offset, bit_reader); - } - return sequence_length; - } - - /// Decode one sequence from `bit_reader` into `dest`; see `decodeSequenceSlice`. - pub fn decodeSequenceRingBuffer( - self: *DecodeState, - dest: *RingBuffer, - bit_reader: anytype, - sequence_size_limit: usize, - last_sequence: bool, - ) DecodeSequenceError!usize { - const sequence = try self.nextSequence(bit_reader); - const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; - if (sequence_length > sequence_size_limit) return error.MalformedSequence; - - try self.executeSequenceRingBuffer(dest, sequence); - if (!last_sequence) { - try self.updateState(.literal, bit_reader); - try self.updateState(.match, bit_reader); - try self.updateState(.offset, bit_reader); - } - return sequence_length; - } - - fn nextLiteralMultiStream( - self: *DecodeState, - ) error{BitStreamHasNoStartBit}!void { - self.literal_stream_index += 1; - try self.initLiteralStream(self.literal_streams.four[self.literal_stream_index]); - } - - pub fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void { - try self.literal_stream_reader.init(bytes); - } - - const LiteralBitsError = error{ - BitStreamHasNoStartBit, - UnexpectedEndOfLiteralStream, - }; - fn readLiteralsBits( - self: *DecodeState, - comptime T: type, - bit_count_to_read: usize, - ) LiteralBitsError!T { - return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: { - if (self.literal_streams == .four and self.literal_stream_index < 3) { - try self.nextLiteralMultiStream(); - break :bits self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch - return error.UnexpectedEndOfLiteralStream; - } else { - return error.UnexpectedEndOfLiteralStream; - } - }; - } - - const DecodeLiteralsError = error{ - MalformedLiteralsLength, - PrefixNotFound, - } || LiteralBitsError; - - /// Decode `len` bytes of literals into `dest`. `literals` should be the - /// `LiteralsSection` that was passed to `prepare()`. Returns - /// `error.MalformedLiteralsLength` if the number of literal bytes decoded by - /// `self` plus `len` is greater than the regenerated size of `literals`. - /// Returns `error.UnexpectedEndOfLiteralStream` and `error.PrefixNotFound` if - /// there are problems decoding Huffman compressed literals. - pub fn decodeLiteralsSlice( - self: *DecodeState, - dest: []u8, - len: usize, - ) DecodeLiteralsError!void { - if (self.literal_written_count + len > self.literal_header.regenerated_size) - return error.MalformedLiteralsLength; - - switch (self.literal_header.block_type) { - .raw => { - const literals_end = self.literal_written_count + len; - const literal_data = self.literal_streams.one[self.literal_written_count..literals_end]; - std.mem.copy(u8, dest, literal_data); - self.literal_written_count += len; - }, - .rle => { - var i: usize = 0; - while (i < len) : (i += 1) { - dest[i] = self.literal_streams.one[0]; - } - self.literal_written_count += len; - }, - .compressed, .treeless => { - // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; - const huffman_tree = self.huffman_tree orelse unreachable; - const max_bit_count = huffman_tree.max_bit_count; - const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount( - huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight, - max_bit_count, - ); - var bits_read: u4 = 0; - var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one; - var bit_count_to_read: u4 = starting_bit_count; - var i: usize = 0; - while (i < len) : (i += 1) { - var prefix: u16 = 0; - while (true) { - const new_bits = self.readLiteralsBits(u16, bit_count_to_read) catch |err| { - return err; - }; - prefix <<= bit_count_to_read; - prefix |= new_bits; - bits_read += bit_count_to_read; - const result = huffman_tree.query(huffman_tree_index, prefix) catch |err| { - return err; - }; - - switch (result) { - .symbol => |sym| { - dest[i] = sym; - bit_count_to_read = starting_bit_count; - bits_read = 0; - huffman_tree_index = huffman_tree.symbol_count_minus_one; - break; - }, - .index => |index| { - huffman_tree_index = index; - const bit_count = LiteralsSection.HuffmanTree.weightToBitCount( - huffman_tree.nodes[index].weight, - max_bit_count, - ); - bit_count_to_read = bit_count - bits_read; - }, - } - } - } - self.literal_written_count += len; - }, - } - } - - /// Decode literals into `dest`; see `decodeLiteralsSlice()`. - pub fn decodeLiteralsRingBuffer( - self: *DecodeState, - dest: *RingBuffer, - len: usize, - ) DecodeLiteralsError!void { - if (self.literal_written_count + len > self.literal_header.regenerated_size) - return error.MalformedLiteralsLength; - - switch (self.literal_header.block_type) { - .raw => { - const literals_end = self.literal_written_count + len; - const literal_data = self.literal_streams.one[self.literal_written_count..literals_end]; - dest.writeSliceAssumeCapacity(literal_data); - self.literal_written_count += len; - }, - .rle => { - var i: usize = 0; - while (i < len) : (i += 1) { - dest.writeAssumeCapacity(self.literal_streams.one[0]); - } - self.literal_written_count += len; - }, - .compressed, .treeless => { - // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; - const huffman_tree = self.huffman_tree orelse unreachable; - const max_bit_count = huffman_tree.max_bit_count; - const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount( - huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight, - max_bit_count, - ); - var bits_read: u4 = 0; - var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one; - var bit_count_to_read: u4 = starting_bit_count; - var i: usize = 0; - while (i < len) : (i += 1) { - var prefix: u16 = 0; - while (true) { - const new_bits = try self.readLiteralsBits(u16, bit_count_to_read); - prefix <<= bit_count_to_read; - prefix |= new_bits; - bits_read += bit_count_to_read; - const result = try huffman_tree.query(huffman_tree_index, prefix); - - switch (result) { - .symbol => |sym| { - dest.writeAssumeCapacity(sym); - bit_count_to_read = starting_bit_count; - bits_read = 0; - huffman_tree_index = huffman_tree.symbol_count_minus_one; - break; - }, - .index => |index| { - huffman_tree_index = index; - const bit_count = LiteralsSection.HuffmanTree.weightToBitCount( - huffman_tree.nodes[index].weight, - max_bit_count, - ); - bit_count_to_read = bit_count - bits_read; - }, - } - } - } - self.literal_written_count += len; - }, - } - } - - fn getCode(self: *DecodeState, comptime choice: DataType) u32 { - return switch (@field(self, @tagName(choice)).table) { - .rle => |value| value, - .fse => |table| table[@field(self, @tagName(choice)).state].symbol, - }; - } -}; - pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 { const hash = hasher.final(); return @intCast(u32, hash & 0xFFFFFFFF); @@ -589,7 +74,7 @@ pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 { const FrameError = error{ DictionaryIdFlagUnsupported, ChecksumFailure, -} || InvalidBit || DecodeBlockError; +} || InvalidBit || block.Error; /// Decode a Zstandard frame from `src` into `dest`, returning the number of /// bytes read from `src` and written to `dest`; if the frame does not declare @@ -695,15 +180,15 @@ pub fn decodeZStandardFrameAlloc( var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined; var offset_fse_data: [types.compressed_block.table_size_max.offset]Table.Fse = undefined; - var block_header = try decodeBlockHeaderSlice(src[consumed_count..]); + var block_header = try block.decodeBlockHeaderSlice(src[consumed_count..]); consumed_count += 3; - var decode_state = DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data); + var decode_state = block.DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data); while (true) : ({ - block_header = try decodeBlockHeaderSlice(src[consumed_count..]); + block_header = try block.decodeBlockHeaderSlice(src[consumed_count..]); consumed_count += 3; }) { if (block_header.block_size > frame_context.block_size_max) return error.BlockSizeOverMaximum; - const written_size = try decodeBlockRingBuffer( + const written_size = try block.decodeBlockRingBuffer( &ring_buffer, src[consumed_count..], block_header, @@ -731,37 +216,28 @@ pub fn decodeZStandardFrameAlloc( return result.toOwnedSlice(); } -const DecodeBlockError = error{ - BlockSizeOverMaximum, - MalformedBlockSize, - ReservedBlock, - MalformedRleBlock, - MalformedCompressedBlock, - EndOfStream, -}; - /// Convenience wrapper for decoding all blocks in a frame; see `decodeBlock()`. -pub fn decodeFrameBlocks( +fn decodeFrameBlocks( dest: []u8, src: []const u8, consumed_count: *usize, hash: ?*std.hash.XxHash64, -) DecodeBlockError!usize { +) block.Error!usize { // These tables take 7680 bytes var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined; var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined; var offset_fse_data: [types.compressed_block.table_size_max.offset]Table.Fse = undefined; - var block_header = try decodeBlockHeaderSlice(src); + var block_header = try block.decodeBlockHeaderSlice(src); var bytes_read: usize = 3; defer consumed_count.* += bytes_read; - var decode_state = DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data); + var decode_state = block.DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data); var written_count: usize = 0; while (true) : ({ - block_header = try decodeBlockHeaderSlice(src[bytes_read..]); + block_header = try block.decodeBlockHeaderSlice(src[bytes_read..]); bytes_read += 3; }) { - const written_size = try decodeBlock( + const written_size = try block.decodeBlock( dest, src[bytes_read..], block_header, @@ -776,255 +252,6 @@ pub fn decodeFrameBlocks( return written_count; } -/// Decode a single block from `src` into `dest`. The beginning of `src` should -/// be the start of the block content (i.e. directly after the block header). -/// Increments `consumed_count` by the number of bytes read from `src` to decode -/// the block and returns the decompressed size of the block. -pub fn decodeBlock( - dest: []u8, - src: []const u8, - block_header: frame.ZStandard.Block.Header, - decode_state: *DecodeState, - consumed_count: *usize, - written_count: usize, -) DecodeBlockError!usize { - const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB - const block_size = block_header.block_size; - if (block_size_max < block_size) return error.BlockSizeOverMaximum; - switch (block_header.block_type) { - .raw => { - if (src.len < block_size) return error.MalformedBlockSize; - const data = src[0..block_size]; - std.mem.copy(u8, dest[written_count..], data); - consumed_count.* += block_size; - return block_size; - }, - .rle => { - if (src.len < 1) return error.MalformedRleBlock; - var write_pos: usize = written_count; - while (write_pos < block_size + written_count) : (write_pos += 1) { - dest[write_pos] = src[0]; - } - consumed_count.* += 1; - return block_size; - }, - .compressed => { - if (src.len < block_size) return error.MalformedBlockSize; - var bytes_read: usize = 0; - const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch - return error.MalformedCompressedBlock; - var fbs = std.io.fixedBufferStream(src[bytes_read..]); - const fbs_reader = fbs.reader(); - const sequences_header = decodeSequencesHeader(fbs_reader) catch - return error.MalformedCompressedBlock; - - decode_state.prepare(fbs_reader, literals, sequences_header) catch - return error.MalformedCompressedBlock; - - bytes_read += fbs.pos; - - var bytes_written: usize = 0; - if (sequences_header.sequence_count > 0) { - const bit_stream_bytes = src[bytes_read..block_size]; - var bit_stream: ReverseBitReader = undefined; - bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock; - - decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock; - - var sequence_size_limit = block_size_max; - var i: usize = 0; - while (i < sequences_header.sequence_count) : (i += 1) { - const write_pos = written_count + bytes_written; - const decompressed_size = decode_state.decodeSequenceSlice( - dest, - write_pos, - &bit_stream, - sequence_size_limit, - i == sequences_header.sequence_count - 1, - ) catch return error.MalformedCompressedBlock; - bytes_written += decompressed_size; - sequence_size_limit -= decompressed_size; - } - - bytes_read += bit_stream_bytes.len; - } - if (bytes_read != block_size) return error.MalformedCompressedBlock; - - if (decode_state.literal_written_count < literals.header.regenerated_size) { - const len = literals.header.regenerated_size - decode_state.literal_written_count; - decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], len) catch - return error.MalformedCompressedBlock; - bytes_written += len; - } - - consumed_count.* += bytes_read; - return bytes_written; - }, - .reserved => return error.ReservedBlock, - } -} - -/// Decode a single block from `src` into `dest`; see `decodeBlock()`. Returns -/// the size of the decompressed block, which can be used with `dest.sliceLast()` -/// to get the decompressed bytes. -pub fn decodeBlockRingBuffer( - dest: *RingBuffer, - src: []const u8, - block_header: frame.ZStandard.Block.Header, - decode_state: *DecodeState, - consumed_count: *usize, - block_size_max: usize, -) DecodeBlockError!usize { - const block_size = block_header.block_size; - if (block_size_max < block_size) return error.BlockSizeOverMaximum; - switch (block_header.block_type) { - .raw => { - if (src.len < block_size) return error.MalformedBlockSize; - const data = src[0..block_size]; - dest.writeSliceAssumeCapacity(data); - consumed_count.* += block_size; - return block_size; - }, - .rle => { - if (src.len < 1) return error.MalformedRleBlock; - var write_pos: usize = 0; - while (write_pos < block_size) : (write_pos += 1) { - dest.writeAssumeCapacity(src[0]); - } - consumed_count.* += 1; - return block_size; - }, - .compressed => { - if (src.len < block_size) return error.MalformedBlockSize; - var bytes_read: usize = 0; - const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch - return error.MalformedCompressedBlock; - var fbs = std.io.fixedBufferStream(src[bytes_read..]); - const fbs_reader = fbs.reader(); - const sequences_header = decodeSequencesHeader(fbs_reader) catch - return error.MalformedCompressedBlock; - - decode_state.prepare(fbs_reader, literals, sequences_header) catch - return error.MalformedCompressedBlock; - - bytes_read += fbs.pos; - - var bytes_written: usize = 0; - if (sequences_header.sequence_count > 0) { - const bit_stream_bytes = src[bytes_read..block_size]; - var bit_stream: ReverseBitReader = undefined; - bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock; - - decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock; - - var sequence_size_limit = block_size_max; - var i: usize = 0; - while (i < sequences_header.sequence_count) : (i += 1) { - const decompressed_size = decode_state.decodeSequenceRingBuffer( - dest, - &bit_stream, - sequence_size_limit, - i == sequences_header.sequence_count - 1, - ) catch return error.MalformedCompressedBlock; - bytes_written += decompressed_size; - sequence_size_limit -= decompressed_size; - } - - bytes_read += bit_stream_bytes.len; - } - if (bytes_read != block_size) return error.MalformedCompressedBlock; - - if (decode_state.literal_written_count < literals.header.regenerated_size) { - const len = literals.header.regenerated_size - decode_state.literal_written_count; - decode_state.decodeLiteralsRingBuffer(dest, len) catch - return error.MalformedCompressedBlock; - bytes_written += len; - } - - consumed_count.* += bytes_read; - if (bytes_written > block_size_max) return error.BlockSizeOverMaximum; - return bytes_written; - }, - .reserved => return error.ReservedBlock, - } -} - -/// Decode a single block from `source` into `dest`. Literal and sequence data -/// from the block is copied into `literals_buffer` and `sequence_buffer`, which -/// must be large enough or `error.LiteralsBufferTooSmall` and -/// `error.SequenceBufferTooSmall` are returned (the maximum block size is an -/// upper bound for the size of both buffers). See `decodeBlock` -/// and `decodeBlockRingBuffer` for function that can decode a block without -/// these extra copies. -pub fn decodeBlockReader( - dest: *RingBuffer, - source: anytype, - block_header: frame.ZStandard.Block.Header, - decode_state: *DecodeState, - block_size_max: usize, - literals_buffer: []u8, - sequence_buffer: []u8, -) !void { - const block_size = block_header.block_size; - var block_reader_limited = std.io.limitedReader(source, block_size); - const block_reader = block_reader_limited.reader(); - if (block_size_max < block_size) return error.BlockSizeOverMaximum; - switch (block_header.block_type) { - .raw => { - const slice = dest.sliceAt(dest.write_index, block_size); - try source.readNoEof(slice.first); - try source.readNoEof(slice.second); - dest.write_index = dest.mask2(dest.write_index + block_size); - }, - .rle => { - const byte = try source.readByte(); - var i: usize = 0; - while (i < block_size) : (i += 1) { - dest.writeAssumeCapacity(byte); - } - }, - .compressed => { - const literals = try decodeLiteralsSection(block_reader, literals_buffer); - const sequences_header = try decodeSequencesHeader(block_reader); - - try decode_state.prepare(block_reader, literals, sequences_header); - - if (sequences_header.sequence_count > 0) { - if (sequence_buffer.len < block_reader_limited.bytes_left) - return error.SequenceBufferTooSmall; - - const size = try block_reader.readAll(sequence_buffer); - var bit_stream: ReverseBitReader = undefined; - try bit_stream.init(sequence_buffer[0..size]); - - decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock; - - var sequence_size_limit = block_size_max; - var i: usize = 0; - while (i < sequences_header.sequence_count) : (i += 1) { - const decompressed_size = decode_state.decodeSequenceRingBuffer( - dest, - &bit_stream, - sequence_size_limit, - i == sequences_header.sequence_count - 1, - ) catch return error.MalformedCompressedBlock; - sequence_size_limit -= decompressed_size; - } - } - - if (decode_state.literal_written_count < literals.header.regenerated_size) { - const len = literals.header.regenerated_size - decode_state.literal_written_count; - decode_state.decodeLiteralsRingBuffer(dest, len) catch - return error.MalformedCompressedBlock; - } - - decode_state.literal_written_count = 0; - assert(block_reader.readByte() == error.EndOfStream); - }, - .reserved => return error.ReservedBlock, - } -} - /// Decode the header of a skippable frame. pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header { const magic = readInt(u32, src[0..4]); @@ -1090,673 +317,6 @@ pub fn decodeZStandardHeader(source: anytype) (error{EndOfStream} || InvalidBit) return header; } -/// Decode the header of a block. -pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header { - const last_block = src[0] & 1 == 1; - const block_type = @intToEnum(frame.ZStandard.Block.Type, (src[0] & 0b110) >> 1); - const block_size = ((src[0] & 0b11111000) >> 3) + (@as(u21, src[1]) << 5) + (@as(u21, src[2]) << 13); - return .{ - .last_block = last_block, - .block_type = block_type, - .block_size = block_size, - }; -} - -pub fn decodeBlockHeaderSlice(src: []const u8) error{EndOfStream}!frame.ZStandard.Block.Header { - if (src.len < 3) return error.EndOfStream; - return decodeBlockHeader(src[0..3]); -} - -/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the -/// number of bytes the section uses. -/// -/// Errors: -/// - returns `error.MalformedLiteralsHeader` if the header is invalid -/// - returns `error.MalformedLiteralsSection` if there are errors decoding -pub fn decodeLiteralsSectionSlice( - src: []const u8, - consumed_count: *usize, -) (error{ MalformedLiteralsHeader, MalformedLiteralsSection, EndOfStream } || DecodeHuffmanError)!LiteralsSection { - var bytes_read: usize = 0; - const header = header: { - var fbs = std.io.fixedBufferStream(src); - defer bytes_read = fbs.pos; - break :header decodeLiteralsHeader(fbs.reader()) catch return error.MalformedLiteralsHeader; - }; - switch (header.block_type) { - .raw => { - if (src.len < bytes_read + header.regenerated_size) return error.MalformedLiteralsSection; - const stream = src[bytes_read .. bytes_read + header.regenerated_size]; - consumed_count.* += header.regenerated_size + bytes_read; - return LiteralsSection{ - .header = header, - .huffman_tree = null, - .streams = .{ .one = stream }, - }; - }, - .rle => { - if (src.len < bytes_read + 1) return error.MalformedLiteralsSection; - const stream = src[bytes_read .. bytes_read + 1]; - consumed_count.* += 1 + bytes_read; - return LiteralsSection{ - .header = header, - .huffman_tree = null, - .streams = .{ .one = stream }, - }; - }, - .compressed, .treeless => { - const huffman_tree_start = bytes_read; - const huffman_tree = if (header.block_type == .compressed) - try decodeHuffmanTreeSlice(src[bytes_read..], &bytes_read) - else - null; - const huffman_tree_size = bytes_read - huffman_tree_start; - const total_streams_size = @as(usize, header.compressed_size.?) - huffman_tree_size; - - if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection; - const stream_data = src[bytes_read .. bytes_read + total_streams_size]; - - const streams = try decodeStreams(header.size_format, stream_data); - consumed_count.* += bytes_read + total_streams_size; - return LiteralsSection{ - .header = header, - .huffman_tree = huffman_tree, - .streams = streams, - }; - }, - } -} - -/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the -/// number of bytes the section uses. -/// -/// Errors: -/// - returns `error.MalformedLiteralsHeader` if the header is invalid -/// - returns `error.MalformedLiteralsSection` if there are errors decoding -pub fn decodeLiteralsSection( - source: anytype, - buffer: []u8, -) !LiteralsSection { - const header = try decodeLiteralsHeader(source); - switch (header.block_type) { - .raw => { - try source.readNoEof(buffer[0..header.regenerated_size]); - return LiteralsSection{ - .header = header, - .huffman_tree = null, - .streams = .{ .one = buffer }, - }; - }, - .rle => { - buffer[0] = try source.readByte(); - return LiteralsSection{ - .header = header, - .huffman_tree = null, - .streams = .{ .one = buffer[0..1] }, - }; - }, - .compressed, .treeless => { - var counting_reader = std.io.countingReader(source); - const huffman_tree = if (header.block_type == .compressed) - try decodeHuffmanTree(counting_reader.reader(), buffer) - else - null; - const huffman_tree_size = counting_reader.bytes_read; - const total_streams_size = @as(usize, header.compressed_size.?) - @intCast(usize, huffman_tree_size); - - if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall; - try source.readNoEof(buffer[0..total_streams_size]); - const stream_data = buffer[0..total_streams_size]; - - const streams = try decodeStreams(header.size_format, stream_data); - return LiteralsSection{ - .header = header, - .huffman_tree = huffman_tree, - .streams = streams, - }; - }, - } -} - -fn decodeStreams(size_format: u2, stream_data: []const u8) !LiteralsSection.Streams { - if (size_format == 0) { - return .{ .one = stream_data }; - } - - if (stream_data.len < 6) return error.MalformedLiteralsSection; - - const stream_1_length = @as(usize, readInt(u16, stream_data[0..2])); - const stream_2_length = @as(usize, readInt(u16, stream_data[2..4])); - const stream_3_length = @as(usize, readInt(u16, stream_data[4..6])); - - const stream_1_start = 6; - const stream_2_start = stream_1_start + stream_1_length; - const stream_3_start = stream_2_start + stream_2_length; - const stream_4_start = stream_3_start + stream_3_length; - - return .{ .four = .{ - stream_data[stream_1_start .. stream_1_start + stream_1_length], - stream_data[stream_2_start .. stream_2_start + stream_2_length], - stream_data[stream_3_start .. stream_3_start + stream_3_length], - stream_data[stream_4_start..], - } }; -} - -const DecodeHuffmanError = error{ - MalformedHuffmanTree, - MalformedFseTable, - MalformedAccuracyLog, -}; - -fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, weights: *[256]u4) !usize { - var stream = std.io.limitedReader(source, compressed_size); - var bit_reader = bitReader(stream.reader()); - - var entries: [1 << 6]Table.Fse = undefined; - const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) { - error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e, - error.EndOfStream => return error.MalformedFseTable, - }; - const accuracy_log = std.math.log2_int_ceil(usize, table_size); - - const amount = try stream.reader().readAll(buffer); - var huff_bits: ReverseBitReader = undefined; - huff_bits.init(buffer[0..amount]) catch return error.MalformedHuffmanTree; - - return assignWeights(&huff_bits, accuracy_log, &entries, weights); -} - -fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *[256]u4) !usize { - if (src.len < compressed_size) return error.MalformedHuffmanTree; - var stream = std.io.fixedBufferStream(src[0..compressed_size]); - var counting_reader = std.io.countingReader(stream.reader()); - var bit_reader = bitReader(counting_reader.reader()); - - var entries: [1 << 6]Table.Fse = undefined; - const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) { - error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e, - error.EndOfStream => return error.MalformedFseTable, - }; - const accuracy_log = std.math.log2_int_ceil(usize, table_size); - - const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedHuffmanTree; - var huff_data = src[start_index..compressed_size]; - var huff_bits: ReverseBitReader = undefined; - huff_bits.init(huff_data) catch return error.MalformedHuffmanTree; - - return assignWeights(&huff_bits, accuracy_log, &entries, weights); -} - -fn assignWeights(huff_bits: *ReverseBitReader, accuracy_log: usize, entries: *[1 << 6]Table.Fse, weights: *[256]u4) !usize { - var i: usize = 0; - var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; - var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; - - while (i < 255) { - const even_data = entries[even_state]; - var read_bits: usize = 0; - const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable; - weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree; - i += 1; - if (read_bits < even_data.bits) { - weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree; - i += 1; - break; - } - even_state = even_data.baseline + even_bits; - - read_bits = 0; - const odd_data = entries[odd_state]; - const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable; - weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree; - i += 1; - if (read_bits < odd_data.bits) { - if (i == 256) return error.MalformedHuffmanTree; - weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree; - i += 1; - break; - } - odd_state = odd_data.baseline + odd_bits; - } else return error.MalformedHuffmanTree; - - return i + 1; // stream contains all but the last symbol -} - -fn decodeDirectHuffmanTree(source: anytype, encoded_symbol_count: usize, weights: *[256]u4) !usize { - const weights_byte_count = (encoded_symbol_count + 1) / 2; - var i: usize = 0; - while (i < weights_byte_count) : (i += 1) { - const byte = try source.readByte(); - weights[2 * i] = @intCast(u4, byte >> 4); - weights[2 * i + 1] = @intCast(u4, byte & 0xF); - } - return encoded_symbol_count + 1; -} - -fn assignSymbols(weight_sorted_prefixed_symbols: []LiteralsSection.HuffmanTree.PrefixedSymbol, weights: [256]u4) usize { - for (weight_sorted_prefixed_symbols) |_, i| { - weight_sorted_prefixed_symbols[i] = .{ - .symbol = @intCast(u8, i), - .weight = undefined, - .prefix = undefined, - }; - } - - std.sort.sort( - LiteralsSection.HuffmanTree.PrefixedSymbol, - weight_sorted_prefixed_symbols, - weights, - lessThanByWeight, - ); - - var prefix: u16 = 0; - var prefixed_symbol_count: usize = 0; - var sorted_index: usize = 0; - const symbol_count = weight_sorted_prefixed_symbols.len; - while (sorted_index < symbol_count) { - var symbol = weight_sorted_prefixed_symbols[sorted_index].symbol; - const weight = weights[symbol]; - if (weight == 0) { - sorted_index += 1; - continue; - } - - while (sorted_index < symbol_count) : ({ - sorted_index += 1; - prefixed_symbol_count += 1; - prefix += 1; - }) { - symbol = weight_sorted_prefixed_symbols[sorted_index].symbol; - if (weights[symbol] != weight) { - prefix = ((prefix - 1) >> (weights[symbol] - weight)) + 1; - break; - } - weight_sorted_prefixed_symbols[prefixed_symbol_count].symbol = symbol; - weight_sorted_prefixed_symbols[prefixed_symbol_count].prefix = prefix; - weight_sorted_prefixed_symbols[prefixed_symbol_count].weight = weight; - } - } - return prefixed_symbol_count; -} - -fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) LiteralsSection.HuffmanTree { - var weight_power_sum: u16 = 0; - for (weights[0 .. symbol_count - 1]) |value| { - if (value > 0) { - weight_power_sum += @as(u16, 1) << (value - 1); - } - } - - // advance to next power of two (even if weight_power_sum is a power of 2) - const max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1; - const next_power_of_two = @as(u16, 1) << max_number_of_bits; - weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1; - - var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined; - const prefixed_symbol_count = assignSymbols(weight_sorted_prefixed_symbols[0..symbol_count], weights.*); - const tree = LiteralsSection.HuffmanTree{ - .max_bit_count = max_number_of_bits, - .symbol_count_minus_one = @intCast(u8, prefixed_symbol_count - 1), - .nodes = weight_sorted_prefixed_symbols, - }; - return tree; -} - -fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.HuffmanTree { - const header = try source.readByte(); - var weights: [256]u4 = undefined; - const symbol_count = if (header < 128) - // FSE compressed weights - try decodeFseHuffmanTree(source, header, buffer, &weights) - else - try decodeDirectHuffmanTree(source, header - 127, &weights); - - return buildHuffmanTree(&weights, symbol_count); -} - -fn decodeHuffmanTreeSlice(src: []const u8, consumed_count: *usize) (error{EndOfStream} || DecodeHuffmanError)!LiteralsSection.HuffmanTree { - if (src.len == 0) return error.MalformedHuffmanTree; - const header = src[0]; - var bytes_read: usize = 1; - var weights: [256]u4 = undefined; - const symbol_count = if (header < 128) count: { - // FSE compressed weights - bytes_read += header; - break :count try decodeFseHuffmanTreeSlice(src[1..], header, &weights); - } else count: { - var fbs = std.io.fixedBufferStream(src[1..]); - defer bytes_read += fbs.pos; - break :count try decodeDirectHuffmanTree(fbs.reader(), header - 127, &weights); - }; - - consumed_count.* += bytes_read; - return buildHuffmanTree(&weights, symbol_count); -} - -fn lessThanByWeight( - weights: [256]u4, - lhs: LiteralsSection.HuffmanTree.PrefixedSymbol, - rhs: LiteralsSection.HuffmanTree.PrefixedSymbol, -) bool { - // NOTE: this function relies on the use of a stable sorting algorithm, - // otherwise a special case of if (weights[lhs] == weights[rhs]) return lhs < rhs; - // should be added - return weights[lhs.symbol] < weights[rhs.symbol]; -} - -/// Decode a literals section header. -pub fn decodeLiteralsHeader(source: anytype) !LiteralsSection.Header { - const byte0 = try source.readByte(); - const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11); - const size_format = @intCast(u2, (byte0 & 0b1100) >> 2); - var regenerated_size: u20 = undefined; - var compressed_size: ?u18 = null; - switch (block_type) { - .raw, .rle => { - switch (size_format) { - 0, 2 => { - regenerated_size = byte0 >> 3; - }, - 1 => regenerated_size = (byte0 >> 4) + (@as(u20, try source.readByte()) << 4), - 3 => regenerated_size = (byte0 >> 4) + - (@as(u20, try source.readByte()) << 4) + - (@as(u20, try source.readByte()) << 12), - } - }, - .compressed, .treeless => { - const byte1 = try source.readByte(); - const byte2 = try source.readByte(); - switch (size_format) { - 0, 1 => { - regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4); - compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2); - }, - 2 => { - const byte3 = try source.readByte(); - regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12); - compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6); - }, - 3 => { - const byte3 = try source.readByte(); - const byte4 = try source.readByte(); - regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12); - compressed_size = ((byte2 & 0b11000000) >> 6) + (@as(u18, byte3) << 2) + (@as(u18, byte4) << 10); - }, - } - }, - } - return LiteralsSection.Header{ - .block_type = block_type, - .size_format = size_format, - .regenerated_size = regenerated_size, - .compressed_size = compressed_size, - }; -} - -/// Decode a sequences section header. -/// -/// Errors: -/// - returns `error.ReservedBitSet` is the reserved bit is set -/// - returns `error.MalformedSequencesHeader` if the header is invalid -pub fn decodeSequencesHeader( - source: anytype, -) !SequencesSection.Header { - var sequence_count: u24 = undefined; - - const byte0 = try source.readByte(); - if (byte0 == 0) { - return SequencesSection.Header{ - .sequence_count = 0, - .offsets = undefined, - .match_lengths = undefined, - .literal_lengths = undefined, - }; - } else if (byte0 < 128) { - sequence_count = byte0; - } else if (byte0 < 255) { - sequence_count = (@as(u24, (byte0 - 128)) << 8) + try source.readByte(); - } else { - sequence_count = (try source.readByte()) + (@as(u24, try source.readByte()) << 8) + 0x7F00; - } - - const compression_modes = try source.readByte(); - - const matches_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00001100) >> 2); - const offsets_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00110000) >> 4); - const literal_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b11000000) >> 6); - if (compression_modes & 0b11 != 0) return error.ReservedBitSet; - - return SequencesSection.Header{ - .sequence_count = sequence_count, - .offsets = offsets_mode, - .match_lengths = matches_mode, - .literal_lengths = literal_mode, - }; -} - -fn buildFseTable(values: []const u16, entries: []Table.Fse) !void { - const total_probability = @intCast(u16, entries.len); - const accuracy_log = std.math.log2_int(u16, total_probability); - assert(total_probability <= 1 << 9); - - var less_than_one_count: usize = 0; - for (values) |value, i| { - if (value == 0) { - entries[entries.len - 1 - less_than_one_count] = Table.Fse{ - .symbol = @intCast(u8, i), - .baseline = 0, - .bits = accuracy_log, - }; - less_than_one_count += 1; - } - } - - var position: usize = 0; - var temp_states: [1 << 9]u16 = undefined; - for (values) |value, symbol| { - if (value == 0 or value == 1) continue; - const probability = value - 1; - - const state_share_dividend = std.math.ceilPowerOfTwo(u16, probability) catch - return error.MalformedFseTable; - const share_size = @divExact(total_probability, state_share_dividend); - const double_state_count = state_share_dividend - probability; - const single_state_count = probability - double_state_count; - const share_size_log = std.math.log2_int(u16, share_size); - - var i: u16 = 0; - while (i < probability) : (i += 1) { - temp_states[i] = @intCast(u16, position); - position += (entries.len >> 1) + (entries.len >> 3) + 3; - position &= entries.len - 1; - while (position >= entries.len - less_than_one_count) { - position += (entries.len >> 1) + (entries.len >> 3) + 3; - position &= entries.len - 1; - } - } - std.sort.sort(u16, temp_states[0..probability], {}, std.sort.asc(u16)); - i = 0; - while (i < probability) : (i += 1) { - entries[temp_states[i]] = if (i < double_state_count) Table.Fse{ - .symbol = @intCast(u8, symbol), - .bits = share_size_log + 1, - .baseline = single_state_count * share_size + i * 2 * share_size, - } else Table.Fse{ - .symbol = @intCast(u8, symbol), - .bits = share_size_log, - .baseline = (i - double_state_count) * share_size, - }; - } - } -} - -fn decodeFseTable( - bit_reader: anytype, - expected_symbol_count: usize, - max_accuracy_log: u4, - entries: []Table.Fse, -) !usize { - const accuracy_log_biased = try bit_reader.readBitsNoEof(u4, 4); - if (accuracy_log_biased > max_accuracy_log -| 5) return error.MalformedAccuracyLog; - const accuracy_log = accuracy_log_biased + 5; - - var values: [256]u16 = undefined; - var value_count: usize = 0; - - const total_probability = @as(u16, 1) << accuracy_log; - var accumulated_probability: u16 = 0; - - while (accumulated_probability < total_probability) { - // WARNING: The RFC in poorly worded, and would suggest std.math.log2_int_ceil is correct here, - // but power of two (remaining probabilities + 1) need max bits set to 1 more. - const max_bits = std.math.log2_int(u16, total_probability - accumulated_probability + 1) + 1; - const small = try bit_reader.readBitsNoEof(u16, max_bits - 1); - - const cutoff = (@as(u16, 1) << max_bits) - 1 - (total_probability - accumulated_probability + 1); - - const value = if (small < cutoff) - small - else value: { - const value_read = small + (try bit_reader.readBitsNoEof(u16, 1) << (max_bits - 1)); - break :value if (value_read < @as(u16, 1) << (max_bits - 1)) - value_read - else - value_read - cutoff; - }; - - accumulated_probability += if (value != 0) value - 1 else 1; - - values[value_count] = value; - value_count += 1; - - if (value == 1) { - while (true) { - const repeat_flag = try bit_reader.readBitsNoEof(u2, 2); - var i: usize = 0; - while (i < repeat_flag) : (i += 1) { - values[value_count] = 1; - value_count += 1; - } - if (repeat_flag < 3) break; - } - } - } - bit_reader.alignToByte(); - - if (value_count < 2) return error.MalformedFseTable; - if (accumulated_probability != total_probability) return error.MalformedFseTable; - if (value_count > expected_symbol_count) return error.MalformedFseTable; - - const table_size = total_probability; - - try buildFseTable(values[0..value_count], entries[0..table_size]); - return table_size; -} - -const ReversedByteReader = struct { - remaining_bytes: usize, - bytes: []const u8, - - const Reader = std.io.Reader(*ReversedByteReader, error{}, readFn); - - fn init(bytes: []const u8) ReversedByteReader { - return .{ - .bytes = bytes, - .remaining_bytes = bytes.len, - }; - } - - fn reader(self: *ReversedByteReader) Reader { - return .{ .context = self }; - } - - fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize { - if (ctx.remaining_bytes == 0) return 0; - const byte_index = ctx.remaining_bytes - 1; - buffer[0] = ctx.bytes[byte_index]; - // buffer[0] = @bitReverse(ctx.bytes[byte_index]); - ctx.remaining_bytes = byte_index; - return 1; - } -}; - -/// A bit reader for reading the reversed bit streams used to encode -/// FSE compressed data. -pub const ReverseBitReader = struct { - byte_reader: ReversedByteReader, - bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader), - - pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void { - self.byte_reader = ReversedByteReader.init(bytes); - self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader()); - while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {} - } - - pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U { - return self.bit_reader.readBitsNoEof(U, num_bits); - } - - pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) error{}!U { - return try self.bit_reader.readBits(U, num_bits, out_bits); - } - - pub fn alignToByte(self: *@This()) void { - self.bit_reader.alignToByte(); - } -}; - -fn BitReader(comptime Reader: type) type { - return struct { - underlying: std.io.BitReader(.Little, Reader), - - fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U { - return self.underlying.readBitsNoEof(U, num_bits); - } - - fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U { - return self.underlying.readBits(U, num_bits, out_bits); - } - - fn alignToByte(self: *@This()) void { - self.underlying.alignToByte(); - } - }; -} - -pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) { - return .{ .underlying = std.io.bitReader(.Little, reader) }; -} - test { std.testing.refAllDecls(@This()); } - -test buildFseTable { - const literals_length_default_values = [36]u16{ - 5, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 2, 2, 2, 2, 2, - 0, 0, 0, 0, - }; - - const match_lengths_default_values = [53]u16{ - 2, 5, 4, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, - 0, 0, 0, 0, 0, - }; - - const offset_codes_default_values = [29]u16{ - 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, - }; - - var entries: [64]Table.Fse = undefined; - try buildFseTable(&literals_length_default_values, &entries); - try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_literal_fse_table.fse, &entries); - - try buildFseTable(&match_lengths_default_values, &entries); - try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_match_fse_table.fse, &entries); - - try buildFseTable(&offset_codes_default_values, entries[0..32]); - try std.testing.expectEqualSlices(Table.Fse, types.compressed_block.predefined_offset_fse_table.fse, entries[0..32]); -} diff --git a/lib/std/compress/zstandard/readers.zig b/lib/std/compress/zstandard/readers.zig new file mode 100644 index 0000000000..489f933310 --- /dev/null +++ b/lib/std/compress/zstandard/readers.zig @@ -0,0 +1,75 @@ +const std = @import("std"); + +pub const ReversedByteReader = struct { + remaining_bytes: usize, + bytes: []const u8, + + const Reader = std.io.Reader(*ReversedByteReader, error{}, readFn); + + pub fn init(bytes: []const u8) ReversedByteReader { + return .{ + .bytes = bytes, + .remaining_bytes = bytes.len, + }; + } + + pub fn reader(self: *ReversedByteReader) Reader { + return .{ .context = self }; + } + + fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize { + if (ctx.remaining_bytes == 0) return 0; + const byte_index = ctx.remaining_bytes - 1; + buffer[0] = ctx.bytes[byte_index]; + // buffer[0] = @bitReverse(ctx.bytes[byte_index]); + ctx.remaining_bytes = byte_index; + return 1; + } +}; + +/// A bit reader for reading the reversed bit streams used to encode +/// FSE compressed data. +pub const ReverseBitReader = struct { + byte_reader: ReversedByteReader, + bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader), + + pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void { + self.byte_reader = ReversedByteReader.init(bytes); + self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader()); + while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {} + } + + pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U { + return self.bit_reader.readBitsNoEof(U, num_bits); + } + + pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) error{}!U { + return try self.bit_reader.readBits(U, num_bits, out_bits); + } + + pub fn alignToByte(self: *@This()) void { + self.bit_reader.alignToByte(); + } +}; + +pub fn BitReader(comptime Reader: type) type { + return struct { + underlying: std.io.BitReader(.Little, Reader), + + pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U { + return self.underlying.readBitsNoEof(U, num_bits); + } + + pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U { + return self.underlying.readBits(U, num_bits, out_bits); + } + + pub fn alignToByte(self: *@This()) void { + self.underlying.alignToByte(); + } + }; +} + +pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) { + return .{ .underlying = std.io.bitReader(.Little, reader) }; +}