diff --git a/lib/std/compress/lzma.zig b/lib/std/compress/lzma.zig index 6e7c29684b..8bb8c19da1 100644 --- a/lib/std/compress/lzma.zig +++ b/lib/std/compress/lzma.zig @@ -1,4 +1,6 @@ const std = @import("../std.zig"); +const math = std.math; +const mem = std.mem; const Allocator = std.mem.Allocator; pub const decode = @import("lzma/decode.zig"); @@ -6,13 +8,80 @@ pub const decode = @import("lzma/decode.zig"); pub fn decompress( allocator: Allocator, reader: anytype, - writer: anytype, +) !Decompress(@TypeOf(reader)) { + return decompressWithOptions(allocator, reader, .{}); +} + +pub fn decompressWithOptions( + allocator: Allocator, + reader: anytype, options: decode.Options, -) !void { +) !Decompress(@TypeOf(reader)) { const params = try decode.Params.readHeader(reader, options); - var decoder = try decode.Decoder.init(allocator, params, options.memlimit); - defer decoder.deinit(allocator); - return decoder.decompress(allocator, reader, writer); + return Decompress(@TypeOf(reader)).init(allocator, reader, params, options.memlimit); +} + +pub fn Decompress(comptime ReaderType: type) type { + return struct { + const Self = @This(); + + pub const Error = + ReaderType.Error || + Allocator.Error || + error{ CorruptInput, EndOfStream, Overflow }; + + pub const Reader = std.io.Reader(*Self, Error, read); + + allocator: Allocator, + in_reader: ReaderType, + to_read: std.ArrayListUnmanaged(u8), + + buffer: decode.lzbuffer.LzCircularBuffer, + decoder: decode.rangecoder.RangeDecoder, + state: decode.DecoderState, + + pub fn init(allocator: Allocator, source: ReaderType, params: decode.Params, memlimit: ?usize) !Self { + return Self{ + .allocator = allocator, + .in_reader = source, + .to_read = .{}, + + .buffer = decode.lzbuffer.LzCircularBuffer.init(params.dict_size, memlimit orelse math.maxInt(usize)), + .decoder = try decode.rangecoder.RangeDecoder.init(source), + .state = try decode.DecoderState.init(allocator, params.properties, params.unpacked_size), + }; + } + + pub fn reader(self: *Self) Reader { + return .{ .context = self }; + } + + pub fn deinit(self: *Self) void { + self.to_read.deinit(self.allocator); + self.buffer.deinit(self.allocator); + self.state.deinit(self.allocator); + self.* = undefined; + } + + pub fn read(self: *Self, output: []u8) Error!usize { + const writer = self.to_read.writer(self.allocator); + while (self.to_read.items.len < output.len) { + switch (try self.state.process(self.allocator, self.in_reader, writer, &self.buffer, &self.decoder)) { + .continue_ => {}, + .finished => { + try self.buffer.finish(writer); + break; + }, + } + } + const input = self.to_read.items; + const n = math.min(input.len, output.len); + mem.copy(u8, output[0..n], input[0..n]); + mem.copy(u8, input, input[n..]); + self.to_read.shrinkRetainingCapacity(input.len - n); + return n; + } + }; } test { diff --git a/lib/std/compress/lzma/decode.zig b/lib/std/compress/lzma/decode.zig index 31a676b40a..6c9a3ae862 100644 --- a/lib/std/compress/lzma/decode.zig +++ b/lib/std/compress/lzma/decode.zig @@ -280,26 +280,29 @@ pub const DecoderState = struct { writer: anytype, buffer: anytype, decoder: *RangeDecoder, - ) !void { - while (true) { + ) !ProcessingStatus { + process_next: { if (self.unpacked_size) |unpacked_size| { if (buffer.len >= unpacked_size) { - break; + break :process_next; } } else if (decoder.isFinished()) { - break; + break :process_next; } - if (try self.processNext(allocator, reader, writer, buffer, decoder) == .finished) { - break; + switch (try self.processNext(allocator, reader, writer, buffer, decoder)) { + .continue_ => return .continue_, + .finished => break :process_next, } } - if (self.unpacked_size) |len| { - if (len != buffer.len) { + if (self.unpacked_size) |unpacked_size| { + if (buffer.len != unpacked_size) { return error.CorruptInput; } } + + return .finished; } fn decodeLiteral( @@ -374,36 +377,3 @@ pub const DecoderState = struct { return result; } }; - -pub const Decoder = struct { - params: Params, - memlimit: usize, - state: DecoderState, - - pub fn init(allocator: Allocator, params: Params, memlimit: ?usize) !Decoder { - return Decoder{ - .params = params, - .memlimit = memlimit orelse math.maxInt(usize), - .state = try DecoderState.init(allocator, params.properties, params.unpacked_size), - }; - } - - pub fn deinit(self: *Decoder, allocator: Allocator) void { - self.state.deinit(allocator); - self.* = undefined; - } - - pub fn decompress( - self: *Decoder, - allocator: Allocator, - reader: anytype, - writer: anytype, - ) !void { - var buffer = LzCircularBuffer.init(self.params.dict_size, self.memlimit); - defer buffer.deinit(allocator); - - var decoder = try RangeDecoder.init(reader); - try self.state.process(allocator, reader, writer, &buffer, &decoder); - try buffer.finish(writer); - } -}; diff --git a/lib/std/compress/lzma/decode/lzbuffer.zig b/lib/std/compress/lzma/decode/lzbuffer.zig index 4063c7993a..80c470c5f9 100644 --- a/lib/std/compress/lzma/decode/lzbuffer.zig +++ b/lib/std/compress/lzma/decode/lzbuffer.zig @@ -98,6 +98,7 @@ pub const LzAccumBuffer = struct { pub fn finish(self: *Self, writer: anytype) !void { try writer.writeAll(self.buf.items); + self.buf.clearRetainingCapacity(); } pub fn deinit(self: *Self, allocator: Allocator) void { @@ -216,6 +217,7 @@ pub const LzCircularBuffer = struct { pub fn finish(self: *Self, writer: anytype) !void { if (self.cursor > 0) { try writer.writeAll(self.buf.items[0..self.cursor]); + self.cursor = 0; } } diff --git a/lib/std/compress/lzma/test.zig b/lib/std/compress/lzma/test.zig index e720d87222..bdfe2909d8 100644 --- a/lib/std/compress/lzma/test.zig +++ b/lib/std/compress/lzma/test.zig @@ -1,22 +1,24 @@ const std = @import("../../std.zig"); const lzma = @import("../lzma.zig"); -fn testDecompress(compressed: []const u8, writer: anytype) !void { +fn testDecompress(compressed: []const u8) ![]u8 { const allocator = std.testing.allocator; var stream = std.io.fixedBufferStream(compressed); - try lzma.decompress(allocator, stream.reader(), writer, .{}); + var decompressor = try lzma.decompress(allocator, stream.reader()); + defer decompressor.deinit(); + const reader = decompressor.reader(); + return reader.readAllAlloc(allocator, std.math.maxInt(usize)); } fn testDecompressEqual(expected: []const u8, compressed: []const u8) !void { const allocator = std.testing.allocator; - var decomp = std.ArrayList(u8).init(allocator); - defer decomp.deinit(); - try testDecompress(compressed, decomp.writer()); - try std.testing.expectEqualSlices(u8, expected, decomp.items); + const decomp = try testDecompress(compressed); + defer allocator.free(decomp); + try std.testing.expectEqualSlices(u8, expected, decomp); } fn testDecompressError(expected: anyerror, compressed: []const u8) !void { - return std.testing.expectError(expected, testDecompress(compressed, std.io.null_writer)); + return std.testing.expectError(expected, testDecompress(compressed)); } test "LZMA: decompress empty world" { diff --git a/lib/std/compress/lzma2/decode.zig b/lib/std/compress/lzma2/decode.zig index 911f0352f2..7297a1a51b 100644 --- a/lib/std/compress/lzma2/decode.zig +++ b/lib/std/compress/lzma2/decode.zig @@ -141,7 +141,7 @@ pub const Decoder = struct { const counter_reader = counter.reader(); var rangecoder = try RangeDecoder.init(counter_reader); - try self.lzma_state.process(allocator, counter_reader, writer, accum, &rangecoder); + while (try self.lzma_state.process(allocator, counter_reader, writer, accum, &rangecoder) == .continue_) {} if (counter.bytes_read != packed_size) { return error.CorruptInput;