diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index ea72e33730..090110f9d0 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -573,6 +573,11 @@ const literal_table_size_max = 1 << types.compressed_block.table_accuracy_log_ma const match_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match; const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match; +pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 { + const hash = hasher.final(); + return @intCast(u32, hash & 0xFFFFFFFF); +} + const FrameError = error{ DictionaryIdFlagUnsupported, ChecksumFailure, @@ -601,24 +606,20 @@ pub fn decodeZStandardFrame( if (dest.len < content_size) return error.ContentTooLarge; const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum; - var hash_state = if (should_compute_checksum) std.hash.XxHash64.init(0) else undefined; + var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null; const written_count = try decodeFrameBlocks( dest, src[consumed_count..], &consumed_count, - if (should_compute_checksum) &hash_state else null, + if (hasher_opt) |*hasher| hasher else null, ); if (frame_header.descriptor.content_checksum_flag) { const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]); consumed_count += 4; - if (verify_checksum) { - const hash = hash_state.final(); - const hash_low_bytes = hash & 0xFFFFFFFF; - if (checksum != hash_low_bytes) { - return error.ChecksumFailure; - } + if (hasher_opt) |*hasher| { + if (checksum != computeChecksum(hasher)) return error.ChecksumFailure; } } return ReadWriteCount{ .read_count = consumed_count, .write_count = written_count }; @@ -649,7 +650,7 @@ pub fn decodeZStandardFrameAlloc( @intCast(usize, window_size_raw); const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum; - var hash = if (should_compute_checksum) std.hash.XxHash64.init(0) else null; + var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null; const block_size_maximum = @min(1 << 17, window_size); @@ -707,12 +708,20 @@ pub fn decodeZStandardFrameAlloc( const written_slice = ring_buffer.sliceLast(written_size); try result.appendSlice(written_slice.first); try result.appendSlice(written_slice.second); - if (hash) |*hash_state| { - hash_state.update(written_slice.first); - hash_state.update(written_slice.second); + if (hasher_opt) |*hasher| { + hasher.update(written_slice.first); + hasher.update(written_slice.second); } if (block_header.last_block) break; } + + if (frame_header.descriptor.content_checksum_flag) { + const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]); + consumed_count += 4; + if (hasher_opt) |*hasher| { + if (checksum != computeChecksum(hasher)) return error.ChecksumFailure; + } + } return result.toOwnedSlice(); }