From 7558bf64513ec2be59b95aecc5e0ac50ad88b1f5 Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Wed, 25 Jan 2023 01:30:17 +1100 Subject: [PATCH] std.compress.zstandard: minor cleanup and add doc comments --- lib/std/compress/zstandard/decompress.zig | 72 +++++++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 9119a51c7d..ff554ee6c8 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -18,6 +18,10 @@ fn isSkippableMagic(magic: u32) bool { return frame.Skippable.magic_number_min <= magic and magic <= frame.Skippable.magic_number_max; } +/// Returns the decompressed size of the frame at the start of `src`. Returns 0 +/// if the the frame is skippable, `null` for Zstanndard frames that do not +/// declare their content size. Returns `UnusedBitSet` and `ReservedBitSet` +/// errors if the respective bits of the the frame descriptor are set. pub fn getFrameDecompressedSize(src: []const u8) !?usize { switch (try frameType(src)) { .zstandard => { @@ -28,7 +32,10 @@ pub fn getFrameDecompressedSize(src: []const u8) !?usize { } } -pub fn frameType(src: []const u8) !frame.Kind { +/// Returns the kind of frame at the beginning of `src`. Returns `BadMagic` if +/// `src` begin with bytes not equal to the Zstandard frame magic number, or +/// outside the range of magic numbers for skippable frames. +pub fn frameType(src: []const u8) error{BadMagic}!frame.Kind { const magic = readInt(u32, src[0..4]); return if (magic == frame.ZStandard.magic_number) .zstandard @@ -43,11 +50,13 @@ const ReadWriteCount = struct { write_count: usize, }; +/// Decodes the frame at the start of `src` into `dest`. Returns the number of +/// bytes read from `src` and written to `dest`. pub fn decodeFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWriteCount { return switch (try frameType(src)) { .zstandard => decodeZStandardFrame(dest, src, verify_checksum), .skippable => ReadWriteCount{ - .read_count = try skippableFrameSize(src[0..8]) + 8, + .read_count = skippableFrameSize(src[0..8]) + 8, .write_count = 0, }, }; @@ -82,6 +91,10 @@ pub const DecodeState = struct { }; } + /// Prepare the decoder to decode a compressed block. Loads the literals + /// stream and Huffman tree from `literals` and reads the FSE tables from `src`. + /// Returns `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's + /// first byte does not have any bits set. pub fn prepare( self: *DecodeState, src: []const u8, @@ -130,6 +143,8 @@ pub const DecodeState = struct { return 0; } + /// Read initial FSE states for sequence decoding. Returns `error.EndOfStream` + /// if `bit_reader` does not contain enough bits. pub fn readInitialFseState(self: *DecodeState, bit_reader: anytype) !void { self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log); self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log); @@ -283,6 +298,14 @@ pub const DecodeState = struct { for (copy_slice.second) |b| dest.writeAssumeCapacity(b); } + /// Decode one sequence from `bit_reader` into `dest`, written starting at + /// `write_pos` and update FSE states if `last_sequence` is `false`. Returns + /// `error.MalformedSequence` error if the decompressed sequence would be longer + /// than `sequence_size_limit` or the sequence's offset is too large; returns + /// `error.EndOfStream` if `bit_reader` does not contain enough bits; returns + /// `error.UnexpectedEndOfLiteralStream` if the decoder state's literal streams + /// do not contain enough literals for the sequence (this may mean the literal + /// stream or the sequence is malformed). pub fn decodeSequenceSlice( self: *DecodeState, dest: []u8, @@ -305,6 +328,7 @@ pub const DecodeState = struct { return sequence_length; } + /// Decode one sequence from `bit_reader` into `dest`; see `decodeSequenceSlice`. pub fn decodeSequenceRingBuffer( self: *DecodeState, dest: *RingBuffer, @@ -335,6 +359,12 @@ pub const DecodeState = struct { try self.literal_stream_reader.init(bytes); } + /// Decode `len` bytes of literals into `dest`. `literals` should be the + /// `LiteralsSection` that was passed to `prepare()`. Returns + /// `error.MalformedLiteralsLength` if the number of literal bytes decoded by + /// `self` plus `len` is greater than the regenerated size of `literals`. + /// Returns `error.UnexpectedEndOfLiteralStream` and `error.PrefixNotFound` if + /// there are problems decoding Huffman compressed literals. pub fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: LiteralsSection, len: usize) !void { if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength; switch (literals.header.block_type) { @@ -403,6 +433,7 @@ pub const DecodeState = struct { } } + /// Decode literals into `dest`; see `decodeLiteralsSlice()`. pub fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: LiteralsSection, len: usize) !void { if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength; switch (literals.header.block_type) { @@ -483,6 +514,13 @@ const literal_table_size_max = 1 << types.compressed_block.table_accuracy_log_ma const match_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match; const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match; +/// Decode a Zstandard frame from `src` into `dest`, returning the number of +/// bytes read from `src` and written to `dest`; if the frame does not declare +/// its decompressed content size `error.UnknownContentSizeUnsupported` is +/// returned. Returns `error.DictionaryIdFlagUnsupported` if the frame uses a +/// dictionary, and `error.ChecksumFailure` if `verify_checksum` is `true` and +/// the frame contains a checksum that does not match the checksum computed from +/// the decompressed frame. pub fn decodeZStandardFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWriteCount { assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number); var consumed_count: usize = 4; @@ -520,6 +558,10 @@ pub fn decodeZStandardFrame(dest: []u8, src: []const u8, verify_checksum: bool) return ReadWriteCount{ .read_count = consumed_count, .write_count = written_count }; } +/// Decode a Zstandard from from `src` and return the decompressed bytes; see +/// `decodeZStandardFrame()`. Returns `error.WindowSizeUnknown` if the frame +/// does not declare its content size or a window descriptor (this indicates a +/// malformed frame). pub fn decodeZStandardFrameAlloc(allocator: std.mem.Allocator, src: []const u8, verify_checksum: bool) ![]u8 { var result = std.ArrayList(u8).init(allocator); assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number); @@ -599,6 +641,7 @@ pub fn decodeZStandardFrameAlloc(allocator: std.mem.Allocator, src: []const u8, return result.toOwnedSlice(); } +/// Convenience wrapper for decoding all blocks in a frame; see `decodeBlock()`. pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, hash: ?*std.hash.XxHash64) !usize { // These tables take 7680 bytes var literal_fse_data: [literal_table_size_max]Table.Fse = undefined; @@ -686,6 +729,10 @@ fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, return block_size; } +/// Decode a single block from `src` into `dest`. The beginning of `src` should +/// be the start of the block content (i.e. directly after the block header). +/// Increments `consumed_count` by the number of bytes read from `src` to decode +/// the block and returns the decompressed size of the block. pub fn decodeBlock( dest: []u8, src: []const u8, @@ -750,6 +797,9 @@ pub fn decodeBlock( } } +/// Decode a single block from `src` into `dest`; see `decodeBlock()`. Returns +/// the size of the decompressed block, which can be used with `dest.sliceLast()` +/// to get the decompressed bytes. pub fn decodeBlockRingBuffer( dest: *RingBuffer, src: []const u8, @@ -811,6 +861,7 @@ pub fn decodeBlockRingBuffer( } } +/// Decode the header of a skippable frame. pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header { const magic = readInt(u32, src[0..4]); assert(isSkippableMagic(magic)); @@ -821,12 +872,15 @@ pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header { }; } -pub fn skippableFrameSize(src: *const [8]u8) !usize { +/// Returns the content size of a skippable frame. +pub fn skippableFrameSize(src: *const [8]u8) usize { assert(isSkippableMagic(readInt(u32, src[0..4]))); const frame_size = readInt(u32, src[4..8]); return frame_size; } +/// Returns the window size required to decompress a frame, or `null` if it cannot be +/// determined, which indicates a malformed frame header. pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 { if (header.window_descriptor) |descriptor| { const exponent = (descriptor & 0b11111000) >> 3; @@ -838,6 +892,8 @@ pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 { } else return header.content_size; } +/// Decode the header of a Zstandard frame. Returns `error.UnusedBitSet` or +/// `error.ReservedBitSet` if the corresponding bits are sets. pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) !frame.ZStandard.Header { const descriptor = @bitCast(frame.ZStandard.Header.Descriptor, src[0]); @@ -879,6 +935,7 @@ pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) !frame.ZS return header; } +/// Decode the header of a block. pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header { const last_block = src[0] & 1 == 1; const block_type = @intToEnum(frame.ZStandard.Block.Type, (src[0] & 0b110) >> 1); @@ -890,6 +947,8 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header { }; } +/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the +/// number of bytes the section uses. pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsSection { var bytes_read: usize = 0; const header = try decodeLiteralsHeader(src, &bytes_read); @@ -1107,6 +1166,7 @@ fn lessThanByWeight( return weights[lhs.symbol] < weights[rhs.symbol]; } +/// Decode a literals section header. pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) !LiteralsSection.Header { if (src.len == 0) return error.MalformedLiteralsSection; const byte0 = src[0]; @@ -1172,6 +1232,7 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) !LiteralsSe }; } +/// Decode a sequences section header. pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !SequencesSection.Header { if (src.len == 0) return error.MalformedSequencesSection; var sequence_count: u24 = undefined; @@ -1241,7 +1302,8 @@ fn buildFseTable(values: []const u16, entries: []Table.Fse) !void { if (value == 0 or value == 1) continue; const probability = value - 1; - const state_share_dividend = try std.math.ceilPowerOfTwo(u16, probability); + const state_share_dividend = std.math.ceilPowerOfTwo(u16, probability) catch + return error.MalformedFseTable; const share_size = @divExact(total_probability, state_share_dividend); const double_state_count = state_share_dividend - probability; const single_state_count = probability - double_state_count; @@ -1363,6 +1425,8 @@ const ReversedByteReader = struct { } }; +/// A bit reader for reading the reversed bit streams used to encode +/// FSE compressed data. pub const ReverseBitReader = struct { byte_reader: ReversedByteReader, bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader),