zig/lib/std/compress/flate/huffman_decoder.zig
Igor Anić d645114f7e add deflate implemented from first principles
Zig deflate compression/decompression implementation. It supports compression and decompression of gzip, zlib and raw deflate format.

Fixes #18062.

This PR replaces current compress/gzip and compress/zlib packages. Deflate package is renamed to flate. Flate is common name for deflate/inflate where deflate is compression and inflate decompression.

There are breaking change. Methods signatures are changed because of removal of the allocator, and I also unified API for all three namespaces (flate, gzip, zlib).

Currently I put old packages under v1 namespace they are still available as compress/v1/gzip, compress/v1/zlib, compress/v1/deflate. Idea is to give users of the current API little time to postpone analyzing what they had to change. Although that rises question when it is safe to remove that v1 namespace.

Here is current API in the compress package:

```Zig
// deflate
    fn compressor(allocator, writer, options) !Compressor(@TypeOf(writer))
    fn Compressor(comptime WriterType) type

    fn decompressor(allocator, reader, null) !Decompressor(@TypeOf(reader))
    fn Decompressor(comptime ReaderType: type) type

// gzip
    fn compress(allocator, writer, options) !Compress(@TypeOf(writer))
    fn Compress(comptime WriterType: type) type

    fn decompress(allocator, reader) !Decompress(@TypeOf(reader))
    fn Decompress(comptime ReaderType: type) type

// zlib
    fn compressStream(allocator, writer, options) !CompressStream(@TypeOf(writer))
    fn CompressStream(comptime WriterType: type) type

    fn decompressStream(allocator, reader) !DecompressStream(@TypeOf(reader))
    fn DecompressStream(comptime ReaderType: type) type

// xz
   fn decompress(allocator: Allocator, reader: anytype) !Decompress(@TypeOf(reader))
   fn Decompress(comptime ReaderType: type) type

// lzma
    fn decompress(allocator, reader) !Decompress(@TypeOf(reader))
    fn Decompress(comptime ReaderType: type) type

// lzma2
    fn decompress(allocator, reader, writer !void

// zstandard:
    fn DecompressStream(ReaderType, options) type
    fn decompressStream(allocator, reader) DecompressStream(@TypeOf(reader), .{})
    struct decompress
```

The proposed naming convention:
 - Compressor/Decompressor for functions which return type, like Reader/Writer/GeneralPurposeAllocator
 - compressor/compressor for functions which are initializers for that type, like reader/writer/allocator
 - compress/decompress for one shot operations, accepts reader/writer pair, like read/write/alloc

```Zig
/// Compress from reader and write compressed data to the writer.
fn compress(reader: anytype, writer: anytype, options: Options) !void

/// Create Compressor which outputs the writer.
fn compressor(writer: anytype, options: Options) !Compressor(@TypeOf(writer))

/// Compressor type
fn Compressor(comptime WriterType: type) type

/// Decompress from reader and write plain data to the writer.
fn decompress(reader: anytype, writer: anytype) !void

/// Create Decompressor which reads from reader.
fn decompressor(reader: anytype) Decompressor(@TypeOf(reader)

/// Decompressor type
fn Decompressor(comptime ReaderType: type) type

```

Comparing this implementation with the one we currently have in Zig's standard library (std).
Std is roughly 1.2-1.4 times slower in decompression, and 1.1-1.2 times slower in compression. Compressed sizes are pretty much same in both cases.
More resutls in [this](https://github.com/ianic/flate) repo.

This library uses static allocations for all structures, doesn't require allocator. That makes sense especially for deflate where all structures, internal buffers are allocated to the full size. Little less for inflate where we std version uses less memory by not preallocating to theoretical max size array which are usually not fully used.

For deflate this library allocates 395K while std 779K.
For inflate this library allocates 74.5K while std around 36K.

Inflate difference is because we here use 64K history instead of 32K in std.

If merged existing usage of compress gzip/zlib/deflate need some changes. Here is example with necessary changes in comments:

```Zig

const std = @import("std");

// To get this file:
// wget -nc -O war_and_peace.txt https://www.gutenberg.org/ebooks/2600.txt.utf-8
const data = @embedFile("war_and_peace.txt");

pub fn main() !void {
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    defer std.debug.assert(gpa.deinit() == .ok);
    const allocator = gpa.allocator();

    try oldDeflate(allocator);
    try new(std.compress.flate, allocator);

    try oldZlib(allocator);
    try new(std.compress.zlib, allocator);

    try oldGzip(allocator);
    try new(std.compress.gzip, allocator);
}

pub fn new(comptime pkg: type, allocator: std.mem.Allocator) !void {
    var buf = std.ArrayList(u8).init(allocator);
    defer buf.deinit();

    // Compressor
    var cmp = try pkg.compressor(buf.writer(), .{});
    _ = try cmp.write(data);
    try cmp.finish();

    var fbs = std.io.fixedBufferStream(buf.items);
    // Decompressor
    var dcp = pkg.decompressor(fbs.reader());

    const plain = try dcp.reader().readAllAlloc(allocator, std.math.maxInt(usize));
    defer allocator.free(plain);
    try std.testing.expectEqualSlices(u8, data, plain);
}

pub fn oldDeflate(allocator: std.mem.Allocator) !void {
    const deflate = std.compress.v1.deflate;

    // Compressor
    var buf = std.ArrayList(u8).init(allocator);
    defer buf.deinit();
    // Remove allocator
    // Rename deflate -> flate
    var cmp = try deflate.compressor(allocator, buf.writer(), .{});
    _ = try cmp.write(data);
    try cmp.close(); // Rename to finish
    cmp.deinit(); // Remove

    // Decompressor
    var fbs = std.io.fixedBufferStream(buf.items);
    // Remove allocator and last param
    // Rename deflate -> flate
    // Remove try
    var dcp = try deflate.decompressor(allocator, fbs.reader(), null);
    defer dcp.deinit(); // Remove

    const plain = try dcp.reader().readAllAlloc(allocator, std.math.maxInt(usize));
    defer allocator.free(plain);
    try std.testing.expectEqualSlices(u8, data, plain);
}

pub fn oldZlib(allocator: std.mem.Allocator) !void {
    const zlib = std.compress.v1.zlib;

    var buf = std.ArrayList(u8).init(allocator);
    defer buf.deinit();

    // Compressor
    // Rename compressStream => compressor
    // Remove allocator
    var cmp = try zlib.compressStream(allocator, buf.writer(), .{});
    _ = try cmp.write(data);
    try cmp.finish();
    cmp.deinit(); // Remove

    var fbs = std.io.fixedBufferStream(buf.items);
    // Decompressor
    // decompressStream => decompressor
    // Remove allocator
    // Remove try
    var dcp = try zlib.decompressStream(allocator, fbs.reader());
    defer dcp.deinit(); // Remove

    const plain = try dcp.reader().readAllAlloc(allocator, std.math.maxInt(usize));
    defer allocator.free(plain);
    try std.testing.expectEqualSlices(u8, data, plain);
}

pub fn oldGzip(allocator: std.mem.Allocator) !void {
    const gzip = std.compress.v1.gzip;

    var buf = std.ArrayList(u8).init(allocator);
    defer buf.deinit();

    // Compressor
    // Rename compress => compressor
    // Remove allocator
    var cmp = try gzip.compress(allocator, buf.writer(), .{});
    _ = try cmp.write(data);
    try cmp.close(); // Rename to finisho
    cmp.deinit(); // Remove

    var fbs = std.io.fixedBufferStream(buf.items);
    // Decompressor
    // Rename decompress => decompressor
    // Remove allocator
    // Remove try
    var dcp = try gzip.decompress(allocator, fbs.reader());
    defer dcp.deinit(); // Remove

    const plain = try dcp.reader().readAllAlloc(allocator, std.math.maxInt(usize));
    defer allocator.free(plain);
    try std.testing.expectEqualSlices(u8, data, plain);
}

```
2024-02-14 18:28:20 +01:00

309 lines
11 KiB
Zig

const std = @import("std");
const testing = std.testing;
pub const Symbol = packed struct {
pub const Kind = enum(u2) {
literal,
end_of_block,
match,
};
symbol: u8 = 0, // symbol from alphabet
code_bits: u4 = 0, // number of bits in code 0-15
kind: Kind = .literal,
code: u16 = 0, // huffman code of the symbol
next: u16 = 0, // pointer to the next symbol in linked list
// it is safe to use 0 as null pointer, when sorted 0 has shortest code and fits into lookup
// Sorting less than function.
pub fn asc(_: void, a: Symbol, b: Symbol) bool {
if (a.code_bits == b.code_bits) {
if (a.kind == b.kind) {
return a.symbol < b.symbol;
}
return @intFromEnum(a.kind) < @intFromEnum(b.kind);
}
return a.code_bits < b.code_bits;
}
};
pub const LiteralDecoder = HuffmanDecoder(286, 15, 9);
pub const DistanceDecoder = HuffmanDecoder(30, 15, 9);
pub const CodegenDecoder = HuffmanDecoder(19, 7, 7);
pub const Error = error{
InvalidCode,
OversubscribedHuffmanTree,
IncompleteHuffmanTree,
MissingEndOfBlockCode,
};
/// Creates huffman tree codes from list of code lengths (in `build`).
///
/// `find` then finds symbol for code bits. Code can be any length between 1 and
/// 15 bits. When calling `find` we don't know how many bits will be used to
/// find symbol. When symbol is returned it has code_bits field which defines
/// how much we should advance in bit stream.
///
/// Lookup table is used to map 15 bit int to symbol. Same symbol is written
/// many times in this table; 32K places for 286 (at most) symbols.
/// Small lookup table is optimization for faster search.
/// It is variation of the algorithm explained in [zlib](https://github.com/madler/zlib/blob/643e17b7498d12ab8d15565662880579692f769d/doc/algorithm.txt#L92)
/// with difference that we here use statically allocated arrays.
///
fn HuffmanDecoder(
comptime alphabet_size: u16,
comptime max_code_bits: u4,
comptime lookup_bits: u4,
) type {
const lookup_shift = max_code_bits - lookup_bits;
return struct {
// all symbols in alaphabet, sorted by code_len, symbol
symbols: [alphabet_size]Symbol = undefined,
// lookup table code -> symbol
lookup: [1 << lookup_bits]Symbol = undefined,
const Self = @This();
/// Generates symbols and lookup tables from list of code lens for each symbol.
pub fn generate(self: *Self, lens: []const u4) !void {
try checkCompletnes(lens);
// init alphabet with code_bits
for (self.symbols, 0..) |_, i| {
const cb: u4 = if (i < lens.len) lens[i] else 0;
self.symbols[i] = if (i < 256)
.{ .kind = .literal, .symbol = @intCast(i), .code_bits = cb }
else if (i == 256)
.{ .kind = .end_of_block, .symbol = 0xff, .code_bits = cb }
else
.{ .kind = .match, .symbol = @intCast(i - 257), .code_bits = cb };
}
std.sort.heap(Symbol, &self.symbols, {}, Symbol.asc);
// reset lookup table
for (0..self.lookup.len) |i| {
self.lookup[i] = .{};
}
// assign code to symbols
// reference: https://youtu.be/9_YEGLe33NA?list=PLU4IQLU9e_OrY8oASHx0u3IXAL9TOdidm&t=2639
var code: u16 = 0;
var idx: u16 = 0;
for (&self.symbols, 0..) |*sym, pos| {
//print("sym: {}\n", .{sym});
if (sym.code_bits == 0) continue; // skip unused
sym.code = code;
const next_code = code + (@as(u16, 1) << (max_code_bits - sym.code_bits));
const next_idx = next_code >> lookup_shift;
if (next_idx > self.lookup.len or idx >= self.lookup.len) break;
if (sym.code_bits <= lookup_bits) {
// fill small lookup table
for (idx..next_idx) |j|
self.lookup[j] = sym.*;
} else {
// insert into linked table starting at root
const root = &self.lookup[idx];
const root_next = root.next;
root.next = @intCast(pos);
sym.next = root_next;
}
idx = next_idx;
code = next_code;
}
//print("decoder generate, code: {d}, idx: {d}\n", .{ code, idx });
}
/// Given the list of code lengths check that it represents a canonical
/// Huffman code for n symbols.
///
/// Reference: https://github.com/madler/zlib/blob/5c42a230b7b468dff011f444161c0145b5efae59/contrib/puff/puff.c#L340
fn checkCompletnes(lens: []const u4) !void {
if (alphabet_size == 286)
if (lens[256] == 0) return error.MissingEndOfBlockCode;
var count = [_]u16{0} ** (@as(usize, max_code_bits) + 1);
var max: usize = 0;
for (lens) |n| {
if (n == 0) continue;
if (n > max) max = n;
count[n] += 1;
}
if (max == 0) // emtpy tree
return;
// check for an over-subscribed or incomplete set of lengths
var left: usize = 1; // one possible code of zero length
for (1..count.len) |len| {
left <<= 1; // one more bit, double codes left
if (count[len] > left)
return error.OversubscribedHuffmanTree;
left -= count[len]; // deduct count from possible codes
}
if (left > 0) { // left > 0 means incomplete
// incomplete code ok only for single length 1 code
if (max_code_bits > 7 and max == count[0] + count[1]) return;
return error.IncompleteHuffmanTree;
}
}
/// Finds symbol for lookup table code.
pub fn find(self: *Self, code: u16) !Symbol {
// try to find in lookup table
const idx = code >> lookup_shift;
const sym = self.lookup[idx];
if (sym.code_bits != 0) return sym;
// if not use linked list of symbols with same prefix
return self.findLinked(code, sym.next);
}
inline fn findLinked(self: *Self, code: u16, start: u16) !Symbol {
var pos = start;
while (pos > 0) {
const sym = self.symbols[pos];
const shift = max_code_bits - sym.code_bits;
// compare code_bits number of upper bits
if ((code ^ sym.code) >> shift == 0) return sym;
pos = sym.next;
}
return error.InvalidCode;
}
};
}
test "flate.HuffmanDecoder init/find" {
// example data from: https://youtu.be/SJPvNi4HrWQ?t=8423
const code_lens = [_]u4{ 4, 3, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2 };
var h: CodegenDecoder = .{};
try h.generate(&code_lens);
const expected = [_]struct {
sym: Symbol,
code: u16,
}{
.{
.code = 0b00_00000,
.sym = .{ .symbol = 3, .code_bits = 2 },
},
.{
.code = 0b01_00000,
.sym = .{ .symbol = 18, .code_bits = 2 },
},
.{
.code = 0b100_0000,
.sym = .{ .symbol = 1, .code_bits = 3 },
},
.{
.code = 0b101_0000,
.sym = .{ .symbol = 4, .code_bits = 3 },
},
.{
.code = 0b110_0000,
.sym = .{ .symbol = 17, .code_bits = 3 },
},
.{
.code = 0b1110_000,
.sym = .{ .symbol = 0, .code_bits = 4 },
},
.{
.code = 0b1111_000,
.sym = .{ .symbol = 16, .code_bits = 4 },
},
};
// unused symbols
for (0..12) |i| {
try testing.expectEqual(0, h.symbols[i].code_bits);
}
// used, from index 12
for (expected, 12..) |e, i| {
try testing.expectEqual(e.sym.symbol, h.symbols[i].symbol);
try testing.expectEqual(e.sym.code_bits, h.symbols[i].code_bits);
const sym_from_code = try h.find(e.code);
try testing.expectEqual(e.sym.symbol, sym_from_code.symbol);
}
// All possible codes for each symbol.
// Lookup table has 126 elements, to cover all possible 7 bit codes.
for (0b0000_000..0b0100_000) |c| // 0..32 (32)
try testing.expectEqual(3, (try h.find(@intCast(c))).symbol);
for (0b0100_000..0b1000_000) |c| // 32..64 (32)
try testing.expectEqual(18, (try h.find(@intCast(c))).symbol);
for (0b1000_000..0b1010_000) |c| // 64..80 (16)
try testing.expectEqual(1, (try h.find(@intCast(c))).symbol);
for (0b1010_000..0b1100_000) |c| // 80..96 (16)
try testing.expectEqual(4, (try h.find(@intCast(c))).symbol);
for (0b1100_000..0b1110_000) |c| // 96..112 (16)
try testing.expectEqual(17, (try h.find(@intCast(c))).symbol);
for (0b1110_000..0b1111_000) |c| // 112..120 (8)
try testing.expectEqual(0, (try h.find(@intCast(c))).symbol);
for (0b1111_000..0b1_0000_000) |c| // 120...128 (8)
try testing.expectEqual(16, (try h.find(@intCast(c))).symbol);
}
const print = std.debug.print;
const assert = std.debug.assert;
const expect = std.testing.expect;
test "flate.HuffmanDecoder encode/decode literals" {
const LiteralEncoder = @import("huffman_encoder.zig").LiteralEncoder;
for (1..286) |j| { // for all different number of codes
var enc: LiteralEncoder = .{};
// create freqencies
var freq = [_]u16{0} ** 286;
freq[256] = 1; // ensure we have end of block code
for (&freq, 1..) |*f, i| {
if (i % j == 0)
f.* = @intCast(i);
}
// encoder from freqencies
enc.generate(&freq, 15);
// get code_lens from encoder
var code_lens = [_]u4{0} ** 286;
for (code_lens, 0..) |_, i| {
code_lens[i] = @intCast(enc.codes[i].len);
}
// generate decoder from code lens
var dec: LiteralDecoder = .{};
try dec.generate(&code_lens);
// expect decoder code to match original encoder code
for (dec.symbols) |s| {
if (s.code_bits == 0) continue;
const c_code: u16 = @bitReverse(@as(u15, @intCast(s.code)));
const symbol: u16 = switch (s.kind) {
.literal => s.symbol,
.end_of_block => 256,
.match => @as(u16, s.symbol) + 257,
};
const c = enc.codes[symbol];
try expect(c.code == c_code);
}
// find each symbol by code
for (enc.codes) |c| {
if (c.len == 0) continue;
const s_code: u15 = @bitReverse(@as(u15, @intCast(c.code)));
const s = try dec.find(s_code);
try expect(s.code == s_code);
try expect(s.code_bits == c.len);
}
}
}