diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index d83b3a3336..6fb6a70027 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -1,8 +1,158 @@ const std = @import("std"); +const Allocator = std.mem.Allocator; +const types = @import("zstandard/types.zig"); + +const RingBuffer = @import("zstandard/RingBuffer.zig"); pub const decompress = @import("zstandard/decompress.zig"); pub usingnamespace @import("zstandard/types.zig"); +pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool, comptime window_size_max: usize) type { + return struct { + const Self = @This(); + + allocator: Allocator, + in_reader: ReaderType, + decode_state: decompress.DecodeState, + frame_context: decompress.FrameContext, + buffer: RingBuffer, + last_block: bool, + literal_fse_buffer: []types.compressed_block.Table.Fse, + match_fse_buffer: []types.compressed_block.Table.Fse, + offset_fse_buffer: []types.compressed_block.Table.Fse, + literals_buffer: []u8, + sequence_buffer: []u8, + checksum: if (verify_checksum) ?u32 else void, + + pub const Error = ReaderType.Error || error{ MalformedBlock, MalformedFrame, EndOfStream }; + + pub const Reader = std.io.Reader(*Self, Error, read); + + pub fn init(allocator: Allocator, source: ReaderType) !Self { + switch (try decompress.decodeFrameType(source)) { + .skippable => return error.SkippableFrame, + .zstandard => { + const frame_context = context: { + const frame_header = try decompress.decodeZStandardHeader(source); + break :context try decompress.FrameContext.init(frame_header, window_size_max, verify_checksum); + }; + + const literal_fse_buffer = try allocator.alloc(types.compressed_block.Table.Fse, types.compressed_block.table_size_max.literal); + errdefer allocator.free(literal_fse_buffer); + const match_fse_buffer = try allocator.alloc(types.compressed_block.Table.Fse, types.compressed_block.table_size_max.match); + errdefer allocator.free(match_fse_buffer); + const offset_fse_buffer = try allocator.alloc(types.compressed_block.Table.Fse, types.compressed_block.table_size_max.offset); + errdefer allocator.free(offset_fse_buffer); + + const decode_state = decompress.DecodeState.init(literal_fse_buffer, match_fse_buffer, offset_fse_buffer); + const buffer = try RingBuffer.init(allocator, frame_context.window_size); + + const literals_data = try allocator.alloc(u8, window_size_max); + errdefer allocator.free(literals_data); + const sequence_data = try allocator.alloc(u8, window_size_max); + errdefer allocator.free(sequence_data); + + return Self{ + .allocator = allocator, + .in_reader = source, + .decode_state = decode_state, + .frame_context = frame_context, + .buffer = buffer, + .checksum = if (verify_checksum) null else {}, + .last_block = false, + .literal_fse_buffer = literal_fse_buffer, + .match_fse_buffer = match_fse_buffer, + .offset_fse_buffer = offset_fse_buffer, + .literals_buffer = literals_data, + .sequence_buffer = sequence_data, + }; + }, + } + } + + pub fn deinit(self: *Self) void { + self.allocator.free(self.decode_state.literal_fse_buffer); + self.allocator.free(self.decode_state.match_fse_buffer); + self.allocator.free(self.decode_state.offset_fse_buffer); + self.allocator.free(self.literals_buffer); + self.allocator.free(self.sequence_buffer); + self.buffer.deinit(self.allocator); + } + + pub fn reader(self: *Self) Reader { + return .{ .context = self }; + } + + pub fn read(self: *Self, buffer: []u8) Error!usize { + if (buffer.len == 0) return 0; + + if (self.buffer.isEmpty() and !self.last_block) { + const header_bytes = try self.in_reader.readBytesNoEof(3); + const block_header = decompress.decodeBlockHeader(&header_bytes); + + decompress.decodeBlockReader( + &self.buffer, + self.in_reader, + block_header, + &self.decode_state, + self.frame_context.block_size_max, + self.literals_buffer, + self.sequence_buffer, + ) catch + return error.MalformedBlock; + + self.last_block = block_header.last_block; + if (self.frame_context.hasher_opt) |*hasher| { + const written_slice = self.buffer.sliceLast(self.buffer.len()); + hasher.update(written_slice.first); + hasher.update(written_slice.second); + } + if (block_header.last_block and self.frame_context.has_checksum) { + const checksum = self.in_reader.readIntLittle(u32) catch return error.MalformedFrame; + if (verify_checksum) self.checksum = checksum; + } + } + + const decoded_data_len = self.buffer.len(); + var written_count: usize = 0; + while (written_count < decoded_data_len and written_count < buffer.len) : (written_count += 1) { + buffer[written_count] = self.buffer.read().?; + } + return written_count; + } + + pub fn verifyChecksum(self: *Self) !bool { + if (verify_checksum) { + if (self.checksum) |checksum| { + if (self.frame_context.hasher_opt) |*hasher| { + return checksum == decompress.computeChecksum(hasher); + } + } + } + return true; + } + }; +} + +pub fn zstandardStream(allocator: Allocator, reader: anytype) !ZstandardStream(@TypeOf(reader), true, 8 * (1 << 20)) { + return ZstandardStream(@TypeOf(reader), true, 8 * (1 << 20)).init(allocator, reader); +} + +fn testDecompress(data: []const u8) ![]u8 { + var in_stream = std.io.fixedBufferStream(data); + var stream = try zstandardStream(std.testing.allocator, in_stream.reader()); + defer stream.deinit(); + const result = stream.reader().readAllAlloc(std.testing.allocator, std.math.maxInt(usize)); + try std.testing.expect(try stream.verifyChecksum()); + return result; +} + +fn testReader(data: []const u8, comptime expected: []const u8) !void { + const buf = try testDecompress(data); + defer std.testing.allocator.free(buf); + try std.testing.expectEqualSlices(u8, expected, buf); +} + test "decompression" { const uncompressed = @embedFile("testdata/rfc8478.txt"); const compressed3 = @embedFile("testdata/rfc8478.txt.zst.3"); @@ -20,4 +170,7 @@ test "decompression" { try std.testing.expectEqual(compressed19.len, res19.read_count); try std.testing.expectEqual(uncompressed.len, res19.write_count); try std.testing.expectEqualSlices(u8, uncompressed, buffer); + + try testReader(compressed3, uncompressed); + try testReader(compressed19, uncompressed); }