From 32cf1d7cbf7852b7c66d1c026b0003690e9f7337 Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Tue, 21 Feb 2023 17:14:45 +1100 Subject: [PATCH] std.compress.zstandard: fix error sets for streaming API --- lib/std/compress/zstandard/decode/huffman.zig | 28 +++++++++++++++---- lib/std/compress/zstandard/decompress.zig | 6 ++-- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/lib/std/compress/zstandard/decode/huffman.zig b/lib/std/compress/zstandard/decode/huffman.zig index 68aac85320..2914198268 100644 --- a/lib/std/compress/zstandard/decode/huffman.zig +++ b/lib/std/compress/zstandard/decode/huffman.zig @@ -15,7 +15,12 @@ pub const Error = error{ EndOfStream, }; -fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, weights: *[256]u4) !usize { +fn decodeFseHuffmanTree( + source: anytype, + compressed_size: usize, + buffer: []u8, + weights: *[256]u4, +) !usize { var stream = std.io.limitedReader(source, compressed_size); var bit_reader = readers.bitReader(stream.reader()); @@ -23,6 +28,7 @@ fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, w const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) { error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e, error.EndOfStream => return error.MalformedFseTable, + else => |e| return e, }; const accuracy_log = std.math.log2_int_ceil(usize, table_size); @@ -46,7 +52,8 @@ fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: * }; const accuracy_log = std.math.log2_int_ceil(usize, table_size); - const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedHuffmanTree; + const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse + return error.MalformedHuffmanTree; var huff_data = src[start_index..compressed_size]; var huff_bits: readers.ReverseBitReader = undefined; huff_bits.init(huff_data) catch return error.MalformedHuffmanTree; @@ -54,7 +61,12 @@ fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: * return assignWeights(&huff_bits, accuracy_log, &entries, weights); } -fn assignWeights(huff_bits: *readers.ReverseBitReader, accuracy_log: usize, entries: *[1 << 6]Table.Fse, weights: *[256]u4) !usize { +fn assignWeights( + huff_bits: *readers.ReverseBitReader, + accuracy_log: usize, + entries: *[1 << 6]Table.Fse, + weights: *[256]u4, +) !usize { var i: usize = 0; var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; @@ -173,7 +185,10 @@ fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) error{MalformedHuffm return tree; } -pub fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.HuffmanTree { +pub fn decodeHuffmanTree( + source: anytype, + buffer: []u8, +) (@TypeOf(source).Error || Error)!LiteralsSection.HuffmanTree { const header = try source.readByte(); var weights: [256]u4 = undefined; const symbol_count = if (header < 128) @@ -185,7 +200,10 @@ pub fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.Huffman return buildHuffmanTree(&weights, symbol_count); } -pub fn decodeHuffmanTreeSlice(src: []const u8, consumed_count: *usize) Error!LiteralsSection.HuffmanTree { +pub fn decodeHuffmanTreeSlice( + src: []const u8, + consumed_count: *usize, +) Error!LiteralsSection.HuffmanTree { if (src.len == 0) return error.MalformedHuffmanTree; const header = src[0]; var bytes_read: usize = 1; diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index ffa01b94f1..a2ba59e688 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -64,7 +64,7 @@ pub const HeaderError = error{ BadMagic, EndOfStream, ReservedBitSet }; /// - `error.EndOfStream` if `source` contains fewer than 4 bytes /// - `error.ReservedBitSet` if the frame is a Zstandard frame and any of the /// reserved bits are set -pub fn decodeFrameHeader(source: anytype) HeaderError!FrameHeader { +pub fn decodeFrameHeader(source: anytype) (@TypeOf(source).Error || HeaderError)!FrameHeader { const magic = try source.readIntLittle(u32); const frame_type = try frameType(magic); switch (frame_type) { @@ -596,7 +596,9 @@ pub fn frameWindowSize(header: ZstandardHeader) ?u64 { /// Errors returned: /// - `error.ReservedBitSet` if any of the reserved bits of the header are set /// - `error.EndOfStream` if `source` does not contain a complete header -pub fn decodeZstandardHeader(source: anytype) error{ EndOfStream, ReservedBitSet }!ZstandardHeader { +pub fn decodeZstandardHeader( + source: anytype, +) (@TypeOf(source).Error || error{ EndOfStream, ReservedBitSet })!ZstandardHeader { const descriptor = @bitCast(ZstandardHeader.Descriptor, try source.readByte()); if (descriptor.reserved) return error.ReservedBitSet;