From 3bfba365483ccf30b197195cce8d5656f2c73736 Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Sat, 28 Jan 2023 21:03:55 +1100 Subject: [PATCH] std.compress.zstandard: clean up error sets and line lengths --- lib/std/compress/zstandard/decompress.zig | 307 +++++++++++++++------- lib/std/compress/zstandard/types.zig | 2 +- 2 files changed, 215 insertions(+), 94 deletions(-) diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 6a1ab364e7..ea72e33730 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -22,7 +22,7 @@ fn isSkippableMagic(magic: u32) bool { /// if the the frame is skippable, `null` for Zstanndard frames that do not /// declare their content size. Returns `UnusedBitSet` and `ReservedBitSet` /// errors if the respective bits of the the frame descriptor are set. -pub fn getFrameDecompressedSize(src: []const u8) !?u64 { +pub fn getFrameDecompressedSize(src: []const u8) (InvalidBit || error{BadMagic})!?u64 { switch (try frameType(src)) { .zstandard => { const header = try decodeZStandardHeader(src[4..], null); @@ -52,7 +52,11 @@ const ReadWriteCount = struct { /// Decodes the frame at the start of `src` into `dest`. Returns the number of /// bytes read from `src` and written to `dest`. -pub fn decodeFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWriteCount { +pub fn decodeFrame( + dest: []u8, + src: []const u8, + verify_checksum: bool, +) (error{ UnknownContentSizeUnsupported, ContentTooLarge, BadMagic } || FrameError)!ReadWriteCount { return switch (try frameType(src)) { .zstandard => decodeZStandardFrame(dest, src, verify_checksum), .skippable => ReadWriteCount{ @@ -100,7 +104,7 @@ pub const DecodeState = struct { src: []const u8, literals: LiteralsSection, sequences_header: SequencesSection.Header, - ) !usize { + ) (error{ BitStreamHasNoStartBit, TreelessLiteralsFirst } || FseTableError)!usize { if (literals.huffman_tree) |tree| { self.huffman_tree = tree; } else if (literals.header.block_type == .treeless and self.huffman_tree == null) { @@ -145,7 +149,7 @@ pub const DecodeState = struct { /// Read initial FSE states for sequence decoding. Returns `error.EndOfStream` /// if `bit_reader` does not contain enough bits. - pub fn readInitialFseState(self: *DecodeState, bit_reader: anytype) !void { + pub fn readInitialFseState(self: *DecodeState, bit_reader: *ReverseBitReader) error{EndOfStream}!void { self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log); self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log); self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log); @@ -169,7 +173,11 @@ pub const DecodeState = struct { const DataType = enum { offset, match, literal }; - fn updateState(self: *DecodeState, comptime choice: DataType, bit_reader: anytype) !void { + fn updateState( + self: *DecodeState, + comptime choice: DataType, + bit_reader: *ReverseBitReader, + ) error{ MalformedFseBits, EndOfStream }!void { switch (@field(self, @tagName(choice)).table) { .rle => {}, .fse => |table| { @@ -185,17 +193,27 @@ pub const DecodeState = struct { } } + const FseTableError = error{ + MalformedFseTable, + MalformedAccuracyLog, + RepeatModeFirst, + EndOfStream, + }; + fn updateFseTable( self: *DecodeState, src: []const u8, comptime choice: DataType, mode: SequencesSection.Header.Mode, - ) !usize { + ) FseTableError!usize { const field_name = @tagName(choice); switch (mode) { .predefined => { - @field(self, field_name).accuracy_log = @field(types.compressed_block.default_accuracy_log, field_name); - @field(self, field_name).table = @field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table"); + @field(self, field_name).accuracy_log = + @field(types.compressed_block.default_accuracy_log, field_name); + + @field(self, field_name).table = + @field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table"); return 0; }, .rle => { @@ -214,9 +232,11 @@ pub const DecodeState = struct { @field(types.compressed_block.table_accuracy_log_max, field_name), @field(self, field_name ++ "_fse_buffer"), ); - @field(self, field_name).table = .{ .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size] }; + @field(self, field_name).table = .{ + .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size], + }; @field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size); - return std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedFseTable; + return std.math.cast(usize, counting_reader.bytes_read) orelse error.MalformedFseTable; }, .repeat => return if (self.fse_tables_undefined) error.RepeatModeFirst else 0, } @@ -228,7 +248,10 @@ pub const DecodeState = struct { offset: u32, }; - fn nextSequence(self: *DecodeState, bit_reader: anytype) !Sequence { + fn nextSequence( + self: *DecodeState, + bit_reader: *ReverseBitReader, + ) error{ OffsetCodeTooLarge, EndOfStream }!Sequence { const raw_code = self.getCode(.offset); const offset_code = std.math.cast(u5, raw_code) orelse { return error.OffsetCodeTooLarge; @@ -272,7 +295,7 @@ pub const DecodeState = struct { write_pos: usize, literals: LiteralsSection, sequence: Sequence, - ) !void { + ) (error{MalformedSequence} || DecodeLiteralsError)!void { if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence; try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length); @@ -288,16 +311,23 @@ pub const DecodeState = struct { dest: *RingBuffer, literals: LiteralsSection, sequence: Sequence, - ) !void { + ) (error{MalformedSequence} || DecodeLiteralsError)!void { if (sequence.offset > dest.data.len) return error.MalformedSequence; try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length); - const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length); + const copy_start = dest.write_index + dest.data.len - sequence.offset; + const copy_slice = dest.sliceAt(copy_start, sequence.match_length); // TODO: would std.mem.copy and figuring out dest slice be better/faster? for (copy_slice.first) |b| dest.writeAssumeCapacity(b); for (copy_slice.second) |b| dest.writeAssumeCapacity(b); } + const DecodeSequenceError = error{ + OffsetCodeTooLarge, + EndOfStream, + MalformedSequence, + MalformedFseBits, + } || DecodeLiteralsError; /// Decode one sequence from `bit_reader` into `dest`, written starting at /// `write_pos` and update FSE states if `last_sequence` is `false`. Returns /// `error.MalformedSequence` error if the decompressed sequence would be longer @@ -311,10 +341,10 @@ pub const DecodeState = struct { dest: []u8, write_pos: usize, literals: LiteralsSection, - bit_reader: anytype, + bit_reader: *ReverseBitReader, sequence_size_limit: usize, last_sequence: bool, - ) !usize { + ) DecodeSequenceError!usize { const sequence = try self.nextSequence(bit_reader); const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; if (sequence_length > sequence_size_limit) return error.MalformedSequence; @@ -336,7 +366,7 @@ pub const DecodeState = struct { bit_reader: anytype, sequence_size_limit: usize, last_sequence: bool, - ) !usize { + ) DecodeSequenceError!usize { const sequence = try self.nextSequence(bit_reader); const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; if (sequence_length > sequence_size_limit) return error.MalformedSequence; @@ -350,26 +380,63 @@ pub const DecodeState = struct { return sequence_length; } - fn nextLiteralMultiStream(self: *DecodeState, literals: LiteralsSection) !void { + fn nextLiteralMultiStream( + self: *DecodeState, + literals: LiteralsSection, + ) error{BitStreamHasNoStartBit}!void { self.literal_stream_index += 1; try self.initLiteralStream(literals.streams.four[self.literal_stream_index]); } - fn initLiteralStream(self: *DecodeState, bytes: []const u8) !void { + fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void { try self.literal_stream_reader.init(bytes); } + const LiteralBitsError = error{ + BitStreamHasNoStartBit, + UnexpectedEndOfLiteralStream, + }; + fn readLiteralsBits( + self: *DecodeState, + comptime T: type, + bit_count_to_read: usize, + literals: LiteralsSection, + ) LiteralBitsError!T { + return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: { + if (literals.streams == .four and self.literal_stream_index < 3) { + try self.nextLiteralMultiStream(literals); + break :bits self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch + return error.UnexpectedEndOfLiteralStream; + } else { + return error.UnexpectedEndOfLiteralStream; + } + }; + } + + const DecodeLiteralsError = error{ + MalformedLiteralsLength, + PrefixNotFound, + } || LiteralBitsError; + /// Decode `len` bytes of literals into `dest`. `literals` should be the /// `LiteralsSection` that was passed to `prepare()`. Returns /// `error.MalformedLiteralsLength` if the number of literal bytes decoded by /// `self` plus `len` is greater than the regenerated size of `literals`. /// Returns `error.UnexpectedEndOfLiteralStream` and `error.PrefixNotFound` if /// there are problems decoding Huffman compressed literals. - pub fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: LiteralsSection, len: usize) !void { - if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength; + pub fn decodeLiteralsSlice( + self: *DecodeState, + dest: []u8, + literals: LiteralsSection, + len: usize, + ) DecodeLiteralsError!void { + if (self.literal_written_count + len > literals.header.regenerated_size) + return error.MalformedLiteralsLength; + switch (literals.header.block_type) { .raw => { - const literal_data = literals.streams.one[self.literal_written_count .. self.literal_written_count + len]; + const literals_end = self.literal_written_count + len; + const literal_data = literals.streams.one[self.literal_written_count..literals_end]; std.mem.copy(u8, dest, literal_data); self.literal_written_count += len; }, @@ -395,15 +462,7 @@ pub const DecodeState = struct { while (i < len) : (i += 1) { var prefix: u16 = 0; while (true) { - const new_bits = self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch |err| - switch (err) { - error.EndOfStream => if (literals.streams == .four and self.literal_stream_index < 3) bits: { - try self.nextLiteralMultiStream(literals); - break :bits try self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read); - } else { - return error.UnexpectedEndOfLiteralStream; - }, - }; + const new_bits = try self.readLiteralsBits(u16, bit_count_to_read, literals); prefix <<= bit_count_to_read; prefix |= new_bits; bits_read += bit_count_to_read; @@ -434,11 +493,19 @@ pub const DecodeState = struct { } /// Decode literals into `dest`; see `decodeLiteralsSlice()`. - pub fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: LiteralsSection, len: usize) !void { - if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength; + pub fn decodeLiteralsRingBuffer( + self: *DecodeState, + dest: *RingBuffer, + literals: LiteralsSection, + len: usize, + ) DecodeLiteralsError!void { + if (self.literal_written_count + len > literals.header.regenerated_size) + return error.MalformedLiteralsLength; + switch (literals.header.block_type) { .raw => { - const literal_data = literals.streams.one[self.literal_written_count .. self.literal_written_count + len]; + const literals_end = self.literal_written_count + len; + const literal_data = literals.streams.one[self.literal_written_count..literals_end]; dest.writeSliceAssumeCapacity(literal_data); self.literal_written_count += len; }, @@ -464,15 +531,7 @@ pub const DecodeState = struct { while (i < len) : (i += 1) { var prefix: u16 = 0; while (true) { - const new_bits = self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch |err| - switch (err) { - error.EndOfStream => if (literals.streams == .four and self.literal_stream_index < 3) bits: { - try self.nextLiteralMultiStream(literals); - break :bits try self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read); - } else { - return error.UnexpectedEndOfLiteralStream; - }, - }; + const new_bits = try self.readLiteralsBits(u16, bit_count_to_read, literals); prefix <<= bit_count_to_read; prefix |= new_bits; bits_read += bit_count_to_read; @@ -514,6 +573,11 @@ const literal_table_size_max = 1 << types.compressed_block.table_accuracy_log_ma const match_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match; const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match; +const FrameError = error{ + DictionaryIdFlagUnsupported, + ChecksumFailure, +} || InvalidBit || DecodeBlockError; + /// Decode a Zstandard frame from `src` into `dest`, returning the number of /// bytes read from `src` and written to `dest`; if the frame does not declare /// its decompressed content size `error.UnknownContentSizeUnsupported` is @@ -521,7 +585,11 @@ const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max /// dictionary, and `error.ChecksumFailure` if `verify_checksum` is `true` and /// the frame contains a checksum that does not match the checksum computed from /// the decompressed frame. -pub fn decodeZStandardFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWriteCount { +pub fn decodeZStandardFrame( + dest: []u8, + src: []const u8, + verify_checksum: bool, +) (error{ UnknownContentSizeUnsupported, ContentTooLarge } || FrameError)!ReadWriteCount { assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number); var consumed_count: usize = 4; @@ -530,13 +598,11 @@ pub fn decodeZStandardFrame(dest: []u8, src: []const u8, verify_checksum: bool) if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported; const content_size = frame_header.content_size orelse return error.UnknownContentSizeUnsupported; - // const window_size = frameWindowSize(header) orelse return error.WindowSizeUnknown; if (dest.len < content_size) return error.ContentTooLarge; const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum; var hash_state = if (should_compute_checksum) std.hash.XxHash64.init(0) else undefined; - // TODO: block_maximum_size should be @min(1 << 17, window_size); const written_count = try decodeFrameBlocks( dest, src[consumed_count..], @@ -567,7 +633,7 @@ pub fn decodeZStandardFrameAlloc( src: []const u8, verify_checksum: bool, window_size_max: usize, -) ![]u8 { +) (error{ WindowSizeUnknown, WindowTooLarge, OutOfMemory } || FrameError)![]u8 { var result = std.ArrayList(u8).init(allocator); assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number); var consumed_count: usize = 4; @@ -628,7 +694,7 @@ pub fn decodeZStandardFrameAlloc( block_header = decodeBlockHeader(src[consumed_count..][0..3]); consumed_count += 3; }) { - if (block_header.block_size > block_size_maximum) return error.CompressedBlockSizeOverMaximum; + if (block_header.block_size > block_size_maximum) return error.BlockSizeOverMaximum; const written_size = try decodeBlockRingBuffer( &ring_buffer, src[consumed_count..], @@ -637,7 +703,7 @@ pub fn decodeZStandardFrameAlloc( &consumed_count, block_size_maximum, ); - if (written_size > block_size_maximum) return error.DecompressedBlockSizeOverMaximum; + if (written_size > block_size_maximum) return error.BlockSizeOverMaximum; const written_slice = ring_buffer.sliceLast(written_size); try result.appendSlice(written_slice.first); try result.appendSlice(written_slice.second); @@ -650,8 +716,21 @@ pub fn decodeZStandardFrameAlloc( return result.toOwnedSlice(); } +const DecodeBlockError = error{ + BlockSizeOverMaximum, + MalformedBlockSize, + ReservedBlock, + MalformedRleBlock, + MalformedCompressedBlock, +}; + /// Convenience wrapper for decoding all blocks in a frame; see `decodeBlock()`. -pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, hash: ?*std.hash.XxHash64) !usize { +pub fn decodeFrameBlocks( + dest: []u8, + src: []const u8, + consumed_count: *usize, + hash: ?*std.hash.XxHash64, +) DecodeBlockError!usize { // These tables take 7680 bytes var literal_fse_data: [literal_table_size_max]Table.Fse = undefined; var match_fse_data: [match_table_size_max]Table.Fse = undefined; @@ -702,7 +781,12 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha return written_count; } -fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) !usize { +fn decodeRawBlock( + dest: []u8, + src: []const u8, + block_size: u21, + consumed_count: *usize, +) error{MalformedBlockSize}!usize { if (src.len < block_size) return error.MalformedBlockSize; const data = src[0..block_size]; std.mem.copy(u8, dest, data); @@ -710,7 +794,12 @@ fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: return block_size; } -fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) !usize { +fn decodeRawBlockRingBuffer( + dest: *RingBuffer, + src: []const u8, + block_size: u21, + consumed_count: *usize, +) error{MalformedBlockSize}!usize { if (src.len < block_size) return error.MalformedBlockSize; const data = src[0..block_size]; dest.writeSliceAssumeCapacity(data); @@ -718,7 +807,12 @@ fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, return block_size; } -fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) !usize { +fn decodeRleBlock( + dest: []u8, + src: []const u8, + block_size: u21, + consumed_count: *usize, +) error{MalformedRleBlock}!usize { if (src.len < 1) return error.MalformedRleBlock; var write_pos: usize = 0; while (write_pos < block_size) : (write_pos += 1) { @@ -728,7 +822,12 @@ fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: return block_size; } -fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) !usize { +fn decodeRleBlockRingBuffer( + dest: *RingBuffer, + src: []const u8, + block_size: u21, + consumed_count: *usize, +) error{MalformedRleBlock}!usize { if (src.len < 1) return error.MalformedRleBlock; var write_pos: usize = 0; while (write_pos < block_size) : (write_pos += 1) { @@ -749,7 +848,7 @@ pub fn decodeBlock( decode_state: *DecodeState, consumed_count: *usize, written_count: usize, -) !usize { +) DecodeBlockError!usize { const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB const block_size = block_header.block_size; if (block_size_max < block_size) return error.BlockSizeOverMaximum; @@ -759,31 +858,33 @@ pub fn decodeBlock( .compressed => { if (src.len < block_size) return error.MalformedBlockSize; var bytes_read: usize = 0; - const literals = try decodeLiteralsSection(src, &bytes_read); - const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read); + const literals = decodeLiteralsSection(src, &bytes_read) catch return error.MalformedCompressedBlock; + const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch + return error.MalformedCompressedBlock; - bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header); + bytes_read += decode_state.prepare(src[bytes_read..], literals, sequences_header) catch + return error.MalformedCompressedBlock; var bytes_written: usize = 0; if (sequences_header.sequence_count > 0) { const bit_stream_bytes = src[bytes_read..block_size]; var bit_stream: ReverseBitReader = undefined; - try bit_stream.init(bit_stream_bytes); + bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock; - try decode_state.readInitialFseState(&bit_stream); + decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock; var sequence_size_limit = block_size_max; var i: usize = 0; while (i < sequences_header.sequence_count) : (i += 1) { const write_pos = written_count + bytes_written; - const decompressed_size = try decode_state.decodeSequenceSlice( + const decompressed_size = decode_state.decodeSequenceSlice( dest, write_pos, literals, &bit_stream, sequence_size_limit, i == sequences_header.sequence_count - 1, - ); + ) catch return error.MalformedCompressedBlock; bytes_written += decompressed_size; sequence_size_limit -= decompressed_size; } @@ -793,7 +894,8 @@ pub fn decodeBlock( if (decode_state.literal_written_count < literals.header.regenerated_size) { const len = literals.header.regenerated_size - decode_state.literal_written_count; - try decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len); + decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len) catch + return error.MalformedCompressedBlock; bytes_written += len; } @@ -802,7 +904,7 @@ pub fn decodeBlock( consumed_count.* += bytes_read; return bytes_written; }, - .reserved => return error.FrameContainsReservedBlock, + .reserved => return error.ReservedBlock, } } @@ -816,7 +918,7 @@ pub fn decodeBlockRingBuffer( decode_state: *DecodeState, consumed_count: *usize, block_size_max: usize, -) !usize { +) DecodeBlockError!usize { const block_size = block_header.block_size; if (block_size_max < block_size) return error.BlockSizeOverMaximum; switch (block_header.block_type) { @@ -825,29 +927,31 @@ pub fn decodeBlockRingBuffer( .compressed => { if (src.len < block_size) return error.MalformedBlockSize; var bytes_read: usize = 0; - const literals = try decodeLiteralsSection(src, &bytes_read); - const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read); + const literals = decodeLiteralsSection(src, &bytes_read) catch return error.MalformedCompressedBlock; + const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch + return error.MalformedCompressedBlock; - bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header); + bytes_read += decode_state.prepare(src[bytes_read..], literals, sequences_header) catch + return error.MalformedCompressedBlock; var bytes_written: usize = 0; if (sequences_header.sequence_count > 0) { const bit_stream_bytes = src[bytes_read..block_size]; var bit_stream: ReverseBitReader = undefined; - try bit_stream.init(bit_stream_bytes); + bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock; - try decode_state.readInitialFseState(&bit_stream); + decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock; var sequence_size_limit = block_size_max; var i: usize = 0; while (i < sequences_header.sequence_count) : (i += 1) { - const decompressed_size = try decode_state.decodeSequenceRingBuffer( + const decompressed_size = decode_state.decodeSequenceRingBuffer( dest, literals, &bit_stream, sequence_size_limit, i == sequences_header.sequence_count - 1, - ); + ) catch return error.MalformedCompressedBlock; bytes_written += decompressed_size; sequence_size_limit -= decompressed_size; } @@ -857,7 +961,8 @@ pub fn decodeBlockRingBuffer( if (decode_state.literal_written_count < literals.header.regenerated_size) { const len = literals.header.regenerated_size - decode_state.literal_written_count; - try decode_state.decodeLiteralsRingBuffer(dest, literals, len); + decode_state.decodeLiteralsRingBuffer(dest, literals, len) catch + return error.MalformedCompressedBlock; bytes_written += len; } @@ -866,7 +971,7 @@ pub fn decodeBlockRingBuffer( consumed_count.* += bytes_read; return bytes_written; }, - .reserved => return error.FrameContainsReservedBlock, + .reserved => return error.ReservedBlock, } } @@ -901,9 +1006,10 @@ pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 { } else return header.content_size; } +const InvalidBit = error{ UnusedBitSet, ReservedBitSet }; /// Decode the header of a Zstandard frame. Returns `error.UnusedBitSet` or /// `error.ReservedBitSet` if the corresponding bits are sets. -pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) !frame.ZStandard.Header { +pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) InvalidBit!frame.ZStandard.Header { const descriptor = @bitCast(frame.ZStandard.Header.Descriptor, src[0]); if (descriptor.unused) return error.UnusedBitSet; @@ -958,7 +1064,10 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header { /// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the /// number of bytes the section uses. -pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsSection { +pub fn decodeLiteralsSection( + src: []const u8, + consumed_count: *usize, +) (error{ MalformedLiteralsHeader, MalformedLiteralsSection } || DecodeHuffmanError)!LiteralsSection { var bytes_read: usize = 0; const header = try decodeLiteralsHeader(src, &bytes_read); switch (header.block_type) { @@ -1032,7 +1141,13 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsS } } -fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.HuffmanTree { +const DecodeHuffmanError = error{ + MalformedHuffmanTree, + MalformedFseTable, + MalformedAccuracyLog, +}; + +fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError!LiteralsSection.HuffmanTree { var bytes_read: usize = 0; bytes_read += 1; if (src.len == 0) return error.MalformedHuffmanTree; @@ -1049,22 +1164,25 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.H var bit_reader = bitReader(counting_reader.reader()); var entries: [1 << 6]Table.Fse = undefined; - const table_size = try decodeFseTable(&bit_reader, 256, 6, &entries); + const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) { + error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e, + error.EndOfStream => return error.MalformedFseTable, + }; const accuracy_log = std.math.log2_int_ceil(usize, table_size); const start_index = std.math.cast(usize, 1 + counting_reader.bytes_read) orelse return error.MalformedHuffmanTree; var huff_data = src[start_index .. compressed_size + 1]; var huff_bits: ReverseBitReader = undefined; - try huff_bits.init(huff_data); + huff_bits.init(huff_data) catch return error.MalformedHuffmanTree; var i: usize = 0; - var even_state: u32 = try huff_bits.readBitsNoEof(u32, accuracy_log); - var odd_state: u32 = try huff_bits.readBitsNoEof(u32, accuracy_log); + var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; + var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree; while (i < 255) { const even_data = entries[even_state]; var read_bits: usize = 0; - const even_bits = try huff_bits.readBits(u32, even_data.bits, &read_bits); + const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable; weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree; i += 1; if (read_bits < even_data.bits) { @@ -1076,7 +1194,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.H read_bits = 0; const odd_data = entries[odd_state]; - const odd_bits = try huff_bits.readBits(u32, odd_data.bits, &read_bits); + const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable; weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree; i += 1; if (read_bits < odd_data.bits) { @@ -1177,8 +1295,8 @@ fn lessThanByWeight( } /// Decode a literals section header. -pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) !LiteralsSection.Header { - if (src.len == 0) return error.MalformedLiteralsSection; +pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) error{MalformedLiteralsHeader}!LiteralsSection.Header { + if (src.len == 0) return error.MalformedLiteralsHeader; const byte0 = src[0]; const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11); const size_format = @intCast(u2, (byte0 & 0b1100) >> 2); @@ -1243,8 +1361,11 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) !LiteralsSe } /// Decode a sequences section header. -pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !SequencesSection.Header { - if (src.len == 0) return error.MalformedSequencesSection; +pub fn decodeSequencesHeader( + src: []const u8, + consumed_count: *usize, +) error{ MalformedSequencesHeader, ReservedBitSet }!SequencesSection.Header { + if (src.len == 0) return error.MalformedSequencesHeader; var sequence_count: u24 = undefined; var bytes_read: usize = 0; @@ -1262,16 +1383,16 @@ pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences sequence_count = byte0; bytes_read += 1; } else if (byte0 < 255) { - if (src.len < 2) return error.MalformedSequencesSection; + if (src.len < 2) return error.MalformedSequencesHeader; sequence_count = (@as(u24, (byte0 - 128)) << 8) + src[1]; bytes_read += 2; } else { - if (src.len < 3) return error.MalformedSequencesSection; + if (src.len < 3) return error.MalformedSequencesHeader; sequence_count = src[1] + (@as(u24, src[2]) << 8) + 0x7F00; bytes_read += 3; } - if (src.len < bytes_read + 1) return error.MalformedSequencesSection; + if (src.len < bytes_read + 1) return error.MalformedSequencesHeader; const compression_modes = src[bytes_read]; bytes_read += 1; @@ -1441,17 +1562,17 @@ pub const ReverseBitReader = struct { byte_reader: ReversedByteReader, bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader), - pub fn init(self: *ReverseBitReader, bytes: []const u8) !void { + pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void { self.byte_reader = ReversedByteReader.init(bytes); self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader()); while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {} } - pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U { + pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U { return self.bit_reader.readBitsNoEof(U, num_bits); } - pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U { + pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) error{}!U { return try self.bit_reader.readBits(U, num_bits, out_bits); } diff --git a/lib/std/compress/zstandard/types.zig b/lib/std/compress/zstandard/types.zig index f703dc29eb..37a716a9d7 100644 --- a/lib/std/compress/zstandard/types.zig +++ b/lib/std/compress/zstandard/types.zig @@ -92,7 +92,7 @@ pub const compressed_block = struct { index: usize, }; - pub fn query(self: HuffmanTree, index: usize, prefix: u16) !Result { + pub fn query(self: HuffmanTree, index: usize, prefix: u16) error{PrefixNotFound}!Result { var node = self.nodes[index]; const weight = node.weight; var i: usize = index;