From cbfaa876d466a885c54ead16a5901399619ed0c8 Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Mon, 23 Jan 2023 12:47:46 +1100 Subject: [PATCH] std.compress.zstandard: cleanup ReverseBitReader --- lib/std/compress/zstandard/decompress.zig | 91 +++++++++++------------ 1 file changed, 42 insertions(+), 49 deletions(-) diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index b5a37878d1..22ed22c0de 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -68,8 +68,7 @@ const DecodeState = struct { fse_tables_undefined: bool, - literal_stream_reader: ReverseBitReader(ReversedByteReader.Reader), - literal_stream_bytes: ReversedByteReader, + literal_stream_reader: ReverseBitReader, literal_stream_index: usize, huffman_tree: ?Literals.HuffmanTree, @@ -288,9 +287,7 @@ const DecodeState = struct { fn initLiteralStream(self: *DecodeState, bytes: []const u8) !void { log.debug("initing literal stream: {}", .{std.fmt.fmtSliceHexUpper(bytes)}); - self.literal_stream_bytes = reversedByteReader(bytes); - self.literal_stream_reader = reverseBitReader(self.literal_stream_bytes.reader()); - while (0 == try self.literal_stream_reader.readBitsNoEof(u1, 1)) {} + try self.literal_stream_reader.init(bytes); } fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void { @@ -532,7 +529,6 @@ pub fn decodeZStandardFrameAlloc(allocator: std.mem.Allocator, src: []const u8, .literal_written_count = 0, .literal_stream_reader = undefined, - .literal_stream_bytes = undefined, .literal_stream_index = undefined, .huffman_tree = null, }; @@ -591,7 +587,6 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha .literal_written_count = 0, .literal_stream_reader = undefined, - .literal_stream_bytes = undefined, .literal_stream_index = undefined, .huffman_tree = null, }; @@ -725,10 +720,9 @@ pub fn decodeBlock( var bytes_written: usize = 0; if (sequences_header.sequence_count > 0) { const bit_stream_bytes = src[bytes_read..block_size]; - var reverse_byte_reader = reversedByteReader(bit_stream_bytes); - var bit_stream = reverseBitReader(reverse_byte_reader.reader()); + var bit_stream: ReverseBitReader = undefined; + try bit_stream.init(bit_stream_bytes); - while (0 == try bit_stream.readBitsNoEof(u1, 1)) {} try decode_state.readInitialState(&bit_stream); var i: usize = 0; @@ -791,10 +785,9 @@ pub fn decodeBlockRingBuffer( var bytes_written: usize = 0; if (sequences_header.sequence_count > 0) { const bit_stream_bytes = src[bytes_read..block_size]; - var reverse_byte_reader = reversedByteReader(bit_stream_bytes); - var bit_stream = reverseBitReader(reverse_byte_reader.reader()); + var bit_stream: ReverseBitReader = undefined; + try bit_stream.init(bit_stream_bytes); - while (0 == try bit_stream.readBitsNoEof(u1, 1)) {} try decode_state.readInitialState(&bit_stream); var i: usize = 0; @@ -1028,9 +1021,8 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT const accuracy_log = std.math.log2_int_ceil(usize, table_size); var huff_data = src[1 + counting_reader.bytes_read .. compressed_size + 1]; - var huff_data_bytes = reversedByteReader(huff_data); - var huff_bits = reverseBitReader(huff_data_bytes.reader()); - while (0 == try huff_bits.readBitsNoEof(u1, 1)) {} + var huff_bits: ReverseBitReader = undefined; + try huff_bits.init(huff_data); dumpFseTable("huffman", entries[0..table_size]); @@ -1415,48 +1407,49 @@ const ReversedByteReader = struct { const Reader = std.io.Reader(*ReversedByteReader, error{}, readFn); + fn init(bytes: []const u8) ReversedByteReader { + return .{ + .bytes = bytes, + .remaining_bytes = bytes.len, + }; + } + fn reader(self: *ReversedByteReader) Reader { return .{ .context = self }; } + + fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize { + if (ctx.remaining_bytes == 0) return 0; + const byte_index = ctx.remaining_bytes - 1; + buffer[0] = ctx.bytes[byte_index]; + // buffer[0] = @bitReverse(ctx.bytes[byte_index]); + ctx.remaining_bytes = byte_index; + return 1; + } }; -fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize { - if (ctx.remaining_bytes == 0) return 0; - const byte_index = ctx.remaining_bytes - 1; - buffer[0] = ctx.bytes[byte_index]; - // buffer[0] = @bitReverse(ctx.bytes[byte_index]); - ctx.remaining_bytes = byte_index; - return 1; -} +const ReverseBitReader = struct { + byte_reader: ReversedByteReader, + bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader), -fn reversedByteReader(bytes: []const u8) ReversedByteReader { - return ReversedByteReader{ - .remaining_bytes = bytes.len, - .bytes = bytes, - }; -} + fn init(self: *ReverseBitReader, bytes: []const u8) !void { + self.byte_reader = ReversedByteReader.init(bytes); + self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader()); + while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {} + } -fn ReverseBitReader(comptime Reader: type) type { - return struct { - underlying: std.io.BitReader(.Big, Reader), + fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U { + return self.bit_reader.readBitsNoEof(U, num_bits); + } - fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U { - return self.underlying.readBitsNoEof(U, num_bits); - } + fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U { + return try self.bit_reader.readBits(U, num_bits, out_bits); + } - fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U { - return try self.underlying.readBits(U, num_bits, out_bits); - } - - fn alignToByte(self: *@This()) void { - self.underlying.alignToByte(); - } - }; -} - -fn reverseBitReader(reader: anytype) ReverseBitReader(@TypeOf(reader)) { - return .{ .underlying = std.io.bitReader(.Big, reader) }; -} + fn alignToByte(self: *@This()) void { + self.bit_reader.alignToByte(); + } +}; fn BitReader(comptime Reader: type) type { return struct {