diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 090110f9d0..0caf31fa33 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -81,6 +81,8 @@ pub const DecodeState = struct { literal_stream_reader: ReverseBitReader, literal_stream_index: usize, + literal_streams: LiteralsSection.Streams, + literal_header: LiteralsSection.Header, huffman_tree: ?LiteralsSection.HuffmanTree, literal_written_count: usize, @@ -105,6 +107,10 @@ pub const DecodeState = struct { literals: LiteralsSection, sequences_header: SequencesSection.Header, ) (error{ BitStreamHasNoStartBit, TreelessLiteralsFirst } || FseTableError)!usize { + self.literal_written_count = 0; + self.literal_header = literals.header; + self.literal_streams = literals.streams; + if (literals.huffman_tree) |tree| { self.huffman_tree = tree; } else if (literals.header.block_type == .treeless and self.huffman_tree == null) { @@ -293,12 +299,11 @@ pub const DecodeState = struct { self: *DecodeState, dest: []u8, write_pos: usize, - literals: LiteralsSection, sequence: Sequence, ) (error{MalformedSequence} || DecodeLiteralsError)!void { if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence; - try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length); + try self.decodeLiteralsSlice(dest[write_pos..], sequence.literal_length); const copy_start = write_pos + sequence.literal_length - sequence.offset; const copy_end = copy_start + sequence.match_length; // NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr @@ -309,12 +314,11 @@ pub const DecodeState = struct { fn executeSequenceRingBuffer( self: *DecodeState, dest: *RingBuffer, - literals: LiteralsSection, sequence: Sequence, ) (error{MalformedSequence} || DecodeLiteralsError)!void { if (sequence.offset > dest.data.len) return error.MalformedSequence; - try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length); + try self.decodeLiteralsRingBuffer(dest, sequence.literal_length); const copy_start = dest.write_index + dest.data.len - sequence.offset; const copy_slice = dest.sliceAt(copy_start, sequence.match_length); // TODO: would std.mem.copy and figuring out dest slice be better/faster? @@ -328,6 +332,7 @@ pub const DecodeState = struct { MalformedSequence, MalformedFseBits, } || DecodeLiteralsError; + /// Decode one sequence from `bit_reader` into `dest`, written starting at /// `write_pos` and update FSE states if `last_sequence` is `false`. Returns /// `error.MalformedSequence` error if the decompressed sequence would be longer @@ -340,7 +345,6 @@ pub const DecodeState = struct { self: *DecodeState, dest: []u8, write_pos: usize, - literals: LiteralsSection, bit_reader: *ReverseBitReader, sequence_size_limit: usize, last_sequence: bool, @@ -349,7 +353,7 @@ pub const DecodeState = struct { const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; if (sequence_length > sequence_size_limit) return error.MalformedSequence; - try self.executeSequenceSlice(dest, write_pos, literals, sequence); + try self.executeSequenceSlice(dest, write_pos, sequence); if (!last_sequence) { try self.updateState(.literal, bit_reader); try self.updateState(.match, bit_reader); @@ -362,7 +366,6 @@ pub const DecodeState = struct { pub fn decodeSequenceRingBuffer( self: *DecodeState, dest: *RingBuffer, - literals: LiteralsSection, bit_reader: anytype, sequence_size_limit: usize, last_sequence: bool, @@ -371,7 +374,7 @@ pub const DecodeState = struct { const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; if (sequence_length > sequence_size_limit) return error.MalformedSequence; - try self.executeSequenceRingBuffer(dest, literals, sequence); + try self.executeSequenceRingBuffer(dest, sequence); if (!last_sequence) { try self.updateState(.literal, bit_reader); try self.updateState(.match, bit_reader); @@ -382,13 +385,12 @@ pub const DecodeState = struct { fn nextLiteralMultiStream( self: *DecodeState, - literals: LiteralsSection, ) error{BitStreamHasNoStartBit}!void { self.literal_stream_index += 1; - try self.initLiteralStream(literals.streams.four[self.literal_stream_index]); + try self.initLiteralStream(self.literal_streams.four[self.literal_stream_index]); } - fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void { + pub fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void { try self.literal_stream_reader.init(bytes); } @@ -400,11 +402,10 @@ pub const DecodeState = struct { self: *DecodeState, comptime T: type, bit_count_to_read: usize, - literals: LiteralsSection, ) LiteralBitsError!T { return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: { - if (literals.streams == .four and self.literal_stream_index < 3) { - try self.nextLiteralMultiStream(literals); + if (self.literal_streams == .four and self.literal_stream_index < 3) { + try self.nextLiteralMultiStream(); break :bits self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch return error.UnexpectedEndOfLiteralStream; } else { @@ -427,23 +428,22 @@ pub const DecodeState = struct { pub fn decodeLiteralsSlice( self: *DecodeState, dest: []u8, - literals: LiteralsSection, len: usize, ) DecodeLiteralsError!void { - if (self.literal_written_count + len > literals.header.regenerated_size) + if (self.literal_written_count + len > self.literal_header.regenerated_size) return error.MalformedLiteralsLength; - switch (literals.header.block_type) { + switch (self.literal_header.block_type) { .raw => { const literals_end = self.literal_written_count + len; - const literal_data = literals.streams.one[self.literal_written_count..literals_end]; + const literal_data = self.literal_streams.one[self.literal_written_count..literals_end]; std.mem.copy(u8, dest, literal_data); self.literal_written_count += len; }, .rle => { var i: usize = 0; while (i < len) : (i += 1) { - dest[i] = literals.streams.one[0]; + dest[i] = self.literal_streams.one[0]; } self.literal_written_count += len; }, @@ -462,7 +462,7 @@ pub const DecodeState = struct { while (i < len) : (i += 1) { var prefix: u16 = 0; while (true) { - const new_bits = try self.readLiteralsBits(u16, bit_count_to_read, literals); + const new_bits = try self.readLiteralsBits(u16, bit_count_to_read); prefix <<= bit_count_to_read; prefix |= new_bits; bits_read += bit_count_to_read; @@ -496,23 +496,22 @@ pub const DecodeState = struct { pub fn decodeLiteralsRingBuffer( self: *DecodeState, dest: *RingBuffer, - literals: LiteralsSection, len: usize, ) DecodeLiteralsError!void { - if (self.literal_written_count + len > literals.header.regenerated_size) + if (self.literal_written_count + len > self.literal_header.regenerated_size) return error.MalformedLiteralsLength; - switch (literals.header.block_type) { + switch (self.literal_header.block_type) { .raw => { const literals_end = self.literal_written_count + len; - const literal_data = literals.streams.one[self.literal_written_count..literals_end]; + const literal_data = self.literal_streams.one[self.literal_written_count..literals_end]; dest.writeSliceAssumeCapacity(literal_data); self.literal_written_count += len; }, .rle => { var i: usize = 0; while (i < len) : (i += 1) { - dest.writeAssumeCapacity(literals.streams.one[0]); + dest.writeAssumeCapacity(self.literal_streams.one[0]); } self.literal_written_count += len; }, @@ -531,7 +530,7 @@ pub const DecodeState = struct { while (i < len) : (i += 1) { var prefix: u16 = 0; while (true) { - const new_bits = try self.readLiteralsBits(u16, bit_count_to_read, literals); + const new_bits = try self.readLiteralsBits(u16, bit_count_to_read); prefix <<= bit_count_to_read; prefix |= new_bits; bits_read += bit_count_to_read; @@ -569,10 +568,6 @@ pub const DecodeState = struct { } }; -const literal_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.literal; -const match_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match; -const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match; - pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 { const hash = hasher.final(); return @intCast(u32, hash & 0xFFFFFFFF); @@ -625,6 +620,31 @@ pub fn decodeZStandardFrame( return ReadWriteCount{ .read_count = consumed_count, .write_count = written_count }; } +pub const FrameContext = struct { + hasher_opt: ?std.hash.XxHash64, + window_size: usize, + has_checksum: bool, + block_size_max: usize, + + pub fn init(frame_header: frame.ZStandard.Header, window_size_max: usize, verify_checksum: bool) !FrameContext { + if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported; + + const window_size_raw = frameWindowSize(frame_header) orelse return error.WindowSizeUnknown; + const window_size = if (window_size_raw > window_size_max) + return error.WindowTooLarge + else + @intCast(usize, window_size_raw); + + const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum; + return .{ + .hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null, + .window_size = window_size, + .has_checksum = frame_header.descriptor.content_checksum_flag, + .block_size_max = @min(1 << 17, window_size), + }; + } +}; + /// Decode a Zstandard from from `src` and return the decompressed bytes; see /// `decodeZStandardFrame()`. Returns `error.WindowSizeUnknown` if the frame /// does not declare its content size or a window descriptor (this indicates a @@ -639,33 +659,18 @@ pub fn decodeZStandardFrameAlloc( assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number); var consumed_count: usize = 4; - const frame_header = try decodeZStandardHeader(src[consumed_count..], &consumed_count); - - if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported; - - const window_size_raw = frameWindowSize(frame_header) orelse return error.WindowSizeUnknown; - const window_size = if (window_size_raw > window_size_max) - return error.WindowTooLarge - else - @intCast(usize, window_size_raw); - - const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum; - var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null; - - const block_size_maximum = @min(1 << 17, window_size); - - var window_data = try allocator.alloc(u8, window_size); - defer allocator.free(window_data); - var ring_buffer = RingBuffer{ - .data = window_data, - .write_index = 0, - .read_index = 0, + var frame_context = context: { + const frame_header = try decodeZStandardHeader(src[consumed_count..], &consumed_count); + break :context try FrameContext.init(frame_header, window_size_max, verify_checksum); }; + var ring_buffer = try RingBuffer.init(allocator, frame_context.window_size); + defer ring_buffer.deinit(allocator); + // These tables take 7680 bytes - var literal_fse_data: [literal_table_size_max]Table.Fse = undefined; - var match_fse_data: [match_table_size_max]Table.Fse = undefined; - var offset_fse_data: [offset_table_size_max]Table.Fse = undefined; + var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined; + var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined; + var offset_fse_data: [types.compressed_block.table_size_max.offset]Table.Fse = undefined; var block_header = decodeBlockHeader(src[consumed_count..][0..3]); consumed_count += 3; @@ -687,6 +692,8 @@ pub fn decodeZStandardFrameAlloc( .fse_tables_undefined = true, .literal_written_count = 0, + .literal_header = undefined, + .literal_streams = undefined, .literal_stream_reader = undefined, .literal_stream_index = undefined, .huffman_tree = null, @@ -695,30 +702,29 @@ pub fn decodeZStandardFrameAlloc( block_header = decodeBlockHeader(src[consumed_count..][0..3]); consumed_count += 3; }) { - if (block_header.block_size > block_size_maximum) return error.BlockSizeOverMaximum; + if (block_header.block_size > frame_context.block_size_max) return error.BlockSizeOverMaximum; const written_size = try decodeBlockRingBuffer( &ring_buffer, src[consumed_count..], block_header, &decode_state, &consumed_count, - block_size_maximum, + frame_context.block_size_max, ); - if (written_size > block_size_maximum) return error.BlockSizeOverMaximum; const written_slice = ring_buffer.sliceLast(written_size); try result.appendSlice(written_slice.first); try result.appendSlice(written_slice.second); - if (hasher_opt) |*hasher| { + if (frame_context.hasher_opt) |*hasher| { hasher.update(written_slice.first); hasher.update(written_slice.second); } if (block_header.last_block) break; } - if (frame_header.descriptor.content_checksum_flag) { + if (frame_context.has_checksum) { const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]); consumed_count += 4; - if (hasher_opt) |*hasher| { + if (frame_context.hasher_opt) |*hasher| { if (checksum != computeChecksum(hasher)) return error.ChecksumFailure; } } @@ -741,9 +747,9 @@ pub fn decodeFrameBlocks( hash: ?*std.hash.XxHash64, ) DecodeBlockError!usize { // These tables take 7680 bytes - var literal_fse_data: [literal_table_size_max]Table.Fse = undefined; - var match_fse_data: [match_table_size_max]Table.Fse = undefined; - var offset_fse_data: [offset_table_size_max]Table.Fse = undefined; + var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined; + var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined; + var offset_fse_data: [types.compressed_block.table_size_max.offset]Table.Fse = undefined; var block_header = decodeBlockHeader(src[0..3]); var bytes_read: usize = 3; @@ -766,6 +772,8 @@ pub fn decodeFrameBlocks( .fse_tables_undefined = true, .literal_written_count = 0, + .literal_header = undefined, + .literal_streams = undefined, .literal_stream_reader = undefined, .literal_stream_index = undefined, .huffman_tree = null, @@ -867,7 +875,8 @@ pub fn decodeBlock( .compressed => { if (src.len < block_size) return error.MalformedBlockSize; var bytes_read: usize = 0; - const literals = decodeLiteralsSection(src, &bytes_read) catch return error.MalformedCompressedBlock; + const literals = decodeLiteralsSection(src, &bytes_read) catch + return error.MalformedCompressedBlock; const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch return error.MalformedCompressedBlock; @@ -889,7 +898,6 @@ pub fn decodeBlock( const decompressed_size = decode_state.decodeSequenceSlice( dest, write_pos, - literals, &bit_stream, sequence_size_limit, i == sequences_header.sequence_count - 1, @@ -903,12 +911,11 @@ pub fn decodeBlock( if (decode_state.literal_written_count < literals.header.regenerated_size) { const len = literals.header.regenerated_size - decode_state.literal_written_count; - decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len) catch + decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], len) catch return error.MalformedCompressedBlock; bytes_written += len; } - decode_state.literal_written_count = 0; assert(bytes_read == block_header.block_size); consumed_count.* += bytes_read; return bytes_written; @@ -936,7 +943,8 @@ pub fn decodeBlockRingBuffer( .compressed => { if (src.len < block_size) return error.MalformedBlockSize; var bytes_read: usize = 0; - const literals = decodeLiteralsSection(src, &bytes_read) catch return error.MalformedCompressedBlock; + const literals = decodeLiteralsSection(src, &bytes_read) catch + return error.MalformedCompressedBlock; const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch return error.MalformedCompressedBlock; @@ -956,7 +964,6 @@ pub fn decodeBlockRingBuffer( while (i < sequences_header.sequence_count) : (i += 1) { const decompressed_size = decode_state.decodeSequenceRingBuffer( dest, - literals, &bit_stream, sequence_size_limit, i == sequences_header.sequence_count - 1, @@ -970,14 +977,14 @@ pub fn decodeBlockRingBuffer( if (decode_state.literal_written_count < literals.header.regenerated_size) { const len = literals.header.regenerated_size - decode_state.literal_written_count; - decode_state.decodeLiteralsRingBuffer(dest, literals, len) catch + decode_state.decodeLiteralsRingBuffer(dest, len) catch return error.MalformedCompressedBlock; bytes_written += len; } - decode_state.literal_written_count = 0; assert(bytes_read == block_header.block_size); consumed_count.* += bytes_read; + if (bytes_written > block_size_max) return error.BlockSizeOverMaximum; return bytes_written; }, .reserved => return error.ReservedBlock, diff --git a/lib/std/compress/zstandard/types.zig b/lib/std/compress/zstandard/types.zig index 37a716a9d7..d94a55ebe5 100644 --- a/lib/std/compress/zstandard/types.zig +++ b/lib/std/compress/zstandard/types.zig @@ -386,6 +386,11 @@ pub const compressed_block = struct { pub const match = 6; pub const offset = 5; }; + pub const table_size_max = struct { + pub const literal = 1 << table_accuracy_log_max.literal; + pub const match = 1 << table_accuracy_log_max.match; + pub const offset = 1 << table_accuracy_log_max.match; + }; }; test {