From 5723291444116419440a187adcfa5ecb9557544e Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Thu, 2 Feb 2023 16:19:13 +1100 Subject: [PATCH] std.compress.zstandard: add `decodeBlockReader` --- lib/std/compress/zstandard/decompress.zig | 822 ++++++++++++---------- 1 file changed, 462 insertions(+), 360 deletions(-) diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 0caf31fa33..e32f4d0282 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -14,29 +14,18 @@ fn readVarInt(comptime T: type, bytes: []const u8) T { return std.mem.readVarInt(T, bytes, .Little); } -fn isSkippableMagic(magic: u32) bool { +pub fn isSkippableMagic(magic: u32) bool { return frame.Skippable.magic_number_min <= magic and magic <= frame.Skippable.magic_number_max; } -/// Returns the decompressed size of the frame at the start of `src`. Returns 0 -/// if the the frame is skippable, `null` for Zstanndard frames that do not -/// declare their content size. Returns `UnusedBitSet` and `ReservedBitSet` -/// errors if the respective bits of the the frame descriptor are set. -pub fn getFrameDecompressedSize(src: []const u8) (InvalidBit || error{BadMagic})!?u64 { - switch (try frameType(src)) { - .zstandard => { - const header = try decodeZStandardHeader(src[4..], null); - return header.content_size; - }, - .skippable => return 0, - } -} - -/// Returns the kind of frame at the beginning of `src`. Returns `BadMagic` if -/// `src` begin with bytes not equal to the Zstandard frame magic number, or -/// outside the range of magic numbers for skippable frames. -pub fn frameType(src: []const u8) error{BadMagic}!frame.Kind { - const magic = readInt(u32, src[0..4]); +/// Returns the kind of frame at the beginning of `src`. +/// +/// Errors: +/// - returns `error.BadMagic` if `source` begins with bytes not equal to the +/// Zstandard frame magic number, or outside the range of magic numbers for +/// skippable frames. +pub fn decodeFrameType(source: anytype) !frame.Kind { + const magic = try source.readIntLittle(u32); return if (magic == frame.ZStandard.magic_number) .zstandard else if (isSkippableMagic(magic)) @@ -52,15 +41,21 @@ const ReadWriteCount = struct { /// Decodes the frame at the start of `src` into `dest`. Returns the number of /// bytes read from `src` and written to `dest`. +/// +/// Errors: +/// - returns `error.UnknownContentSizeUnsupported` +/// - returns `error.ContentTooLarge` +/// - returns `error.BadMagic` pub fn decodeFrame( dest: []u8, src: []const u8, verify_checksum: bool, -) (error{ UnknownContentSizeUnsupported, ContentTooLarge, BadMagic } || FrameError)!ReadWriteCount { - return switch (try frameType(src)) { +) !ReadWriteCount { + var fbs = std.io.fixedBufferStream(src); + return switch (try decodeFrameType(fbs.reader())) { .zstandard => decodeZStandardFrame(dest, src, verify_checksum), .skippable => ReadWriteCount{ - .read_count = skippableFrameSize(src[0..8]) + 8, + .read_count = try fbs.reader().readIntLittle(u32) + 8, .write_count = 0, }, }; @@ -97,16 +92,52 @@ pub const DecodeState = struct { }; } + pub fn init( + literal_fse_buffer: []Table.Fse, + match_fse_buffer: []Table.Fse, + offset_fse_buffer: []Table.Fse, + ) DecodeState { + return DecodeState{ + .repeat_offsets = .{ + types.compressed_block.start_repeated_offset_1, + types.compressed_block.start_repeated_offset_2, + types.compressed_block.start_repeated_offset_3, + }, + + .offset = undefined, + .match = undefined, + .literal = undefined, + + .literal_fse_buffer = literal_fse_buffer, + .match_fse_buffer = match_fse_buffer, + .offset_fse_buffer = offset_fse_buffer, + + .fse_tables_undefined = true, + + .literal_written_count = 0, + .literal_header = undefined, + .literal_streams = undefined, + .literal_stream_reader = undefined, + .literal_stream_index = undefined, + .huffman_tree = null, + }; + } + /// Prepare the decoder to decode a compressed block. Loads the literals - /// stream and Huffman tree from `literals` and reads the FSE tables from `src`. - /// Returns `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's - /// first byte does not have any bits set. + /// stream and Huffman tree from `literals` and reads the FSE tables from + /// `source`. + /// + /// Errors: + /// - returns `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's + /// first byte does not have any bits set. + /// - returns `error.TreelessLiteralsFirst` `literals` is a treeless literals section + /// and the decode state does not have a Huffman tree from a previous block. pub fn prepare( self: *DecodeState, - src: []const u8, + source: anytype, literals: LiteralsSection, sequences_header: SequencesSection.Header, - ) (error{ BitStreamHasNoStartBit, TreelessLiteralsFirst } || FseTableError)!usize { + ) !void { self.literal_written_count = 0; self.literal_header = literals.header; self.literal_streams = literals.streams; @@ -129,28 +160,11 @@ pub const DecodeState = struct { } if (sequences_header.sequence_count > 0) { - var bytes_read = try self.updateFseTable( - src, - .literal, - sequences_header.literal_lengths, - ); - - bytes_read += try self.updateFseTable( - src[bytes_read..], - .offset, - sequences_header.offsets, - ); - - bytes_read += try self.updateFseTable( - src[bytes_read..], - .match, - sequences_header.match_lengths, - ); + try self.updateFseTable(source, .literal, sequences_header.literal_lengths); + try self.updateFseTable(source, .offset, sequences_header.offsets); + try self.updateFseTable(source, .match, sequences_header.match_lengths); self.fse_tables_undefined = false; - - return bytes_read; } - return 0; } /// Read initial FSE states for sequence decoding. Returns `error.EndOfStream` @@ -208,10 +222,10 @@ pub const DecodeState = struct { fn updateFseTable( self: *DecodeState, - src: []const u8, + source: anytype, comptime choice: DataType, mode: SequencesSection.Header.Mode, - ) FseTableError!usize { + ) !void { const field_name = @tagName(choice); switch (mode) { .predefined => { @@ -220,17 +234,13 @@ pub const DecodeState = struct { @field(self, field_name).table = @field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table"); - return 0; }, .rle => { @field(self, field_name).accuracy_log = 0; - @field(self, field_name).table = .{ .rle = src[0] }; - return 1; + @field(self, field_name).table = .{ .rle = try source.readByte() }; }, .fse => { - var stream = std.io.fixedBufferStream(src); - var counting_reader = std.io.countingReader(stream.reader()); - var bit_reader = bitReader(counting_reader.reader()); + var bit_reader = bitReader(source); const table_size = try decodeFseTable( &bit_reader, @@ -242,9 +252,8 @@ pub const DecodeState = struct { .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size], }; @field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size); - return std.math.cast(usize, counting_reader.bytes_read) orelse error.MalformedFseTable; }, - .repeat => return if (self.fse_tables_undefined) error.RepeatModeFirst else 0, + .repeat => if (self.fse_tables_undefined) return error.RepeatModeFirst, } } @@ -462,11 +471,15 @@ pub const DecodeState = struct { while (i < len) : (i += 1) { var prefix: u16 = 0; while (true) { - const new_bits = try self.readLiteralsBits(u16, bit_count_to_read); + const new_bits = self.readLiteralsBits(u16, bit_count_to_read) catch |err| { + return err; + }; prefix <<= bit_count_to_read; prefix |= new_bits; bits_read += bit_count_to_read; - const result = try huffman_tree.query(huffman_tree_index, prefix); + const result = huffman_tree.query(huffman_tree_index, prefix) catch |err| { + return err; + }; switch (result) { .symbol => |sym| { @@ -589,11 +602,14 @@ pub fn decodeZStandardFrame( dest: []u8, src: []const u8, verify_checksum: bool, -) (error{ UnknownContentSizeUnsupported, ContentTooLarge } || FrameError)!ReadWriteCount { +) (error{ UnknownContentSizeUnsupported, ContentTooLarge, EndOfStream } || FrameError)!ReadWriteCount { assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number); var consumed_count: usize = 4; - const frame_header = try decodeZStandardHeader(src[consumed_count..], &consumed_count); + var fbs = std.io.fixedBufferStream(src[consumed_count..]); + var source = fbs.reader(); + const frame_header = try decodeZStandardHeader(source); + consumed_count += fbs.pos; if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported; @@ -649,18 +665,25 @@ pub const FrameContext = struct { /// `decodeZStandardFrame()`. Returns `error.WindowSizeUnknown` if the frame /// does not declare its content size or a window descriptor (this indicates a /// malformed frame). +/// +/// Errors: +/// - returns `error.WindowTooLarge` +/// - returns `error.WindowSizeUnknown` pub fn decodeZStandardFrameAlloc( allocator: std.mem.Allocator, src: []const u8, verify_checksum: bool, window_size_max: usize, -) (error{ WindowSizeUnknown, WindowTooLarge, OutOfMemory } || FrameError)![]u8 { +) (error{ WindowSizeUnknown, WindowTooLarge, OutOfMemory, EndOfStream } || FrameError)![]u8 { var result = std.ArrayList(u8).init(allocator); assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number); var consumed_count: usize = 4; var frame_context = context: { - const frame_header = try decodeZStandardHeader(src[consumed_count..], &consumed_count); + var fbs = std.io.fixedBufferStream(src[consumed_count..]); + var source = fbs.reader(); + const frame_header = try decodeZStandardHeader(source); + consumed_count += fbs.pos; break :context try FrameContext.init(frame_header, window_size_max, verify_checksum); }; @@ -674,30 +697,7 @@ pub fn decodeZStandardFrameAlloc( var block_header = decodeBlockHeader(src[consumed_count..][0..3]); consumed_count += 3; - var decode_state = DecodeState{ - .repeat_offsets = .{ - types.compressed_block.start_repeated_offset_1, - types.compressed_block.start_repeated_offset_2, - types.compressed_block.start_repeated_offset_3, - }, - - .offset = undefined, - .match = undefined, - .literal = undefined, - - .literal_fse_buffer = &literal_fse_data, - .match_fse_buffer = &match_fse_data, - .offset_fse_buffer = &offset_fse_data, - - .fse_tables_undefined = true, - - .literal_written_count = 0, - .literal_header = undefined, - .literal_streams = undefined, - .literal_stream_reader = undefined, - .literal_stream_index = undefined, - .huffman_tree = null, - }; + var decode_state = DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data); while (true) : ({ block_header = decodeBlockHeader(src[consumed_count..][0..3]); consumed_count += 3; @@ -754,30 +754,7 @@ pub fn decodeFrameBlocks( var block_header = decodeBlockHeader(src[0..3]); var bytes_read: usize = 3; defer consumed_count.* += bytes_read; - var decode_state = DecodeState{ - .repeat_offsets = .{ - types.compressed_block.start_repeated_offset_1, - types.compressed_block.start_repeated_offset_2, - types.compressed_block.start_repeated_offset_3, - }, - - .offset = undefined, - .match = undefined, - .literal = undefined, - - .literal_fse_buffer = &literal_fse_data, - .match_fse_buffer = &match_fse_data, - .offset_fse_buffer = &offset_fse_data, - - .fse_tables_undefined = true, - - .literal_written_count = 0, - .literal_header = undefined, - .literal_streams = undefined, - .literal_stream_reader = undefined, - .literal_stream_index = undefined, - .huffman_tree = null, - }; + var decode_state = DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data); var written_count: usize = 0; while (true) : ({ block_header = decodeBlockHeader(src[bytes_read..][0..3]); @@ -798,62 +775,6 @@ pub fn decodeFrameBlocks( return written_count; } -fn decodeRawBlock( - dest: []u8, - src: []const u8, - block_size: u21, - consumed_count: *usize, -) error{MalformedBlockSize}!usize { - if (src.len < block_size) return error.MalformedBlockSize; - const data = src[0..block_size]; - std.mem.copy(u8, dest, data); - consumed_count.* += block_size; - return block_size; -} - -fn decodeRawBlockRingBuffer( - dest: *RingBuffer, - src: []const u8, - block_size: u21, - consumed_count: *usize, -) error{MalformedBlockSize}!usize { - if (src.len < block_size) return error.MalformedBlockSize; - const data = src[0..block_size]; - dest.writeSliceAssumeCapacity(data); - consumed_count.* += block_size; - return block_size; -} - -fn decodeRleBlock( - dest: []u8, - src: []const u8, - block_size: u21, - consumed_count: *usize, -) error{MalformedRleBlock}!usize { - if (src.len < 1) return error.MalformedRleBlock; - var write_pos: usize = 0; - while (write_pos < block_size) : (write_pos += 1) { - dest[write_pos] = src[0]; - } - consumed_count.* += 1; - return block_size; -} - -fn decodeRleBlockRingBuffer( - dest: *RingBuffer, - src: []const u8, - block_size: u21, - consumed_count: *usize, -) error{MalformedRleBlock}!usize { - if (src.len < 1) return error.MalformedRleBlock; - var write_pos: usize = 0; - while (write_pos < block_size) : (write_pos += 1) { - dest.writeAssumeCapacity(src[0]); - } - consumed_count.* += 1; - return block_size; -} - /// Decode a single block from `src` into `dest`. The beginning of `src` should /// be the start of the block content (i.e. directly after the block header). /// Increments `consumed_count` by the number of bytes read from `src` to decode @@ -870,19 +791,37 @@ pub fn decodeBlock( const block_size = block_header.block_size; if (block_size_max < block_size) return error.BlockSizeOverMaximum; switch (block_header.block_type) { - .raw => return decodeRawBlock(dest[written_count..], src, block_size, consumed_count), - .rle => return decodeRleBlock(dest[written_count..], src, block_size, consumed_count), + .raw => { + if (src.len < block_size) return error.MalformedBlockSize; + const data = src[0..block_size]; + std.mem.copy(u8, dest[written_count..], data); + consumed_count.* += block_size; + return block_size; + }, + .rle => { + if (src.len < 1) return error.MalformedRleBlock; + var write_pos: usize = written_count; + while (write_pos < block_size + written_count) : (write_pos += 1) { + dest[write_pos] = src[0]; + } + consumed_count.* += 1; + return block_size; + }, .compressed => { if (src.len < block_size) return error.MalformedBlockSize; var bytes_read: usize = 0; - const literals = decodeLiteralsSection(src, &bytes_read) catch + const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch return error.MalformedCompressedBlock; - const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch + var fbs = std.io.fixedBufferStream(src[bytes_read..]); + const fbs_reader = fbs.reader(); + const sequences_header = decodeSequencesHeader(fbs_reader) catch return error.MalformedCompressedBlock; - bytes_read += decode_state.prepare(src[bytes_read..], literals, sequences_header) catch + decode_state.prepare(fbs_reader, literals, sequences_header) catch return error.MalformedCompressedBlock; + bytes_read += fbs.pos; + var bytes_written: usize = 0; if (sequences_header.sequence_count > 0) { const bit_stream_bytes = src[bytes_read..block_size]; @@ -938,19 +877,37 @@ pub fn decodeBlockRingBuffer( const block_size = block_header.block_size; if (block_size_max < block_size) return error.BlockSizeOverMaximum; switch (block_header.block_type) { - .raw => return decodeRawBlockRingBuffer(dest, src, block_size, consumed_count), - .rle => return decodeRleBlockRingBuffer(dest, src, block_size, consumed_count), + .raw => { + if (src.len < block_size) return error.MalformedBlockSize; + const data = src[0..block_size]; + dest.writeSliceAssumeCapacity(data); + consumed_count.* += block_size; + return block_size; + }, + .rle => { + if (src.len < 1) return error.MalformedRleBlock; + var write_pos: usize = 0; + while (write_pos < block_size) : (write_pos += 1) { + dest.writeAssumeCapacity(src[0]); + } + consumed_count.* += 1; + return block_size; + }, .compressed => { if (src.len < block_size) return error.MalformedBlockSize; var bytes_read: usize = 0; - const literals = decodeLiteralsSection(src, &bytes_read) catch + const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch return error.MalformedCompressedBlock; - const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch + var fbs = std.io.fixedBufferStream(src[bytes_read..]); + const fbs_reader = fbs.reader(); + const sequences_header = decodeSequencesHeader(fbs_reader) catch return error.MalformedCompressedBlock; - bytes_read += decode_state.prepare(src[bytes_read..], literals, sequences_header) catch + decode_state.prepare(fbs_reader, literals, sequences_header) catch return error.MalformedCompressedBlock; + bytes_read += fbs.pos; + var bytes_written: usize = 0; if (sequences_header.sequence_count > 0) { const bit_stream_bytes = src[bytes_read..block_size]; @@ -991,6 +948,82 @@ pub fn decodeBlockRingBuffer( } } +/// Decode a single block from `source` into `dest`. Literal and sequence data +/// from the block is copied into `literals_buffer` and `sequence_buffer`, which +/// must be large enough or `error.LiteralsBufferTooSmall` and +/// `error.SequenceBufferTooSmall` are returned (the maximum block size is an +/// upper bound for the size of both buffers). See `decodeBlock` +/// and `decodeBlockRingBuffer` for function that can decode a block without +/// these extra copies. +pub fn decodeBlockReader( + dest: *RingBuffer, + source: anytype, + block_header: frame.ZStandard.Block.Header, + decode_state: *DecodeState, + block_size_max: usize, + literals_buffer: []u8, + sequence_buffer: []u8, +) !void { + const block_size = block_header.block_size; + var block_reader_limited = std.io.limitedReader(source, block_size); + const block_reader = block_reader_limited.reader(); + if (block_size_max < block_size) return error.BlockSizeOverMaximum; + switch (block_header.block_type) { + .raw => { + const slice = dest.sliceAt(dest.write_index, block_size); + try source.readNoEof(slice.first); + try source.readNoEof(slice.second); + dest.write_index = dest.mask2(dest.write_index + block_size); + }, + .rle => { + const byte = try source.readByte(); + var i: usize = 0; + while (i < block_size) : (i += 1) { + dest.writeAssumeCapacity(byte); + } + }, + .compressed => { + const literals = try decodeLiteralsSection(block_reader, literals_buffer); + const sequences_header = try decodeSequencesHeader(block_reader); + + try decode_state.prepare(block_reader, literals, sequences_header); + + if (sequences_header.sequence_count > 0) { + if (sequence_buffer.len < block_reader_limited.bytes_left) + return error.SequenceBufferTooSmall; + + const size = try block_reader.readAll(sequence_buffer); + var bit_stream: ReverseBitReader = undefined; + try bit_stream.init(sequence_buffer[0..size]); + + decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock; + + var sequence_size_limit = block_size_max; + var i: usize = 0; + while (i < sequences_header.sequence_count) : (i += 1) { + const decompressed_size = decode_state.decodeSequenceRingBuffer( + dest, + &bit_stream, + sequence_size_limit, + i == sequences_header.sequence_count - 1, + ) catch return error.MalformedCompressedBlock; + sequence_size_limit -= decompressed_size; + } + } + + if (decode_state.literal_written_count < literals.header.regenerated_size) { + const len = literals.header.regenerated_size - decode_state.literal_written_count; + decode_state.decodeLiteralsRingBuffer(dest, len) catch + return error.MalformedCompressedBlock; + } + + decode_state.literal_written_count = 0; + assert(block_reader.readByte() == error.EndOfStream); + }, + .reserved => return error.ReservedBlock, + } +} + /// Decode the header of a skippable frame. pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header { const magic = readInt(u32, src[0..4]); @@ -1002,13 +1035,6 @@ pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header { }; } -/// Returns the content size of a skippable frame. -pub fn skippableFrameSize(src: *const [8]u8) usize { - assert(isSkippableMagic(readInt(u32, src[0..4]))); - const frame_size = readInt(u32, src[4..8]); - return frame_size; -} - /// Returns the window size required to decompress a frame, or `null` if it cannot be /// determined, which indicates a malformed frame header. pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 { @@ -1023,40 +1049,37 @@ pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 { } const InvalidBit = error{ UnusedBitSet, ReservedBitSet }; -/// Decode the header of a Zstandard frame. Returns `error.UnusedBitSet` or -/// `error.ReservedBitSet` if the corresponding bits are sets. -pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) InvalidBit!frame.ZStandard.Header { - const descriptor = @bitCast(frame.ZStandard.Header.Descriptor, src[0]); +/// Decode the header of a Zstandard frame. +/// +/// Errors: +/// - returns `error.UnusedBitSet` if the unused bits of the header are set +/// - returns `error.ReservedBitSet` if the reserved bits of the header are +/// set +pub fn decodeZStandardHeader(source: anytype) (error{EndOfStream} || InvalidBit)!frame.ZStandard.Header { + const descriptor = @bitCast(frame.ZStandard.Header.Descriptor, try source.readByte()); if (descriptor.unused) return error.UnusedBitSet; if (descriptor.reserved) return error.ReservedBitSet; - var bytes_read_count: usize = 1; - var window_descriptor: ?u8 = null; if (!descriptor.single_segment_flag) { - window_descriptor = src[bytes_read_count]; - bytes_read_count += 1; + window_descriptor = try source.readByte(); } var dictionary_id: ?u32 = null; if (descriptor.dictionary_id_flag > 0) { // if flag is 3 then field_size = 4, else field_size = flag const field_size = (@as(u4, 1) << descriptor.dictionary_id_flag) >> 1; - dictionary_id = readVarInt(u32, src[bytes_read_count .. bytes_read_count + field_size]); - bytes_read_count += field_size; + dictionary_id = try source.readVarInt(u32, .Little, field_size); } var content_size: ?u64 = null; if (descriptor.single_segment_flag or descriptor.content_size_flag > 0) { const field_size = @as(u4, 1) << descriptor.content_size_flag; - content_size = readVarInt(u64, src[bytes_read_count .. bytes_read_count + field_size]); + content_size = try source.readVarInt(u64, .Little, field_size); if (field_size == 2) content_size.? += 256; - bytes_read_count += field_size; } - if (consumed_count) |p| p.* += bytes_read_count; - const header = frame.ZStandard.Header{ .descriptor = descriptor, .window_descriptor = window_descriptor, @@ -1080,12 +1103,20 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header { /// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the /// number of bytes the section uses. -pub fn decodeLiteralsSection( +/// +/// Errors: +/// - returns `error.MalformedLiteralsHeader` if the header is invalid +/// - returns `error.MalformedLiteralsSection` if there are errors decoding +pub fn decodeLiteralsSectionSlice( src: []const u8, consumed_count: *usize, -) (error{ MalformedLiteralsHeader, MalformedLiteralsSection } || DecodeHuffmanError)!LiteralsSection { +) (error{ MalformedLiteralsHeader, MalformedLiteralsSection, EndOfStream } || DecodeHuffmanError)!LiteralsSection { var bytes_read: usize = 0; - const header = try decodeLiteralsHeader(src, &bytes_read); + const header = header: { + var fbs = std.io.fixedBufferStream(src); + defer bytes_read = fbs.pos; + break :header decodeLiteralsHeader(fbs.reader()) catch return error.MalformedLiteralsHeader; + }; switch (header.block_type) { .raw => { if (src.len < bytes_read + header.regenerated_size) return error.MalformedLiteralsSection; @@ -1110,7 +1141,7 @@ pub fn decodeLiteralsSection( .compressed, .treeless => { const huffman_tree_start = bytes_read; const huffman_tree = if (header.block_type == .compressed) - try decodeHuffmanTree(src[bytes_read..], &bytes_read) + try decodeHuffmanTreeSlice(src[bytes_read..], &bytes_read) else null; const huffman_tree_size = bytes_read - huffman_tree_start; @@ -1119,137 +1150,185 @@ pub fn decodeLiteralsSection( if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection; const stream_data = src[bytes_read .. bytes_read + total_streams_size]; - if (header.size_format == 0) { - consumed_count.* += total_streams_size + bytes_read; - return LiteralsSection{ - .header = header, - .huffman_tree = huffman_tree, - .streams = .{ .one = stream_data }, - }; - } - - if (stream_data.len < 6) return error.MalformedLiteralsSection; - - const stream_1_length = @as(usize, readInt(u16, stream_data[0..2])); - const stream_2_length = @as(usize, readInt(u16, stream_data[2..4])); - const stream_3_length = @as(usize, readInt(u16, stream_data[4..6])); - const stream_4_length = (total_streams_size - 6) - (stream_1_length + stream_2_length + stream_3_length); - - const stream_1_start = 6; - const stream_2_start = stream_1_start + stream_1_length; - const stream_3_start = stream_2_start + stream_2_length; - const stream_4_start = stream_3_start + stream_3_length; - - if (stream_data.len < stream_4_start + stream_4_length) return error.MalformedLiteralsSection; - consumed_count.* += total_streams_size + bytes_read; - + const streams = try decodeStreams(header.size_format, stream_data); + consumed_count.* += bytes_read + total_streams_size; return LiteralsSection{ .header = header, .huffman_tree = huffman_tree, - .streams = .{ .four = .{ - stream_data[stream_1_start .. stream_1_start + stream_1_length], - stream_data[stream_2_start .. stream_2_start + stream_2_length], - stream_data[stream_3_start .. stream_3_start + stream_3_length], - stream_data[stream_4_start .. stream_4_start + stream_4_length], - } }, + .streams = streams, }; }, } } +/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the +/// number of bytes the section uses. +/// +/// Errors: +/// - returns `error.MalformedLiteralsHeader` if the header is invalid +/// - returns `error.MalformedLiteralsSection` if there are errors decoding +pub fn decodeLiteralsSection( + source: anytype, + buffer: []u8, +) !LiteralsSection { + const header = try decodeLiteralsHeader(source); + switch (header.block_type) { + .raw => { + try source.readNoEof(buffer[0..header.regenerated_size]); + return LiteralsSection{ + .header = header, + .huffman_tree = null, + .streams = .{ .one = buffer }, + }; + }, + .rle => { + buffer[0] = try source.readByte(); + return LiteralsSection{ + .header = header, + .huffman_tree = null, + .streams = .{ .one = buffer[0..1] }, + }; + }, + .compressed, .treeless => { + var counting_reader = std.io.countingReader(source); + const huffman_tree = if (header.block_type == .compressed) + try decodeHuffmanTree(counting_reader.reader(), buffer) + else + null; + const huffman_tree_size = counting_reader.bytes_read; + const total_streams_size = @as(usize, header.compressed_size.?) - @intCast(usize, huffman_tree_size); + + if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall; + try source.readNoEof(buffer[0..total_streams_size]); + const stream_data = buffer[0..total_streams_size]; + + const streams = try decodeStreams(header.size_format, stream_data); + return LiteralsSection{ + .header = header, + .huffman_tree = huffman_tree, + .streams = streams, + }; + }, + } +} + +fn decodeStreams(size_format: u2, stream_data: []const u8) !LiteralsSection.Streams { + if (size_format == 0) { + return .{ .one = stream_data }; + } + + if (stream_data.len < 6) return error.MalformedLiteralsSection; + + const stream_1_length = @as(usize, readInt(u16, stream_data[0..2])); + const stream_2_length = @as(usize, readInt(u16, stream_data[2..4])); + const stream_3_length = @as(usize, readInt(u16, stream_data[4..6])); + + const stream_1_start = 6; + const stream_2_start = stream_1_start + stream_1_length; + const stream_3_start = stream_2_start + stream_2_length; + const stream_4_start = stream_3_start + stream_3_length; + + return .{ .four = .{ + stream_data[stream_1_start .. stream_1_start + stream_1_length], + stream_data[stream_2_start .. stream_2_start + stream_2_length], + stream_data[stream_3_start .. stream_3_start + stream_3_length], + stream_data[stream_4_start..], + } }; +} + const DecodeHuffmanError = error{ MalformedHuffmanTree, MalformedFseTable, MalformedAccuracyLog, }; -fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError!LiteralsSection.HuffmanTree { - var bytes_read: usize = 0; - bytes_read += 1; - if (src.len == 0) return error.MalformedHuffmanTree; - const header = src[0]; - var symbol_count: usize = undefined; - var weights: [256]u4 = undefined; - var max_number_of_bits: u4 = undefined; - if (header < 128) { - // FSE compressed weights - const compressed_size = header; - if (src.len < 1 + compressed_size) return error.MalformedHuffmanTree; - var stream = std.io.fixedBufferStream(src[1 .. compressed_size + 1]); - var counting_reader = std.io.countingReader(stream.reader()); - var bit_reader = bitReader(counting_reader.reader()); +fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, weights: *[256]u4) !usize { + var stream = std.io.limitedReader(source, compressed_size); + var bit_reader = bitReader(stream.reader()); - var entries: [1 << 6]Table.Fse = undefined; - const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) { - error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e, - error.EndOfStream => return error.MalformedFseTable, - }; - const accuracy_log = std.math.log2_int_ceil(usize, table_size); + var entries: [1 << 6]Table.Fse = undefined; + const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) { + error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e, + error.EndOfStream => return error.MalformedFseTable, + }; + const accuracy_log = std.math.log2_int_ceil(usize, table_size); - const start_index = std.math.cast(usize, 1 + counting_reader.bytes_read) orelse return error.MalformedHuffmanTree; - var huff_data = src[start_index .. compressed_size + 1]; - var huff_bits: ReverseBitReader = undefined; - huff_bits.init(huff_data) catch return error.MalformedHuffmanTree; + const amount = try stream.reader().readAll(buffer); + var huff_bits: ReverseBitReader = undefined; + huff_bits.init(buffer[0..amount]) catch return error.MalformedHuffmanTree; - var i: usize = 0; - var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; - var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; + return assignWeights(&huff_bits, accuracy_log, &entries, weights); +} - while (i < 255) { - const even_data = entries[even_state]; - var read_bits: usize = 0; - const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable; - weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree; +fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *[256]u4) !usize { + if (src.len < compressed_size) return error.MalformedHuffmanTree; + var stream = std.io.fixedBufferStream(src[0..compressed_size]); + var counting_reader = std.io.countingReader(stream.reader()); + var bit_reader = bitReader(counting_reader.reader()); + + var entries: [1 << 6]Table.Fse = undefined; + const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) { + error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e, + error.EndOfStream => return error.MalformedFseTable, + }; + const accuracy_log = std.math.log2_int_ceil(usize, table_size); + + const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedHuffmanTree; + var huff_data = src[start_index..compressed_size]; + var huff_bits: ReverseBitReader = undefined; + huff_bits.init(huff_data) catch return error.MalformedHuffmanTree; + + return assignWeights(&huff_bits, accuracy_log, &entries, weights); +} + +fn assignWeights(huff_bits: *ReverseBitReader, accuracy_log: usize, entries: *[1 << 6]Table.Fse, weights: *[256]u4) !usize { + var i: usize = 0; + var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; + var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; + + while (i < 255) { + const even_data = entries[even_state]; + var read_bits: usize = 0; + const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable; + weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree; + i += 1; + if (read_bits < even_data.bits) { + weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree; i += 1; - if (read_bits < even_data.bits) { - weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree; - i += 1; - break; - } - even_state = even_data.baseline + even_bits; + break; + } + even_state = even_data.baseline + even_bits; - read_bits = 0; - const odd_data = entries[odd_state]; - const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable; - weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree; + read_bits = 0; + const odd_data = entries[odd_state]; + const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable; + weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree; + i += 1; + if (read_bits < odd_data.bits) { + if (i == 256) return error.MalformedHuffmanTree; + weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree; i += 1; - if (read_bits < odd_data.bits) { - if (i == 256) return error.MalformedHuffmanTree; - weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree; - i += 1; - break; - } - odd_state = odd_data.baseline + odd_bits; - } else return error.MalformedHuffmanTree; - - symbol_count = i + 1; // stream contains all but the last symbol - bytes_read += compressed_size; - } else { - const encoded_symbol_count = header - 127; - symbol_count = encoded_symbol_count + 1; - const weights_byte_count = (encoded_symbol_count + 1) / 2; - if (src.len < weights_byte_count) return error.MalformedHuffmanTree; - var i: usize = 0; - while (i < weights_byte_count) : (i += 1) { - weights[2 * i] = @intCast(u4, src[i + 1] >> 4); - weights[2 * i + 1] = @intCast(u4, src[i + 1] & 0xF); + break; } - bytes_read += weights_byte_count; - } - var weight_power_sum: u16 = 0; - for (weights[0 .. symbol_count - 1]) |value| { - if (value > 0) { - weight_power_sum += @as(u16, 1) << (value - 1); - } - } + odd_state = odd_data.baseline + odd_bits; + } else return error.MalformedHuffmanTree; - // advance to next power of two (even if weight_power_sum is a power of 2) - max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1; - const next_power_of_two = @as(u16, 1) << max_number_of_bits; - weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1; + return i + 1; // stream contains all but the last symbol +} - var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined; - for (weight_sorted_prefixed_symbols[0..symbol_count]) |_, i| { +fn decodeDirectHuffmanTree(source: anytype, encoded_symbol_count: usize, weights: *[256]u4) !usize { + const weights_byte_count = (encoded_symbol_count + 1) / 2; + var i: usize = 0; + while (i < weights_byte_count) : (i += 1) { + const byte = try source.readByte(); + weights[2 * i] = @intCast(u4, byte >> 4); + weights[2 * i + 1] = @intCast(u4, byte & 0xF); + } + return encoded_symbol_count + 1; +} + +fn assignSymbols(weight_sorted_prefixed_symbols: []LiteralsSection.HuffmanTree.PrefixedSymbol, weights: [256]u4) usize { + for (weight_sorted_prefixed_symbols) |_, i| { weight_sorted_prefixed_symbols[i] = .{ .symbol = @intCast(u8, i), .weight = undefined, @@ -1259,7 +1338,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError std.sort.sort( LiteralsSection.HuffmanTree.PrefixedSymbol, - weight_sorted_prefixed_symbols[0..symbol_count], + weight_sorted_prefixed_symbols, weights, lessThanByWeight, ); @@ -1267,6 +1346,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError var prefix: u16 = 0; var prefixed_symbol_count: usize = 0; var sorted_index: usize = 0; + const symbol_count = weight_sorted_prefixed_symbols.len; while (sorted_index < symbol_count) { var symbol = weight_sorted_prefixed_symbols[sorted_index].symbol; const weight = weights[symbol]; @@ -1290,7 +1370,24 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError weight_sorted_prefixed_symbols[prefixed_symbol_count].weight = weight; } } - consumed_count.* += bytes_read; + return prefixed_symbol_count; +} + +fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) LiteralsSection.HuffmanTree { + var weight_power_sum: u16 = 0; + for (weights[0 .. symbol_count - 1]) |value| { + if (value > 0) { + weight_power_sum += @as(u16, 1) << (value - 1); + } + } + + // advance to next power of two (even if weight_power_sum is a power of 2) + const max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1; + const next_power_of_two = @as(u16, 1) << max_number_of_bits; + weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1; + + var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined; + const prefixed_symbol_count = assignSymbols(weight_sorted_prefixed_symbols[0..symbol_count], weights.*); const tree = LiteralsSection.HuffmanTree{ .max_bit_count = max_number_of_bits, .symbol_count_minus_one = @intCast(u8, prefixed_symbol_count - 1), @@ -1299,6 +1396,37 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError return tree; } +fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.HuffmanTree { + const header = try source.readByte(); + var weights: [256]u4 = undefined; + const symbol_count = if (header < 128) + // FSE compressed weights + try decodeFseHuffmanTree(source, header, buffer, &weights) + else + try decodeDirectHuffmanTree(source, header - 127, &weights); + + return buildHuffmanTree(&weights, symbol_count); +} + +fn decodeHuffmanTreeSlice(src: []const u8, consumed_count: *usize) (error{EndOfStream} || DecodeHuffmanError)!LiteralsSection.HuffmanTree { + if (src.len == 0) return error.MalformedHuffmanTree; + const header = src[0]; + var bytes_read: usize = 1; + var weights: [256]u4 = undefined; + const symbol_count = if (header < 128) count: { + // FSE compressed weights + bytes_read += header; + break :count try decodeFseHuffmanTreeSlice(src[1..], header, &weights); + } else count: { + var fbs = std.io.fixedBufferStream(src[1..]); + defer bytes_read += fbs.pos; + break :count try decodeDirectHuffmanTree(fbs.reader(), header - 127, &weights); + }; + + consumed_count.* += bytes_read; + return buildHuffmanTree(&weights, symbol_count); +} + fn lessThanByWeight( weights: [256]u4, lhs: LiteralsSection.HuffmanTree.PrefixedSymbol, @@ -1311,9 +1439,8 @@ fn lessThanByWeight( } /// Decode a literals section header. -pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) error{MalformedLiteralsHeader}!LiteralsSection.Header { - if (src.len == 0) return error.MalformedLiteralsHeader; - const byte0 = src[0]; +pub fn decodeLiteralsHeader(source: anytype) !LiteralsSection.Header { + const byte0 = try source.readByte(); const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11); const size_format = @intCast(u2, (byte0 & 0b1100) >> 2); var regenerated_size: u20 = undefined; @@ -1323,47 +1450,31 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) error{Malfo switch (size_format) { 0, 2 => { regenerated_size = byte0 >> 3; - consumed_count.* += 1; - }, - 1 => { - if (src.len < 2) return error.MalformedLiteralsHeader; - regenerated_size = (byte0 >> 4) + - (@as(u20, src[1]) << 4); - consumed_count.* += 2; - }, - 3 => { - if (src.len < 3) return error.MalformedLiteralsHeader; - regenerated_size = (byte0 >> 4) + - (@as(u20, src[1]) << 4) + - (@as(u20, src[2]) << 12); - consumed_count.* += 3; }, + 1 => regenerated_size = (byte0 >> 4) + (@as(u20, try source.readByte()) << 4), + 3 => regenerated_size = (byte0 >> 4) + + (@as(u20, try source.readByte()) << 4) + + (@as(u20, try source.readByte()) << 12), } }, .compressed, .treeless => { - const byte1 = src[1]; - const byte2 = src[2]; + const byte1 = try source.readByte(); + const byte2 = try source.readByte(); switch (size_format) { 0, 1 => { - if (src.len < 3) return error.MalformedLiteralsHeader; regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4); compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2); - consumed_count.* += 3; }, 2 => { - if (src.len < 4) return error.MalformedLiteralsHeader; - const byte3 = src[3]; + const byte3 = try source.readByte(); regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12); compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6); - consumed_count.* += 4; }, 3 => { - if (src.len < 5) return error.MalformedLiteralsHeader; - const byte3 = src[3]; - const byte4 = src[4]; + const byte3 = try source.readByte(); + const byte4 = try source.readByte(); regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12); compressed_size = ((byte2 & 0b11000000) >> 6) + (@as(u18, byte3) << 2) + (@as(u18, byte4) << 10); - consumed_count.* += 5; }, } }, @@ -1377,18 +1488,17 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) error{Malfo } /// Decode a sequences section header. +/// +/// Errors: +/// - returns `error.ReservedBitSet` is the reserved bit is set +/// - returns `error.MalformedSequencesHeader` if the header is invalid pub fn decodeSequencesHeader( - src: []const u8, - consumed_count: *usize, -) error{ MalformedSequencesHeader, ReservedBitSet }!SequencesSection.Header { - if (src.len == 0) return error.MalformedSequencesHeader; + source: anytype, +) !SequencesSection.Header { var sequence_count: u24 = undefined; - var bytes_read: usize = 0; - const byte0 = src[0]; + const byte0 = try source.readByte(); if (byte0 == 0) { - bytes_read += 1; - consumed_count.* += bytes_read; return SequencesSection.Header{ .sequence_count = 0, .offsets = undefined, @@ -1397,22 +1507,14 @@ pub fn decodeSequencesHeader( }; } else if (byte0 < 128) { sequence_count = byte0; - bytes_read += 1; } else if (byte0 < 255) { - if (src.len < 2) return error.MalformedSequencesHeader; - sequence_count = (@as(u24, (byte0 - 128)) << 8) + src[1]; - bytes_read += 2; + sequence_count = (@as(u24, (byte0 - 128)) << 8) + try source.readByte(); } else { - if (src.len < 3) return error.MalformedSequencesHeader; - sequence_count = src[1] + (@as(u24, src[2]) << 8) + 0x7F00; - bytes_read += 3; + sequence_count = (try source.readByte()) + (@as(u24, try source.readByte()) << 8) + 0x7F00; } - if (src.len < bytes_read + 1) return error.MalformedSequencesHeader; - const compression_modes = src[bytes_read]; - bytes_read += 1; + const compression_modes = try source.readByte(); - consumed_count.* += bytes_read; const matches_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00001100) >> 2); const offsets_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00110000) >> 4); const literal_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b11000000) >> 6); @@ -1615,7 +1717,7 @@ fn BitReader(comptime Reader: type) type { }; } -fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) { +pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) { return .{ .underlying = std.io.bitReader(.Little, reader) }; }