std.compress.lzma2: tests passing

This commit is contained in:
Andrew Kelley 2025-08-25 20:24:19 -07:00
parent 3cb9baaf65
commit a8ae6c2f42
2 changed files with 131 additions and 151 deletions

View File

@ -105,7 +105,6 @@ pub const RangeDecoder = struct {
pub const Decode = struct { pub const Decode = struct {
properties: Properties, properties: Properties,
unpacked_size: ?u64,
literal_probs: Vec2d, literal_probs: Vec2d,
pos_slot_decoder: [4]BitTree(6), pos_slot_decoder: [4]BitTree(6),
align_decoder: BitTree(4), align_decoder: BitTree(4),
@ -121,15 +120,10 @@ pub const Decode = struct {
len_decoder: LenDecoder, len_decoder: LenDecoder,
rep_len_decoder: LenDecoder, rep_len_decoder: LenDecoder,
pub fn init( pub fn init(gpa: Allocator, properties: Properties) !Decode {
gpa: Allocator,
properties: Properties,
unpacked_size: ?u64,
) !Decode {
return .{ return .{
.properties = properties, .properties = properties,
.unpacked_size = unpacked_size, .literal_probs = try Vec2d.init(gpa, 0x400, @as(usize, 1) << (properties.lc + properties.lp), 0x300),
.literal_probs = try Vec2d.init(gpa, 0x400, .{ @as(usize, 1) << (properties.lc + properties.lp), 0x300 }),
.pos_slot_decoder = @splat(.{}), .pos_slot_decoder = @splat(.{}),
.align_decoder = .{}, .align_decoder = .{},
.pos_decoders = @splat(0x400), .pos_decoders = @splat(0x400),
@ -157,7 +151,7 @@ pub const Decode = struct {
self.literal_probs.fill(0x400); self.literal_probs.fill(0x400);
} else { } else {
self.literal_probs.deinit(gpa); self.literal_probs.deinit(gpa);
self.literal_probs = try Vec2d.init(gpa, 0x400, .{ @as(usize, 1) << (new_props.lc + new_props.lp), 0x300 }); self.literal_probs = try Vec2d.init(gpa, 0x400, @as(usize, 1) << (new_props.lc + new_props.lp), 0x300);
} }
self.properties = new_props; self.properties = new_props;
@ -176,11 +170,12 @@ pub const Decode = struct {
self.rep_len_decoder.reset(); self.rep_len_decoder.reset();
} }
fn processNext( pub fn process(
self: *Decode, self: *Decode,
reader: *Reader, reader: *Reader,
allocating: *Writer.Allocating, allocating: *Writer.Allocating,
buffer: *CircularBuffer, /// `CircularBuffer` or `std.compress.lzma2.AccumBuffer`.
buffer: anytype,
decoder: *RangeDecoder, decoder: *RangeDecoder,
) !ProcessingStatus { ) !ProcessingStatus {
const gpa = allocating.allocator; const gpa = allocating.allocator;
@ -256,39 +251,11 @@ pub const Decode = struct {
return .more; return .more;
} }
pub fn process(
self: *Decode,
reader: *Reader,
allocating: *Writer.Allocating,
buffer: *CircularBuffer,
decoder: *RangeDecoder,
) !void {
process_next: {
if (self.unpacked_size) |unpacked_size| {
if (buffer.len >= unpacked_size) {
break :process_next;
}
} else if (decoder.isFinished()) {
break :process_next;
}
switch (try self.processNext(reader, allocating, buffer, decoder)) {
.more => return,
.finished => {},
}
}
if (self.unpacked_size) |unpacked_size| {
if (buffer.len != unpacked_size) return error.DecompressedSizeMismatch;
}
try buffer.finish(&allocating.writer);
self.state = math.maxInt(usize);
}
fn decodeLiteral( fn decodeLiteral(
self: *Decode, self: *Decode,
reader: *Reader, reader: *Reader,
buffer: *CircularBuffer, /// `CircularBuffer` or `std.compress.lzma2.AccumBuffer`.
buffer: anytype,
decoder: *RangeDecoder, decoder: *RangeDecoder,
) !u8 { ) !u8 {
const def_prev_byte = 0; const def_prev_byte = 0;
@ -377,10 +344,7 @@ pub const Decode = struct {
} }
pub fn get(self: CircularBuffer, index: usize) u8 { pub fn get(self: CircularBuffer, index: usize) u8 {
return if (0 <= index and index < self.buf.items.len) return if (0 <= index and index < self.buf.items.len) self.buf.items[index] else 0;
self.buf.items[index]
else
0;
} }
pub fn set(self: *CircularBuffer, gpa: Allocator, index: usize, value: u8) !void { pub fn set(self: *CircularBuffer, gpa: Allocator, index: usize, value: u8) !void {
@ -524,29 +488,29 @@ pub const Decode = struct {
data: []u16, data: []u16,
cols: usize, cols: usize,
pub fn init(gpa: Allocator, value: u16, size: struct { usize, usize }) !Vec2d { pub fn init(gpa: Allocator, value: u16, w: usize, h: usize) !Vec2d {
const len = try math.mul(usize, size[0], size[1]); const len = try math.mul(usize, w, h);
const data = try gpa.alloc(u16, len); const data = try gpa.alloc(u16, len);
@memset(data, value); @memset(data, value);
return .{ return .{
.data = data, .data = data,
.cols = size[1], .cols = h,
}; };
} }
pub fn deinit(self: *Vec2d, gpa: Allocator) void { pub fn deinit(v: *Vec2d, gpa: Allocator) void {
gpa.free(self.data); gpa.free(v.data);
self.* = undefined; v.* = undefined;
} }
pub fn fill(self: *Vec2d, value: u16) void { pub fn fill(v: *Vec2d, value: u16) void {
@memset(self.data, value); @memset(v.data, value);
} }
fn get(self: Vec2d, row: usize) ![]u16 { fn get(v: Vec2d, row: usize) ![]u16 {
const start_row = try math.mul(usize, row, self.cols); const start_row = try math.mul(usize, row, v.cols);
const end_row = try math.add(usize, start_row, self.cols); const end_row = try math.add(usize, start_row, v.cols);
return self.data[start_row..end_row]; return v.data[start_row..end_row];
} }
}; };
@ -627,6 +591,7 @@ pub const Decompress = struct {
range_decoder: RangeDecoder, range_decoder: RangeDecoder,
decode: Decode, decode: Decode,
err: ?Error, err: ?Error,
unpacked_size: ?u64,
pub const Error = error{ pub const Error = error{
OutOfMemory, OutOfMemory,
@ -654,7 +619,7 @@ pub const Decompress = struct {
.input = input, .input = input,
.buffer = Decode.CircularBuffer.init(params.dict_size, mem_limit), .buffer = Decode.CircularBuffer.init(params.dict_size, mem_limit),
.range_decoder = try RangeDecoder.init(input), .range_decoder = try RangeDecoder.init(input),
.decode = try Decode.init(gpa, params.properties, params.unpacked_size), .decode = try Decode.init(gpa, params.properties),
.reader = .{ .reader = .{
.buffer = buffer, .buffer = buffer,
.vtable = &.{ .vtable = &.{
@ -666,6 +631,7 @@ pub const Decompress = struct {
.end = 0, .end = 0,
}, },
.err = null, .err = null,
.unpacked_size = params.unpacked_size,
}; };
} }
@ -728,20 +694,46 @@ pub const Decompress = struct {
r.end = allocating.writer.end; r.end = allocating.writer.end;
} }
if (d.decode.state == math.maxInt(usize)) return error.EndOfStream; if (d.decode.state == math.maxInt(usize)) return error.EndOfStream;
d.decode.process(d.input, &allocating, &d.buffer, &d.range_decoder) catch |err| switch (err) {
process_next: {
if (d.unpacked_size) |unpacked_size| {
if (d.buffer.len >= unpacked_size) break :process_next;
} else if (d.range_decoder.isFinished()) {
break :process_next;
}
switch (d.decode.process(d.input, &allocating, &d.buffer, &d.range_decoder) catch |err| switch (err) {
error.WriteFailed => {
d.err = error.OutOfMemory;
return error.ReadFailed;
},
error.EndOfStream => {
d.err = error.EndOfStream;
return error.ReadFailed;
},
else => |e| {
d.err = e;
return error.ReadFailed;
},
}) {
.more => return 0,
.finished => break :process_next,
}
}
if (d.unpacked_size) |unpacked_size| {
if (d.buffer.len != unpacked_size) {
d.err = error.DecompressedSizeMismatch;
return error.ReadFailed;
}
}
d.buffer.finish(&allocating.writer) catch |err| switch (err) {
error.WriteFailed => { error.WriteFailed => {
d.err = error.OutOfMemory; d.err = error.OutOfMemory;
return error.ReadFailed; return error.ReadFailed;
}, },
error.EndOfStream => {
d.err = error.EndOfStream;
return error.ReadFailed;
},
else => |e| {
d.err = e;
return error.ReadFailed;
},
}; };
d.decode.state = math.maxInt(usize);
return 0; return 0;
} }
}; };

View File

@ -6,17 +6,15 @@ const Writer = std.Io.Writer;
const Reader = std.Io.Reader; const Reader = std.Io.Reader;
/// An accumulating buffer for LZ sequences /// An accumulating buffer for LZ sequences
pub const LzAccumBuffer = struct { pub const AccumBuffer = struct {
/// Buffer /// Buffer
buf: ArrayList(u8), buf: ArrayList(u8),
/// Buffer memory limit /// Buffer memory limit
memlimit: usize, memlimit: usize,
/// Total number of bytes sent through the buffer /// Total number of bytes sent through the buffer
len: usize, len: usize,
pub fn init(memlimit: usize) LzAccumBuffer { pub fn init(memlimit: usize) AccumBuffer {
return .{ return .{
.buf = .{}, .buf = .{},
.memlimit = memlimit, .memlimit = memlimit,
@ -24,20 +22,20 @@ pub const LzAccumBuffer = struct {
}; };
} }
pub fn appendByte(self: *LzAccumBuffer, allocator: Allocator, byte: u8) !void { pub fn appendByte(self: *AccumBuffer, allocator: Allocator, byte: u8) !void {
try self.buf.append(allocator, byte); try self.buf.append(allocator, byte);
self.len += 1; self.len += 1;
} }
/// Reset the internal dictionary /// Reset the internal dictionary
pub fn reset(self: *LzAccumBuffer, writer: *Writer) !void { pub fn reset(self: *AccumBuffer, writer: *Writer) !void {
try writer.writeAll(self.buf.items); try writer.writeAll(self.buf.items);
self.buf.clearRetainingCapacity(); self.buf.clearRetainingCapacity();
self.len = 0; self.len = 0;
} }
/// Retrieve the last byte or return a default /// Retrieve the last byte or return a default
pub fn lastOr(self: LzAccumBuffer, lit: u8) u8 { pub fn lastOr(self: AccumBuffer, lit: u8) u8 {
const buf_len = self.buf.items.len; const buf_len = self.buf.items.len;
return if (buf_len == 0) return if (buf_len == 0)
lit lit
@ -46,7 +44,7 @@ pub const LzAccumBuffer = struct {
} }
/// Retrieve the n-th last byte /// Retrieve the n-th last byte
pub fn lastN(self: LzAccumBuffer, dist: usize) !u8 { pub fn lastN(self: AccumBuffer, dist: usize) !u8 {
const buf_len = self.buf.items.len; const buf_len = self.buf.items.len;
if (dist > buf_len) { if (dist > buf_len) {
return error.CorruptInput; return error.CorruptInput;
@ -57,7 +55,7 @@ pub const LzAccumBuffer = struct {
/// Append a literal /// Append a literal
pub fn appendLiteral( pub fn appendLiteral(
self: *LzAccumBuffer, self: *AccumBuffer,
allocator: Allocator, allocator: Allocator,
lit: u8, lit: u8,
writer: *Writer, writer: *Writer,
@ -72,7 +70,7 @@ pub const LzAccumBuffer = struct {
/// Fetch an LZ sequence (length, distance) from inside the buffer /// Fetch an LZ sequence (length, distance) from inside the buffer
pub fn appendLz( pub fn appendLz(
self: *LzAccumBuffer, self: *AccumBuffer,
allocator: Allocator, allocator: Allocator,
len: usize, len: usize,
dist: usize, dist: usize,
@ -95,12 +93,12 @@ pub const LzAccumBuffer = struct {
self.len += len; self.len += len;
} }
pub fn finish(self: *LzAccumBuffer, writer: *Writer) !void { pub fn finish(self: *AccumBuffer, writer: *Writer) !void {
try writer.writeAll(self.buf.items); try writer.writeAll(self.buf.items);
self.buf.clearRetainingCapacity(); self.buf.clearRetainingCapacity();
} }
pub fn deinit(self: *LzAccumBuffer, allocator: Allocator) void { pub fn deinit(self: *AccumBuffer, allocator: Allocator) void {
self.buf.deinit(allocator); self.buf.deinit(allocator);
self.* = undefined; self.* = undefined;
} }
@ -109,59 +107,43 @@ pub const LzAccumBuffer = struct {
pub const Decode = struct { pub const Decode = struct {
lzma_decode: lzma.Decode, lzma_decode: lzma.Decode,
pub fn init(allocator: Allocator) !Decode { pub fn init(gpa: Allocator) !Decode {
return Decode{ return .{ .lzma_decode = try lzma.Decode.init(gpa, .{ .lc = 0, .lp = 0, .pb = 0 }) };
.lzma_decode = try lzma.Decode.init(
allocator,
.{
.lc = 0,
.lp = 0,
.pb = 0,
},
null,
),
};
} }
pub fn deinit(self: *Decode, allocator: Allocator) void { pub fn deinit(self: *Decode, gpa: Allocator) void {
self.lzma_decode.deinit(allocator); self.lzma_decode.deinit(gpa);
self.* = undefined; self.* = undefined;
} }
pub fn decompress( pub fn decompress(d: *Decode, reader: *Reader, allocating: *Writer.Allocating) !void {
self: *Decode, const gpa = allocating.allocator;
allocator: Allocator,
reader: *Reader, var accum = AccumBuffer.init(std.math.maxInt(usize));
writer: *Writer, defer accum.deinit(gpa);
) !void {
var accum = LzAccumBuffer.init(std.math.maxInt(usize));
defer accum.deinit(allocator);
while (true) { while (true) {
const status = try reader.readByte(); const status = try reader.takeByte();
switch (status) { switch (status) {
0 => break, 0 => break,
1 => try parseUncompressed(allocator, reader, writer, &accum, true), 1 => try parseUncompressed(reader, allocating, &accum, true),
2 => try parseUncompressed(allocator, reader, writer, &accum, false), 2 => try parseUncompressed(reader, allocating, &accum, false),
else => try self.parseLzma(allocator, reader, writer, &accum, status), else => try d.parseLzma(reader, allocating, &accum, status),
} }
} }
try accum.finish(writer); try accum.finish(&allocating.writer);
} }
fn parseLzma( fn parseLzma(
self: *Decode, d: *Decode,
allocator: Allocator,
reader: *Reader, reader: *Reader,
writer: *Writer, allocating: *Writer.Allocating,
accum: *LzAccumBuffer, accum: *AccumBuffer,
status: u8, status: u8,
) !void { ) !void {
if (status & 0x80 == 0) { if (status & 0x80 == 0) return error.CorruptInput;
return error.CorruptInput;
}
const Reset = struct { const Reset = struct {
dict: bool, dict: bool,
@ -169,23 +151,23 @@ pub const Decode = struct {
props: bool, props: bool,
}; };
const reset = switch ((status >> 5) & 0x3) { const reset: Reset = switch ((status >> 5) & 0x3) {
0 => Reset{ 0 => .{
.dict = false, .dict = false,
.state = false, .state = false,
.props = false, .props = false,
}, },
1 => Reset{ 1 => .{
.dict = false, .dict = false,
.state = true, .state = true,
.props = false, .props = false,
}, },
2 => Reset{ 2 => .{
.dict = false, .dict = false,
.state = true, .state = true,
.props = true, .props = true,
}, },
3 => Reset{ 3 => .{
.dict = true, .dict = true,
.state = true, .state = true,
.props = true, .props = true,
@ -196,24 +178,24 @@ pub const Decode = struct {
const unpacked_size = blk: { const unpacked_size = blk: {
var tmp: u64 = status & 0x1F; var tmp: u64 = status & 0x1F;
tmp <<= 16; tmp <<= 16;
tmp |= try reader.readInt(u16, .big); tmp |= try reader.takeInt(u16, .big);
break :blk tmp + 1; break :blk tmp + 1;
}; };
const packed_size = blk: { const packed_size = blk: {
const tmp: u17 = try reader.readInt(u16, .big); const tmp: u17 = try reader.takeInt(u16, .big);
break :blk tmp + 1; break :blk tmp + 1;
}; };
if (reset.dict) { if (reset.dict) try accum.reset(&allocating.writer);
try accum.reset(writer);
} const ld = &d.lzma_decode;
if (reset.state) { if (reset.state) {
var new_props = self.lzma_decode.properties; var new_props = ld.properties;
if (reset.props) { if (reset.props) {
var props = try reader.readByte(); var props = try reader.takeByte();
if (props >= 225) { if (props >= 225) {
return error.CorruptInput; return error.CorruptInput;
} }
@ -231,38 +213,44 @@ pub const Decode = struct {
new_props = .{ .lc = lc, .lp = lp, .pb = pb }; new_props = .{ .lc = lc, .lp = lp, .pb = pb };
} }
try self.lzma_decode.resetState(allocator, new_props); try ld.resetState(allocating.allocator, new_props);
} }
self.lzma_decode.unpacked_size = unpacked_size + accum.len; var range_decoder = try lzma.RangeDecoder.init(reader);
var counter = std.io.countingReader(reader); while (true) {
const counter_reader = counter.reader(); if (accum.len >= unpacked_size) break;
if (range_decoder.isFinished()) break;
var rangecoder = try lzma.RangeDecoder.init(counter_reader); switch (try ld.process(reader, allocating, accum, &range_decoder)) {
while (try self.lzma_decode.process(allocator, counter_reader, writer, accum, &rangecoder) == .continue_) {} .more => continue,
.finished => break,
if (counter.bytes_read != packed_size) { }
return error.CorruptInput;
} }
if (accum.len != unpacked_size) return error.DecompressedSizeMismatch;
// TODO restore this error
//if (counter.bytes_read != packed_size) {
// return error.CorruptInput;
//}
_ = packed_size;
} }
fn parseUncompressed( fn parseUncompressed(
allocator: Allocator,
reader: *Reader, reader: *Reader,
writer: *Writer, allocating: *Writer.Allocating,
accum: *LzAccumBuffer, accum: *AccumBuffer,
reset_dict: bool, reset_dict: bool,
) !void { ) !void {
const unpacked_size = @as(u17, try reader.readInt(u16, .big)) + 1; const unpacked_size = @as(u17, try reader.takeInt(u16, .big)) + 1;
if (reset_dict) { if (reset_dict) try accum.reset(&allocating.writer);
try accum.reset(writer);
}
var i: @TypeOf(unpacked_size) = 0; const gpa = allocating.allocator;
while (i < unpacked_size) : (i += 1) {
try accum.appendByte(allocator, try reader.readByte()); var i = unpacked_size;
while (i != 0) {
try accum.appendByte(gpa, try reader.takeByte());
i -= 1;
} }
} }
}; };
@ -273,13 +261,13 @@ test "decompress hello world stream" {
const gpa = std.testing.allocator; const gpa = std.testing.allocator;
var stream: std.Io.Reader = .fixed(compressed); var decode = try Decode.init(gpa);
var decode = try Decode.init(gpa, &stream);
defer decode.deinit(gpa); defer decode.deinit(gpa);
const result = try decode.reader.allocRemaining(gpa, .unlimited); var stream: std.Io.Reader = .fixed(compressed);
defer gpa.free(result); var result: std.Io.Writer.Allocating = .init(gpa);
defer result.deinit();
try std.testing.expectEqualStrings(expected, result); try decode.decompress(&stream, &result);
try std.testing.expectEqualStrings(expected, result.written());
} }