From 18091723d5afa8001e0fd71274dc4b74d601d0e1 Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Sun, 22 Jan 2023 13:32:16 +1100 Subject: [PATCH] std.compress.zstandard: cleanup decodeBlock --- lib/std/compress/zstandard/decompress.zig | 148 ++++++++++++---------- 1 file changed, 78 insertions(+), 70 deletions(-) diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 37fe722a84..9483b4d9d7 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -65,12 +65,14 @@ const DecodeState = struct { match_fse_buffer: []Table.Fse, literal_fse_buffer: []Table.Fse, - literal_written_count: usize, + fse_tables_undefined: bool, literal_stream_reader: ReverseBitReader(ReversedByteReader.Reader), literal_stream_bytes: ReversedByteReader, literal_stream_index: usize, - huffman_tree: Literals.HuffmanTree, + huffman_tree: ?Literals.HuffmanTree, + + literal_written_count: usize, fn StateData(comptime max_accuracy_log: comptime_int) type { return struct { @@ -129,7 +131,6 @@ const DecodeState = struct { src: []const u8, comptime choice: DataType, mode: Sequences.Header.Mode, - first_compressed_block: bool, ) !usize { const field_name = @tagName(choice); switch (mode) { @@ -162,7 +163,7 @@ const DecodeState = struct { dumpFseTable(field_name, @field(self, field_name).table.fse); return counting_reader.bytes_read; }, - .repeat => return if (first_compressed_block) error.RepeatModeFirst else 0, + .repeat => return if (self.fse_tables_undefined) error.RepeatModeFirst else 0, } } @@ -275,7 +276,7 @@ const DecodeState = struct { }, .compressed, .treeless => { // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; - const huffman_tree = self.huffman_tree; + const huffman_tree = self.huffman_tree orelse unreachable; const max_bit_count = huffman_tree.max_bit_count; const starting_bit_count = Literals.HuffmanTree.weightToBitCount( huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight, @@ -399,14 +400,14 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha .match_fse_buffer = &match_fse_data, .offset_fse_buffer = &offset_fse_data, + .fse_tables_undefined = true, + .literal_written_count = 0, .literal_stream_reader = undefined, .literal_stream_bytes = undefined, .literal_stream_index = undefined, - .huffman_tree = undefined, + .huffman_tree = null, }; - var first_compressed_block = true; - var first_compressed_literals = true; var written_count: usize = 0; while (true) : ({ block_header = decodeBlockHeader(src[bytes_read..][0..3]); @@ -417,8 +418,6 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha src[bytes_read..], block_header, &decode_state, - &first_compressed_block, - &first_compressed_literals, &bytes_read, written_count, ); @@ -430,13 +429,77 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha return written_count; } +fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize { + log.debug("writing raw block - size {d}", .{block_size}); + const data = src[0..block_size]; + std.mem.copy(u8, dest, data); + consumed_count.* += block_size; + return block_size; +} + +fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize { + log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size }); + var write_pos: usize = 0; + while (write_pos < block_size) : (write_pos += 1) { + dest[write_pos] = src[0]; + } + consumed_count.* += 1; + return block_size; +} + +fn prepareDecodeState( + decode_state: *DecodeState, + src: []const u8, + literals: Literals, + sequences_header: Sequences.Header, +) !usize { + if (literals.huffman_tree) |tree| { + decode_state.huffman_tree = tree; + } else if (literals.header.block_type == .treeless and decode_state.huffman_tree == null) { + return error.TreelessLiteralsFirst; + } + + switch (literals.header.block_type) { + .raw, .rle => {}, + .compressed, .treeless => { + decode_state.literal_stream_index = 0; + switch (literals.streams) { + .one => |slice| try decode_state.initLiteralStream(slice), + .four => |streams| try decode_state.initLiteralStream(streams[0]), + } + }, + } + + if (sequences_header.sequence_count > 0) { + var bytes_read = try decode_state.updateFseTable( + src, + .literal, + sequences_header.literal_lengths, + ); + + bytes_read += try decode_state.updateFseTable( + src[bytes_read..], + .offset, + sequences_header.offsets, + ); + + bytes_read += try decode_state.updateFseTable( + src[bytes_read..], + .match, + sequences_header.match_lengths, + ); + decode_state.fse_tables_undefined = false; + + return bytes_read; + } + return 0; +} + pub fn decodeBlock( dest: []u8, src: []const u8, block_header: frame.ZStandard.Block.Header, decode_state: *DecodeState, - first_compressed_block: *bool, - first_compressed_literals: *bool, consumed_count: *usize, written_count: usize, ) !usize { @@ -445,69 +508,14 @@ pub fn decodeBlock( if (block_maximum_size < block_size) return error.BlockSizeOverMaximum; // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks) switch (block_header.block_type) { - .raw => { - log.debug("writing raw block - size {d}", .{block_size}); - const data = src[0..block_size]; - std.mem.copy(u8, dest[written_count..], data); - consumed_count.* += block_size; - return block_size; - }, - .rle => { - log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size }); - var write_pos: usize = written_count; - while (write_pos < block_size + written_count) : (write_pos += 1) { - dest[write_pos] = src[0]; - } - consumed_count.* += 1; - return block_size; - }, + .raw => return decodeRawBlock(dest[written_count..], src, block_size, consumed_count), + .rle => return decodeRleBlock(dest[written_count..], src, block_size, consumed_count), .compressed => { var bytes_read: usize = 0; const literals = try decodeLiteralsSection(src, &bytes_read); const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read); - if (first_compressed_literals.* and literals.header.block_type == .treeless) - return error.TreelessLiteralsFirst; - - if (literals.huffman_tree) |tree| { - decode_state.huffman_tree = tree; - first_compressed_literals.* = false; - } - - switch (literals.header.block_type) { - .raw, .rle => {}, - .compressed, .treeless => { - decode_state.literal_stream_index = 0; - switch (literals.streams) { - .one => |slice| try decode_state.initLiteralStream(slice), - .four => |streams| try decode_state.initLiteralStream(streams[0]), - } - }, - } - - if (sequences_header.sequence_count > 0) { - bytes_read += try decode_state.updateFseTable( - src[bytes_read..], - .literal, - sequences_header.literal_lengths, - first_compressed_block.*, - ); - - bytes_read += try decode_state.updateFseTable( - src[bytes_read..], - .offset, - sequences_header.offsets, - first_compressed_block.*, - ); - - bytes_read += try decode_state.updateFseTable( - src[bytes_read..], - .match, - sequences_header.match_lengths, - first_compressed_block.*, - ); - first_compressed_block.* = false; - } + bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header); var bytes_written: usize = 0; if (sequences_header.sequence_count > 0) {