diff --git a/lib/std/compress/zstandard/decode/block.zig b/lib/std/compress/zstandard/decode/block.zig index a0f968736d..ab1d476dd6 100644 --- a/lib/std/compress/zstandard/decode/block.zig +++ b/lib/std/compress/zstandard/decode/block.zig @@ -148,8 +148,8 @@ pub const DecodeState = struct { } fn updateRepeatOffset(self: *DecodeState, offset: u32) void { - std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1]); - std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]); + self.repeat_offsets[2] = self.repeat_offsets[1]; + self.repeat_offsets[1] = self.repeat_offsets[0]; self.repeat_offsets[0] = offset; } @@ -238,18 +238,22 @@ pub const DecodeState = struct { fn nextSequence( self: *DecodeState, bit_reader: *readers.ReverseBitReader, - ) error{ OffsetCodeTooLarge, EndOfStream }!Sequence { + ) error{ InvalidBitStream, EndOfStream }!Sequence { const raw_code = self.getCode(.offset); const offset_code = std.math.cast(u5, raw_code) orelse { - return error.OffsetCodeTooLarge; + return error.InvalidBitStream; }; const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code); const match_code = self.getCode(.match); + if (match_code >= types.compressed_block.match_length_code_table.len) + return error.InvalidBitStream; const match = types.compressed_block.match_length_code_table[match_code]; const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]); const literal_code = self.getCode(.literal); + if (literal_code >= types.compressed_block.literals_length_code_table.len) + return error.InvalidBitStream; const literal = types.compressed_block.literals_length_code_table[literal_code]; const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]); @@ -269,6 +273,8 @@ pub const DecodeState = struct { break :offset self.useRepeatOffset(offset_value - 1); }; + if (offset == 0) return error.InvalidBitStream; + return .{ .literal_length = literal_length, .match_length = match_length, @@ -308,7 +314,7 @@ pub const DecodeState = struct { } const DecodeSequenceError = error{ - OffsetCodeTooLarge, + InvalidBitStream, EndOfStream, MalformedSequence, MalformedFseBits, @@ -326,7 +332,7 @@ pub const DecodeState = struct { /// - `error.UnexpectedEndOfLiteralStream` if the decoder state's literal /// streams do not contain enough literals for the sequence (this may /// mean the literal stream or the sequence is malformed). - /// - `error.OffsetCodeTooLarge` if an invalid offset code is found + /// - `error.InvalidBitStream` if the FSE sequence bitstream is malformed /// - `error.EndOfStream` if `bit_reader` does not contain enough bits pub fn decodeSequenceSlice( self: *DecodeState, @@ -608,9 +614,9 @@ pub fn decodeBlock( .compressed => { if (src.len < block_size) return error.MalformedBlockSize; var bytes_read: usize = 0; - const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch + const literals = decodeLiteralsSectionSlice(src[0..block_size], &bytes_read) catch return error.MalformedCompressedBlock; - var fbs = std.io.fixedBufferStream(src[bytes_read..]); + var fbs = std.io.fixedBufferStream(src[bytes_read..block_size]); const fbs_reader = fbs.reader(); const sequences_header = decodeSequencesHeader(fbs_reader) catch return error.MalformedCompressedBlock; @@ -695,9 +701,9 @@ pub fn decodeBlockRingBuffer( .compressed => { if (src.len < block_size) return error.MalformedBlockSize; var bytes_read: usize = 0; - const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch + const literals = decodeLiteralsSectionSlice(src[0..block_size], &bytes_read) catch return error.MalformedCompressedBlock; - var fbs = std.io.fixedBufferStream(src[bytes_read..]); + var fbs = std.io.fixedBufferStream(src[bytes_read..block_size]); const fbs_reader = fbs.reader(); const sequences_header = decodeSequencesHeader(fbs_reader) catch return error.MalformedCompressedBlock; @@ -894,7 +900,8 @@ pub fn decodeLiteralsSectionSlice( else null; const huffman_tree_size = bytes_read - huffman_tree_start; - const total_streams_size = @as(usize, header.compressed_size.?) - huffman_tree_size; + const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch + return error.MalformedLiteralsSection; if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection; const stream_data = src[bytes_read .. bytes_read + total_streams_size]; @@ -940,8 +947,9 @@ pub fn decodeLiteralsSection( try huffman.decodeHuffmanTree(counting_reader.reader(), buffer) else null; - const huffman_tree_size = counting_reader.bytes_read; - const total_streams_size = @as(usize, header.compressed_size.?) - @intCast(usize, huffman_tree_size); + const huffman_tree_size = @intCast(usize, counting_reader.bytes_read); + const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch + return error.MalformedLiteralsSection; if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall; try source.readNoEof(buffer[0..total_streams_size]); diff --git a/lib/std/compress/zstandard/decode/huffman.zig b/lib/std/compress/zstandard/decode/huffman.zig index c759bfd6ab..01913c7044 100644 --- a/lib/std/compress/zstandard/decode/huffman.zig +++ b/lib/std/compress/zstandard/decode/huffman.zig @@ -146,13 +146,14 @@ fn assignSymbols(weight_sorted_prefixed_symbols: []LiteralsSection.HuffmanTree.P return prefixed_symbol_count; } -fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) LiteralsSection.HuffmanTree { +fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) error{MalformedHuffmanTree}!LiteralsSection.HuffmanTree { var weight_power_sum: u16 = 0; for (weights[0 .. symbol_count - 1]) |value| { if (value > 0) { weight_power_sum += @as(u16, 1) << (value - 1); } } + if (weight_power_sum >= 1 << 11) return error.MalformedHuffmanTree; // advance to next power of two (even if weight_power_sum is a power of 2) const max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1; diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index da6e161043..cf172a16ca 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -195,6 +195,7 @@ pub fn decodeZstandardFrame( ); if (frame_header.descriptor.content_checksum_flag) { + if (src.len < consumed_count + 4) return error.EndOfStream; const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]); consumed_count += 4; if (hasher_opt) |*hasher| { @@ -302,17 +303,20 @@ pub fn decodeZstandardFrameAlloc( &consumed_count, frame_context.block_size_max, ); - const written_slice = ring_buffer.sliceLast(written_size); - try result.appendSlice(written_slice.first); - try result.appendSlice(written_slice.second); - if (frame_context.hasher_opt) |*hasher| { - hasher.update(written_slice.first); - hasher.update(written_slice.second); + if (written_size > 0) { + const written_slice = ring_buffer.sliceLast(written_size); + try result.appendSlice(written_slice.first); + try result.appendSlice(written_slice.second); + if (frame_context.hasher_opt) |*hasher| { + hasher.update(written_slice.first); + hasher.update(written_slice.second); + } } if (block_header.last_block) break; } if (frame_context.has_checksum) { + if (src.len < consumed_count + 4) return error.EndOfStream; const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]); consumed_count += 4; if (frame_context.hasher_opt) |*hasher| {