std.compress.zstd: it's compiling

This commit is contained in:
Andrew Kelley 2025-05-01 14:52:41 -07:00
parent 685d55c1a4
commit e05af2da13
5 changed files with 379 additions and 435 deletions

View File

@ -109,7 +109,7 @@ fn testExpectDecompressError(err: anyerror, compressed: []const u8) !void {
in.initFixed(@constCast(compressed));
var zstd_stream: Decompress = .init(&in, .{});
try std.testing.expectError(error.ReadFailed, zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited));
try std.testing.expectError(err, zstd_stream.err.?);
try std.testing.expectError(err, zstd_stream.err orelse {});
return error.TestFailed;
}

View File

@ -11,12 +11,10 @@ state: State,
verify_checksum: bool,
err: ?Error = null,
const table_size_max = zstd.compressed_block.table_size_max;
const State = union(enum) {
new_frame,
in_frame: InFrame,
skipping_frame: u32,
skipping_frame: usize,
end,
const InFrame = struct {
@ -31,11 +29,38 @@ pub const Options = struct {
};
pub const Error = error{
BadMagic,
BlockOversize,
ChecksumFailure,
ContentOversize,
DictionaryIdFlagUnsupported,
MalformedBlock,
MalformedFrame,
EndOfStream,
HuffmanTreeIncomplete,
InvalidBitStream,
LiteralsBufferUndersize,
MalformedAccuracyLog,
MalformedBlock,
MalformedCompressedBlock,
MalformedFrame,
MalformedFseBits,
MalformedFseTable,
MalformedHuffmanTree,
MalformedLiteralsHeader,
MalformedLiteralsLength,
MalformedLiteralsSection,
MalformedSequence,
MissingStartBit,
OutputBufferUndersize,
InputBufferUndersize,
ReadFailed,
RepeatModeFirst,
ReservedBitSet,
ReservedBlock,
SequenceBufferUndersize,
TreelessLiteralsFirst,
UnexpectedEndOfLiteralStream,
WindowOversize,
WindowSizeUnknown,
};
pub fn init(input: *BufferedReader, options: Options) Decompress {
@ -69,25 +94,32 @@ fn read(context: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit) Reader.R
d.err = err;
return error.ReadFailed;
};
return readInFrame(d, bw, limit, &d.state.in_frame) catch |err| {
d.err = err;
return error.ReadFailed;
return readInFrame(d, bw, limit, &d.state.in_frame) catch |err| switch (err) {
error.ReadFailed => return error.ReadFailed,
error.WriteFailed => return error.WriteFailed,
else => |e| {
d.err = e;
return error.ReadFailed;
},
};
},
.in_frame => |*in_frame| {
return readInFrame(d, bw, limit, in_frame) catch |err| {
d.err = err;
return error.ReadFailed;
return readInFrame(d, bw, limit, in_frame) catch |err| switch (err) {
error.ReadFailed => return error.ReadFailed,
error.WriteFailed => return error.WriteFailed,
else => |e| {
d.err = e;
return error.ReadFailed;
},
};
},
.skipping_frame => |*remaining| {
const requested = remaining.*;
const n = in.discard(.limited(requested)) catch |err| {
const n = in.discard(.limited(remaining.*)) catch |err| {
d.err = err;
return error.ReadFailed;
};
if (requested == n) d.state = .new_frame;
remaining.* = requested - n;
remaining.* -= n;
if (remaining.* == 0) d.state = .new_frame;
return 0;
},
.end => return error.EndOfStream,
@ -115,9 +147,9 @@ fn initFrame(d: *Decompress, window_size_max: usize, magic: Frame.Magic) !void {
fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: *State.InFrame) !usize {
const in = d.input;
var literal_fse_buffer: [table_size_max.literal]Table.Fse = undefined;
var match_fse_buffer: [table_size_max.match]Table.Fse = undefined;
var offset_fse_buffer: [table_size_max.offset]Table.Fse = undefined;
var literal_fse_buffer: [zstd.table_size_max.literal]Table.Fse = undefined;
var match_fse_buffer: [zstd.table_size_max.match]Table.Fse = undefined;
var offset_fse_buffer: [zstd.table_size_max.offset]Table.Fse = undefined;
var literals_buffer: [zstd.block_size_max]u8 = undefined;
var sequence_buffer: [zstd.block_size_max]u8 = undefined;
@ -125,10 +157,10 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state:
const header_bytes = try in.takeArray(3);
const block_header: Frame.Zstandard.Block.Header = @bitCast(header_bytes.*);
const block_size = block_header.block_size;
const block_size = block_header.size;
if (state.frame.block_size_max < block_size) return error.BlockOversize;
if (@intFromEnum(limit) < block_size) return error.OutputBufferUndersize;
switch (block_header.block_type) {
switch (block_header.type) {
.raw => {
try in.readAll(bw, .limited(block_size));
return block_size;
@ -151,9 +183,10 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state:
var bytes_written: usize = 0;
{
if (sequence_buffer.len < @intFromEnum(remaining))
return error.SequenceBufferTooSmall;
const seq_len = try in.readSlice(remaining.slice(&sequence_buffer));
var bit_stream = try ReverseBitReader.init(sequence_buffer[0..seq_len]);
return error.SequenceBufferUndersize;
const seq_slice = remaining.slice(&sequence_buffer);
try in.readSlice(seq_slice);
var bit_stream = try ReverseBitReader.init(seq_slice);
if (sequences_header.sequence_count > 0) {
try decode.readInitialFseState(&bit_stream);
@ -205,16 +238,16 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state:
}
}
if (block_header.last_block) {
if (block_header.last) {
if (state.frame.has_checksum) {
const expected_checksum = try in.readInt(u32, .little);
const expected_checksum = try in.takeInt(u32, .little);
if (state.frame.hasher_opt) |*hasher| {
const actual_checksum: u32 = @truncate(hasher.final());
if (expected_checksum != actual_checksum) return error.ChecksumFailure;
}
}
if (d.frame.content_size) |content_size| {
if (content_size != d.current_frame_decompressed_size) {
if (state.frame.content_size) |content_size| {
if (content_size != state.decompressed_size) {
return error.MalformedFrame;
}
}
@ -249,16 +282,16 @@ pub const Frame = struct {
_,
pub fn kind(m: Magic) ?Kind {
return switch (m) {
.zstandard => .zstandard,
Skippable.magic_min...Skippable.magic_max => .skippable,
return switch (@intFromEnum(m)) {
@intFromEnum(Magic.zstandard) => .zstandard,
@intFromEnum(Skippable.magic_min)...@intFromEnum(Skippable.magic_max) => .skippable,
else => null,
};
}
pub fn isSkippable(m: Magic) bool {
return switch (m) {
Skippable.magic_min...Skippable.magic_max => true,
return switch (@intFromEnum(m)) {
@intFromEnum(Skippable.magic_min)...@intFromEnum(Skippable.magic_max) => true,
else => false,
};
}
@ -384,9 +417,9 @@ pub const Frame = struct {
) Decode {
return .{
.repeat_offsets = .{
zstd.compressed_block.start_repeated_offset_1,
zstd.compressed_block.start_repeated_offset_2,
zstd.compressed_block.start_repeated_offset_3,
zstd.start_repeated_offset_1,
zstd.start_repeated_offset_2,
zstd.start_repeated_offset_3,
},
.offset = undefined,
@ -410,7 +443,7 @@ pub const Frame = struct {
pub const PrepareError = error{
/// the (reversed) literal bitstream's first byte does not have any bits set
BitStreamHasNoStartBit,
MissingStartBit,
/// `literals` is a treeless literals section and the decode state does not
/// have a Huffman tree from a previous block
TreelessLiteralsFirst,
@ -422,6 +455,8 @@ pub const Frame = struct {
MalformedFseTable,
/// input stream ends before all FSE tables are read
EndOfStream,
ReadFailed,
InputBufferUndersize,
};
/// Prepare the decoder to decode a compressed block. Loads the literals
@ -430,6 +465,7 @@ pub const Frame = struct {
pub fn prepare(
self: *Decode,
in: *BufferedReader,
remaining: *Reader.Limit,
literals: LiteralsSection,
sequences_header: SequencesSection.Header,
) PrepareError!void {
@ -455,17 +491,14 @@ pub const Frame = struct {
}
if (sequences_header.sequence_count > 0) {
try self.updateFseTable(in, .literal, sequences_header.literal_lengths);
try self.updateFseTable(in, .offset, sequences_header.offsets);
try self.updateFseTable(in, .match, sequences_header.match_lengths);
try self.updateFseTable(in, remaining, .literal, sequences_header.literal_lengths);
try self.updateFseTable(in, remaining, .offset, sequences_header.offsets);
try self.updateFseTable(in, remaining, .match, sequences_header.match_lengths);
self.fse_tables_undefined = false;
}
}
/// Read initial FSE states for sequence decoding.
///
/// Errors returned:
/// - `error.EndOfStream` if `bit_reader` does not contain enough bits.
pub fn readInitialFseState(self: *Decode, bit_reader: *ReverseBitReader) error{EndOfStream}!void {
self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log);
self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log);
@ -490,6 +523,7 @@ pub const Frame = struct {
const DataType = enum { offset, match, literal };
/// TODO: don't use `@field`
fn updateState(
self: *Decode,
comptime choice: DataType,
@ -517,9 +551,11 @@ pub const Frame = struct {
EndOfStream,
};
/// TODO: don't use `@field`
fn updateFseTable(
self: *Decode,
source: *BufferedReader,
in: *BufferedReader,
remaining: *Reader.Limit,
comptime choice: DataType,
mode: SequencesSection.Header.Mode,
) !void {
@ -527,28 +563,32 @@ pub const Frame = struct {
switch (mode) {
.predefined => {
@field(self, field_name).accuracy_log =
@field(zstd.compressed_block.default_accuracy_log, field_name);
@field(zstd.default_accuracy_log, field_name);
@field(self, field_name).table =
@field(Table, "predefined_" ++ field_name);
},
.rle => {
@field(self, field_name).accuracy_log = 0;
@field(self, field_name).table = .{ .rle = try source.readByte() };
remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
@field(self, field_name).table = .{ .rle = try in.takeByte() };
},
.fse => {
var bit_reader: std.io.BitReader(.little) = .init(source);
if (in.buffer.len < @intFromEnum(remaining.*)) return error.InputBufferUndersize;
const limited_buffer = try in.peek(@intFromEnum(remaining.*));
var bit_reader: BitReader = .{ .bytes = limited_buffer };
const table_size = try Table.decode(
&bit_reader,
@field(zstd.compressed_block.table_symbol_count_max, field_name),
@field(zstd.compressed_block.table_accuracy_log_max, field_name),
@field(zstd.table_symbol_count_max, field_name),
@field(zstd.table_accuracy_log_max, field_name),
@field(self, field_name ++ "_fse_buffer"),
);
@field(self, field_name).table = .{
.fse = @field(self, field_name ++ "_fse_buffer")[0..table_size],
};
@field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size);
in.toss(bit_reader.index);
remaining.* = remaining.subtract(bit_reader.index).?;
},
.repeat => if (self.fse_tables_undefined) return error.RepeatModeFirst,
}
@ -571,15 +611,15 @@ pub const Frame = struct {
const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code);
const match_code = self.getCode(.match);
if (match_code >= zstd.compressed_block.match_length_code_table.len)
if (match_code >= zstd.match_length_code_table.len)
return error.InvalidBitStream;
const match = zstd.compressed_block.match_length_code_table[match_code];
const match = zstd.match_length_code_table[match_code];
const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]);
const literal_code = self.getCode(.literal);
if (literal_code >= zstd.compressed_block.literals_length_code_table.len)
if (literal_code >= zstd.literals_length_code_table.len)
return error.InvalidBitStream;
const literal = zstd.compressed_block.literals_length_code_table[literal_code];
const literal = zstd.literals_length_code_table[literal_code];
const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]);
const offset = if (offset_value > 3) offset: {
@ -622,12 +662,17 @@ pub const Frame = struct {
/// The `BufferedWriter` storage capacity is not large enough to
/// accept this stream.
OutputBufferUndersize,
WriteFailed,
MalformedLiteralsLength,
MalformedFseBits,
MissingStartBit,
HuffmanTreeIncomplete,
};
/// Decode one sequence from `bit_reader` into `dest`. Updates FSE states
/// if `last_sequence` is `false`. Assumes `prepare` called for the block
/// before attempting to decode sequences.
pub fn decodeSequence(
fn decodeSequence(
self: *Decode,
dest: *BufferedWriter,
bit_reader: *ReverseBitReader,
@ -662,13 +707,13 @@ pub const Frame = struct {
return sequence_length;
}
fn nextLiteralMultiStream(self: *Decode) error{BitStreamHasNoStartBit}!void {
fn nextLiteralMultiStream(self: *Decode) error{MissingStartBit}!void {
self.literal_stream_index += 1;
try self.initLiteralStream(self.literal_streams.four[self.literal_stream_index]);
}
fn initLiteralStream(self: *Decode, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
try self.literal_stream_reader.init(bytes);
fn initLiteralStream(self: *Decode, bytes: []const u8) error{MissingStartBit}!void {
self.literal_stream_reader = try ReverseBitReader.init(bytes);
}
fn isLiteralStreamEmpty(self: *Decode) bool {
@ -679,7 +724,7 @@ pub const Frame = struct {
}
const LiteralBitsError = error{
BitStreamHasNoStartBit,
MissingStartBit,
UnexpectedEndOfLiteralStream,
};
fn readLiteralsBits(
@ -704,6 +749,9 @@ pub const Frame = struct {
/// Problems decoding Huffman compressed literals
UnexpectedEndOfLiteralStream,
OutputBufferUndersize,
WriteFailed,
MissingStartBit,
HuffmanTreeIncomplete,
};
/// Decode `len` bytes of literals into `dest`.
@ -765,6 +813,7 @@ pub const Frame = struct {
}
}
/// TODO: don't use `@field`
fn getCode(self: *Decode, comptime choice: DataType) u32 {
return switch (@field(self, @tagName(choice)).table) {
.rle => |value| value,
@ -785,21 +834,17 @@ pub const Frame = struct {
};
const InitError = error{
/// Frame uses a dictionary.
DictionaryIdFlagUnsupported,
/// Frame does not have a valid window size.
WindowSizeUnknown,
WindowTooLarge,
ContentSizeTooLarge,
/// Window size exceeds `window_size_max` or max `usize` value.
WindowOversize,
/// Frame header indicates a content size exceeding max `usize` value.
ContentOversize,
};
/// Validates `frame_header` and returns the associated `Frame`.
///
/// Errors returned:
/// - `error.DictionaryIdFlagUnsupported` if the frame uses a dictionary
/// - `error.WindowSizeUnknown` if the frame does not have a valid window
/// size
/// - `error.WindowTooLarge` if the window size is larger than
/// `window_size_max` or `std.math.intMax(usize)`
/// - `error.ContentSizeTooLarge` if the frame header indicates a content
/// size larger than `std.math.maxInt(usize)`
pub fn init(
frame_header: Frame.Zstandard.Header,
window_size_max: usize,
@ -810,15 +855,15 @@ pub const Frame = struct {
const window_size_raw = frame_header.windowSize() orelse return error.WindowSizeUnknown;
const window_size = if (window_size_raw > window_size_max)
return error.WindowTooLarge
return error.WindowOversize
else
std.math.cast(usize, window_size_raw) orelse return error.WindowTooLarge;
std.math.cast(usize, window_size_raw) orelse return error.WindowOversize;
const should_compute_checksum =
frame_header.descriptor.content_checksum_flag and verify_checksum;
const content_size = if (frame_header.content_size) |size|
std.math.cast(usize, size) orelse return error.ContentSizeTooLarge
std.math.cast(usize, size) orelse return error.ContentOversize
else
null;
@ -875,13 +920,11 @@ pub const LiteralsSection = struct {
compressed_size: ?u18,
/// Decode a literals section header.
///
/// Errors returned:
/// - `error.EndOfStream` if there are not enough bytes in `source`
pub fn decode(source: *BufferedReader) !Header {
const byte0 = try source.readByte();
const block_type = @as(BlockType, @enumFromInt(byte0 & 0b11));
const size_format = @as(u2, @intCast((byte0 & 0b1100) >> 2));
pub fn decode(in: *BufferedReader, remaining: *Reader.Limit) !Header {
remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
const byte0 = try in.takeByte();
const block_type: BlockType = @enumFromInt(byte0 & 0b11);
const size_format: u2 = @intCast((byte0 & 0b1100) >> 2);
var regenerated_size: u20 = undefined;
var compressed_size: ?u18 = null;
switch (block_type) {
@ -890,28 +933,37 @@ pub const LiteralsSection = struct {
0, 2 => {
regenerated_size = byte0 >> 3;
},
1 => regenerated_size = (byte0 >> 4) + (@as(u20, try source.readByte()) << 4),
3 => regenerated_size = (byte0 >> 4) +
(@as(u20, try source.readByte()) << 4) +
(@as(u20, try source.readByte()) << 12),
1 => {
remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
regenerated_size = (byte0 >> 4) + (@as(u20, try in.takeByte()) << 4);
},
3 => {
remaining.* = remaining.subtract(2) orelse return error.EndOfStream;
regenerated_size = (byte0 >> 4) +
(@as(u20, try in.takeByte()) << 4) +
(@as(u20, try in.takeByte()) << 12);
},
}
},
.compressed, .treeless => {
const byte1 = try source.readByte();
const byte2 = try source.readByte();
remaining.* = remaining.subtract(2) orelse return error.EndOfStream;
const byte1 = try in.takeByte();
const byte2 = try in.takeByte();
switch (size_format) {
0, 1 => {
regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4);
compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2);
},
2 => {
const byte3 = try source.readByte();
remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
const byte3 = try in.takeByte();
regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12);
compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6);
},
3 => {
const byte3 = try source.readByte();
const byte4 = try source.readByte();
remaining.* = remaining.subtract(2) orelse return error.EndOfStream;
const byte3 = try in.takeByte();
const byte4 = try in.takeByte();
regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12);
compressed_size = ((byte2 & 0b11000000) >> 6) + (@as(u18, byte3) << 2) + (@as(u18, byte4) << 10);
},
@ -950,17 +1002,17 @@ pub const LiteralsSection = struct {
index: usize,
};
pub fn query(self: HuffmanTree, index: usize, prefix: u16) error{NotFound}!Result {
pub fn query(self: HuffmanTree, index: usize, prefix: u16) error{HuffmanTreeIncomplete}!Result {
var node = self.nodes[index];
const weight = node.weight;
var i: usize = index;
while (node.weight == weight) {
if (node.prefix == prefix) return Result{ .symbol = node.symbol };
if (i == 0) return error.NotFound;
if (node.prefix == prefix) return .{ .symbol = node.symbol };
if (i == 0) return error.HuffmanTreeIncomplete;
i -= 1;
node = self.nodes[i];
}
return Result{ .index = i };
return .{ .index = i };
}
pub fn weightToBitCount(weight: u4, max_bit_count: u4) u4 {
@ -975,20 +1027,26 @@ pub const LiteralsSection = struct {
MissingStartBit,
};
pub fn decode(in: *BufferedReader) HuffmanTree.DecodeError!HuffmanTree {
pub fn decode(in: *BufferedReader, remaining: *Reader.Limit) HuffmanTree.DecodeError!HuffmanTree {
remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
const header = try in.takeByte();
if (header < 128) {
return decodeFse(in, header);
return decodeFse(in, remaining, header);
} else {
return decodeDirect(in, header - 127);
return decodeDirect(in, remaining, header - 127);
}
}
fn decodeDirect(source: *BufferedReader, encoded_symbol_count: usize) HuffmanTree.DecodeError!HuffmanTree {
fn decodeDirect(
in: *BufferedReader,
remaining: *Reader.Limit,
encoded_symbol_count: usize,
) HuffmanTree.DecodeError!HuffmanTree {
var weights: [256]u4 = undefined;
const weights_byte_count = (encoded_symbol_count + 1) / 2;
remaining.* = remaining.subtract(weights_byte_count) orelse return error.EndOfStream;
for (0..weights_byte_count) |i| {
const byte = try source.takeByte();
const byte = try in.takeByte();
weights[2 * i] = @as(u4, @intCast(byte >> 4));
weights[2 * i + 1] = @as(u4, @intCast(byte & 0xF));
}
@ -996,22 +1054,25 @@ pub const LiteralsSection = struct {
return build(&weights, symbol_count);
}
fn decodeFse(in: *BufferedReader, compressed_size: usize) HuffmanTree.DecodeError!HuffmanTree {
fn decodeFse(
in: *BufferedReader,
remaining: *Reader.Limit,
compressed_size: usize,
) HuffmanTree.DecodeError!HuffmanTree {
var weights: [256]u4 = undefined;
remaining.* = remaining.subtract(compressed_size) orelse return error.EndOfStream;
const compressed_buffer = try in.take(compressed_size);
var limited_stream: BufferedReader = undefined;
limited_stream.initFixed(compressed_buffer);
var bit_reader: std.io.BitReader(.little) = .init(&limited_stream);
var bit_reader: BitReader = .{ .bytes = compressed_buffer };
var entries: [1 << 6]Table.Fse = undefined;
const table_size = try Table.decode(&bit_reader, 256, 6, &entries);
const accuracy_log = std.math.log2_int_ceil(usize, table_size);
const remaining = limited_stream.bufferContents();
const symbol_count = try assignWeights(remaining, accuracy_log, &entries, weights);
const remaining_buffer = bit_reader.bytes[bit_reader.index..];
const symbol_count = try assignWeights(remaining_buffer, accuracy_log, &entries, &weights);
return build(&weights, symbol_count);
}
fn assignWeights(
huff_bits_buffer: []u8,
huff_bits_buffer: []const u8,
accuracy_log: u16,
entries: *[1 << 6]Table.Fse,
weights: *[256]u4,
@ -1159,14 +1220,18 @@ pub const LiteralsSection = struct {
MalformedHuffmanTree,
/// Not enough bytes to complete the section.
EndOfStream,
ReadFailed,
LiteralsBufferUndersize,
MissingStartBit,
};
pub fn decode(source: *BufferedReader, buffer: []u8) DecodeError!LiteralsSection {
const header = try Header.decode(source);
pub fn decode(in: *BufferedReader, remaining: *Reader.Limit, buffer: []u8) DecodeError!LiteralsSection {
const header = try Header.decode(in, remaining);
switch (header.block_type) {
.raw => {
if (buffer.len < header.regenerated_size) return error.LiteralsBufferTooSmall;
try source.readNoEof(buffer[0..header.regenerated_size]);
if (buffer.len < header.regenerated_size) return error.LiteralsBufferUndersize;
remaining.* = remaining.subtract(header.regenerated_size) orelse return error.EndOfStream;
try in.readSlice(buffer[0..header.regenerated_size]);
return .{
.header = header,
.huffman_tree = null,
@ -1174,7 +1239,8 @@ pub const LiteralsSection = struct {
};
},
.rle => {
buffer[0] = try source.readByte();
remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
buffer[0] = try in.takeByte();
return .{
.header = header,
.huffman_tree = null,
@ -1182,19 +1248,18 @@ pub const LiteralsSection = struct {
};
},
.compressed, .treeless => {
var counting_reader = std.io.countingReader(source);
const before_remaining = remaining.*;
const huffman_tree = if (header.block_type == .compressed)
try HuffmanTree.decode(counting_reader.reader(), buffer)
try HuffmanTree.decode(in, remaining)
else
null;
const huffman_tree_size = @as(usize, @intCast(counting_reader.bytes_read));
const huffman_tree_size = @intFromEnum(before_remaining) - @intFromEnum(remaining.*);
const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch
return error.MalformedLiteralsSection;
if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall;
try source.readNoEof(buffer[0..total_streams_size]);
if (total_streams_size > buffer.len) return error.LiteralsBufferUndersize;
remaining.* = remaining.subtract(total_streams_size) orelse return error.EndOfStream;
try in.readSlice(buffer[0..total_streams_size]);
const stream_data = buffer[0..total_streams_size];
const streams = try Streams.decode(header.size_format, stream_data);
return .{
.header = header,
@ -1207,7 +1272,7 @@ pub const LiteralsSection = struct {
};
pub const SequencesSection = struct {
header: SequencesSection.Header,
header: Header,
literals_length_table: Table,
offset_table: Table,
match_length_table: Table,
@ -1228,32 +1293,37 @@ pub const SequencesSection = struct {
pub const DecodeError = error{
ReservedBitSet,
EndOfStream,
ReadFailed,
};
pub fn decode(source: *BufferedReader) DecodeError!Header {
pub fn decode(in: *BufferedReader, remaining: *Reader.Limit) DecodeError!Header {
var sequence_count: u24 = undefined;
const byte0 = try source.readByte();
remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
const byte0 = try in.takeByte();
if (byte0 == 0) {
return SequencesSection.Header{
return .{
.sequence_count = 0,
.offsets = undefined,
.match_lengths = undefined,
.literal_lengths = undefined,
};
} else if (byte0 < 128) {
remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
sequence_count = byte0;
} else if (byte0 < 255) {
sequence_count = (@as(u24, (byte0 - 128)) << 8) + try source.readByte();
remaining.* = remaining.subtract(2) orelse return error.EndOfStream;
sequence_count = (@as(u24, (byte0 - 128)) << 8) + try in.takeByte();
} else {
sequence_count = (try source.readByte()) + (@as(u24, try source.readByte()) << 8) + 0x7F00;
remaining.* = remaining.subtract(3) orelse return error.EndOfStream;
sequence_count = (try in.takeByte()) + (@as(u24, try in.takeByte()) << 8) + 0x7F00;
}
const compression_modes = try source.readByte();
const compression_modes = try in.takeByte();
const matches_mode = @as(SequencesSection.Header.Mode, @enumFromInt((compression_modes & 0b00001100) >> 2));
const offsets_mode = @as(SequencesSection.Header.Mode, @enumFromInt((compression_modes & 0b00110000) >> 4));
const literal_mode = @as(SequencesSection.Header.Mode, @enumFromInt((compression_modes & 0b11000000) >> 6));
const matches_mode: Header.Mode = @enumFromInt((compression_modes & 0b00001100) >> 2);
const offsets_mode: Header.Mode = @enumFromInt((compression_modes & 0b00110000) >> 4);
const literal_mode: Header.Mode = @enumFromInt((compression_modes & 0b11000000) >> 6);
if (compression_modes & 0b11 != 0) return error.ReservedBitSet;
return .{
@ -1277,7 +1347,7 @@ pub const Table = union(enum) {
};
pub fn decode(
bit_reader: *std.io.BitReader(.little),
bit_reader: *BitReader,
expected_symbol_count: usize,
max_accuracy_log: u4,
entries: []Table.Fse,
@ -1600,6 +1670,22 @@ pub const Table = union(enum) {
};
};
const low_bit_mask = [9]u8{
0b00000000,
0b00000001,
0b00000011,
0b00000111,
0b00001111,
0b00011111,
0b00111111,
0b01111111,
0b11111111,
};
fn Bits(comptime T: type) type {
return struct { T, u16 };
}
/// For reading the reversed bit streams used to encode FSE compressed data.
const ReverseBitReader = struct {
bytes: []const u8,
@ -1619,20 +1705,175 @@ const ReverseBitReader = struct {
return error.MissingStartBit;
}
fn readBitsNoEof(self: *ReverseBitReader, comptime U: type, num_bits: u16) error{EndOfStream}!U {
return self.bit_reader.readBitsNoEof(U, num_bits);
fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) {
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
return .{
@bitCast(@as(UT, @intCast(out))),
num,
};
}
fn readBits(self: *ReverseBitReader, comptime U: type, num_bits: u16, out_bits: *u16) error{}!U {
return try self.bit_reader.readBits(U, num_bits, out_bits);
fn readBitsNoEof(self: *ReverseBitReader, comptime T: type, num: u16) error{EndOfStream}!T {
const b, const c = try self.readBitsTuple(T, num);
if (c < num) return error.EndOfStream;
return b;
}
fn alignToByte(self: *ReverseBitReader) void {
self.bit_reader.alignToByte();
fn readBits(self: *ReverseBitReader, comptime T: type, num: u16, out_bits: *u16) !T {
const b, const c = try self.readBitsTuple(T, num);
out_bits.* = c;
return b;
}
fn readBitsTuple(self: *ReverseBitReader, comptime T: type, num: u16) !Bits(T) {
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
const U = if (@bitSizeOf(T) < 8) u8 else UT;
if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num);
var out_count: u16 = self.count;
var out: U = self.removeBits(self.count);
const full_bytes_left = (num - out_count) / 8;
for (0..full_bytes_left) |_| {
const byte = takeByte(self) catch |err| switch (err) {
error.EndOfStream => return initBits(T, out, out_count),
};
if (U == u8) out = 0 else out <<= 8;
out |= byte;
out_count += 8;
}
const bits_left = num - out_count;
const keep = 8 - bits_left;
if (bits_left == 0) return initBits(T, out, out_count);
const final_byte = takeByte(self) catch |err| switch (err) {
error.EndOfStream => return initBits(T, out, out_count),
};
out <<= @intCast(bits_left);
out |= final_byte >> @intCast(keep);
self.bits = final_byte & low_bit_mask[keep];
self.count = @intCast(keep);
return initBits(T, out, num);
}
fn takeByte(rbr: *ReverseBitReader) error{EndOfStream}!u8 {
if (rbr.remaining == 0) return error.EndOfStream;
rbr.remaining -= 1;
return rbr.bytes[rbr.remaining];
}
fn isEmpty(self: *const ReverseBitReader) bool {
return self.byte_reader.remaining_bytes == 0 and self.bit_reader.count == 0;
return self.remaining == 0 and self.count == 0;
}
fn removeBits(self: *ReverseBitReader, num: u4) u8 {
if (num == 8) {
self.count = 0;
return self.bits;
}
const keep = self.count - num;
const bits = self.bits >> @intCast(keep);
self.bits &= low_bit_mask[keep];
self.count = keep;
return bits;
}
};
const BitReader = struct {
bytes: []const u8,
index: usize = 0,
bits: u8 = 0,
count: u4 = 0,
fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) {
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
return .{
@bitCast(@as(UT, @intCast(out))),
num,
};
}
fn readBitsNoEof(self: *@This(), comptime T: type, num: u16) !T {
const b, const c = try self.readBitsTuple(T, num);
if (c < num) return error.EndOfStream;
return b;
}
fn readBits(self: *@This(), comptime T: type, num: u16, out_bits: *u16) !T {
const b, const c = try self.readBitsTuple(T, num);
out_bits.* = c;
return b;
}
fn readBitsTuple(self: *@This(), comptime T: type, num: u16) !Bits(T) {
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
const U = if (@bitSizeOf(T) < 8) u8 else UT;
if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num);
var out_count: u16 = self.count;
var out: U = self.removeBits(self.count);
const full_bytes_left = (num - out_count) / 8;
for (0..full_bytes_left) |_| {
const byte = takeByte(self) catch |err| switch (err) {
error.EndOfStream => return initBits(T, out, out_count),
};
const pos = @as(U, byte) << @intCast(out_count);
out |= pos;
out_count += 8;
}
const bits_left = num - out_count;
const keep = 8 - bits_left;
if (bits_left == 0) return initBits(T, out, out_count);
const final_byte = takeByte(self) catch |err| switch (err) {
error.EndOfStream => return initBits(T, out, out_count),
};
const pos = @as(U, final_byte & low_bit_mask[bits_left]) << @intCast(out_count);
out |= pos;
self.bits = final_byte >> @intCast(bits_left);
self.count = @intCast(keep);
return initBits(T, out, num);
}
fn takeByte(br: *BitReader) error{EndOfStream}!u8 {
if (br.bytes.len - br.index == 0) return error.EndOfStream;
const result = br.bytes[br.index];
br.index += 1;
return result;
}
fn removeBits(self: *@This(), num: u4) u8 {
if (num == 8) {
self.count = 0;
return self.bits;
}
const keep = self.count - num;
const bits = self.bits & low_bit_mask[num];
self.bits >>= @intCast(num);
self.count = keep;
return bits;
}
fn alignToByte(self: *@This()) void {
self.bits = 0;
self.count = 0;
}
};

View File

@ -19,16 +19,12 @@ pub const AllocatingWriter = @import("io/AllocatingWriter.zig");
pub const MultiWriter = @import("io/multi_writer.zig").MultiWriter;
pub const multiWriter = @import("io/multi_writer.zig").multiWriter;
pub const BitReader = @import("io/bit_reader.zig").Type;
pub const BitWriter = @import("io/bit_writer.zig").BitWriter;
pub const bitWriter = @import("io/bit_writer.zig").bitWriter;
pub const ChangeDetectionStream = @import("io/change_detection_stream.zig").ChangeDetectionStream;
pub const changeDetectionStream = @import("io/change_detection_stream.zig").changeDetectionStream;
pub const BufferedAtomicFile = @import("io/buffered_atomic_file.zig").BufferedAtomicFile;
pub const tty = @import("io/tty.zig");
pub fn poll(
@ -437,13 +433,11 @@ pub fn PollFiles(comptime StreamEnum: type) type {
}
test {
_ = BufferedWriter;
_ = AllocatingWriter;
_ = BitWriter;
_ = BufferedReader;
_ = BufferedWriter;
_ = Reader;
_ = Writer;
_ = AllocatingWriter;
_ = @import("io/bit_reader.zig");
_ = @import("io/bit_writer.zig");
_ = @import("io/buffered_atomic_file.zig");
_ = @import("io/test.zig");
}

View File

@ -1,236 +0,0 @@
const std = @import("../std.zig");
const bit_reader = @This();
//General note on endianess:
//Big endian is packed starting in the most significant part of the byte and subsequent
// bytes contain less significant bits. Thus we always take bits from the high
// end and place them below existing bits in our output.
//Little endian is packed starting in the least significant part of the byte and
// subsequent bytes contain more significant bits. Thus we always take bits from
// the low end and place them above existing bits in our output.
//Regardless of endianess, within any given byte the bits are always in most
// to least significant order.
//Also regardless of endianess, the buffer always aligns bits to the low end
// of the byte.
/// Creates a bit reader which allows for reading bits from an underlying standard reader
pub fn Type(comptime endian: std.builtin.Endian) type {
return struct {
reader: *std.io.BufferedReader,
bits: u8,
count: u4,
const low_bit_mask = [9]u8{
0b00000000,
0b00000001,
0b00000011,
0b00000111,
0b00001111,
0b00011111,
0b00111111,
0b01111111,
0b11111111,
};
pub fn init(reader: *std.io.BufferedReader) @This() {
return .{ .reader = reader, .bits = 0, .count = 0 };
}
fn Bits(comptime T: type) type {
return struct { T, u16 };
}
fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) {
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
return .{
@bitCast(@as(UT, @intCast(out))),
num,
};
}
/// Reads `bits` bits from the reader and returns a specified type
/// containing them in the least significant end, returning an error if the
/// specified number of bits could not be read.
pub fn readBitsNoEof(self: *@This(), comptime T: type, num: u16) !T {
const b, const c = try self.readBitsTuple(T, num);
if (c < num) return error.EndOfStream;
return b;
}
/// Reads `bits` bits from the reader and returns a specified type
/// containing them in the least significant end. The number of bits successfully
/// read is placed in `out_bits`, as reaching the end of the stream is not an error.
pub fn readBits(self: *@This(), comptime T: type, num: u16, out_bits: *u16) !T {
const b, const c = try self.readBitsTuple(T, num);
out_bits.* = c;
return b;
}
/// Reads `bits` bits from the reader and returns a tuple of the specified type
/// containing them in the least significant end, and the number of bits successfully
/// read. Reaching the end of the stream is not an error.
pub fn readBitsTuple(self: *@This(), comptime T: type, num: u16) !Bits(T) {
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
const U = if (@bitSizeOf(T) < 8) u8 else UT; //it is a pain to work with <u8
//dump any bits in our buffer first
if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num);
var out_count: u16 = self.count;
var out: U = self.removeBits(self.count);
//grab all the full bytes we need and put their
//bits where they belong
const full_bytes_left = (num - out_count) / 8;
for (0..full_bytes_left) |_| {
const byte = self.reader.takeByte() catch |err| switch (err) {
error.EndOfStream => return initBits(T, out, out_count),
else => |e| return e,
};
switch (endian) {
.big => {
if (U == u8) out = 0 else out <<= 8; //shifting u8 by 8 is illegal in Zig
out |= byte;
},
.little => {
const pos = @as(U, byte) << @intCast(out_count);
out |= pos;
},
}
out_count += 8;
}
const bits_left = num - out_count;
const keep = 8 - bits_left;
if (bits_left == 0) return initBits(T, out, out_count);
const final_byte = self.reader.takeByte() catch |err| switch (err) {
error.EndOfStream => return initBits(T, out, out_count),
else => |e| return e,
};
switch (endian) {
.big => {
out <<= @intCast(bits_left);
out |= final_byte >> @intCast(keep);
self.bits = final_byte & low_bit_mask[keep];
},
.little => {
const pos = @as(U, final_byte & low_bit_mask[bits_left]) << @intCast(out_count);
out |= pos;
self.bits = final_byte >> @intCast(bits_left);
},
}
self.count = @intCast(keep);
return initBits(T, out, num);
}
//convenience function for removing bits from
//the appropriate part of the buffer based on
//endianess.
fn removeBits(self: *@This(), num: u4) u8 {
if (num == 8) {
self.count = 0;
return self.bits;
}
const keep = self.count - num;
const bits = switch (endian) {
.big => self.bits >> @intCast(keep),
.little => self.bits & low_bit_mask[num],
};
switch (endian) {
.big => self.bits &= low_bit_mask[keep],
.little => self.bits >>= @intCast(num),
}
self.count = keep;
return bits;
}
pub fn alignToByte(self: *@This()) void {
self.bits = 0;
self.count = 0;
}
};
}
///////////////////////////////
test "api coverage" {
const mem_be = [_]u8{ 0b11001101, 0b00001011 };
const mem_le = [_]u8{ 0b00011101, 0b10010101 };
var mem_in_be = std.io.fixedBufferStream(&mem_be);
var bit_stream_be: bit_reader.Type(.big) = .init(mem_in_be.reader());
var out_bits: u16 = undefined;
const expect = std.testing.expect;
const expectError = std.testing.expectError;
try expect(1 == try bit_stream_be.readBits(u2, 1, &out_bits));
try expect(out_bits == 1);
try expect(2 == try bit_stream_be.readBits(u5, 2, &out_bits));
try expect(out_bits == 2);
try expect(3 == try bit_stream_be.readBits(u128, 3, &out_bits));
try expect(out_bits == 3);
try expect(4 == try bit_stream_be.readBits(u8, 4, &out_bits));
try expect(out_bits == 4);
try expect(5 == try bit_stream_be.readBits(u9, 5, &out_bits));
try expect(out_bits == 5);
try expect(1 == try bit_stream_be.readBits(u1, 1, &out_bits));
try expect(out_bits == 1);
mem_in_be.pos = 0;
bit_stream_be.count = 0;
try expect(0b110011010000101 == try bit_stream_be.readBits(u15, 15, &out_bits));
try expect(out_bits == 15);
mem_in_be.pos = 0;
bit_stream_be.count = 0;
try expect(0b1100110100001011 == try bit_stream_be.readBits(u16, 16, &out_bits));
try expect(out_bits == 16);
_ = try bit_stream_be.readBits(u0, 0, &out_bits);
try expect(0 == try bit_stream_be.readBits(u1, 1, &out_bits));
try expect(out_bits == 0);
try expectError(error.EndOfStream, bit_stream_be.readBitsNoEof(u1, 1));
var mem_in_le = std.io.fixedBufferStream(&mem_le);
var bit_stream_le: bit_reader.Type(.little) = .init(mem_in_le.reader());
try expect(1 == try bit_stream_le.readBits(u2, 1, &out_bits));
try expect(out_bits == 1);
try expect(2 == try bit_stream_le.readBits(u5, 2, &out_bits));
try expect(out_bits == 2);
try expect(3 == try bit_stream_le.readBits(u128, 3, &out_bits));
try expect(out_bits == 3);
try expect(4 == try bit_stream_le.readBits(u8, 4, &out_bits));
try expect(out_bits == 4);
try expect(5 == try bit_stream_le.readBits(u9, 5, &out_bits));
try expect(out_bits == 5);
try expect(1 == try bit_stream_le.readBits(u1, 1, &out_bits));
try expect(out_bits == 1);
mem_in_le.pos = 0;
bit_stream_le.count = 0;
try expect(0b001010100011101 == try bit_stream_le.readBits(u15, 15, &out_bits));
try expect(out_bits == 15);
mem_in_le.pos = 0;
bit_stream_le.count = 0;
try expect(0b1001010100011101 == try bit_stream_le.readBits(u16, 16, &out_bits));
try expect(out_bits == 16);
_ = try bit_stream_le.readBits(u0, 0, &out_bits);
try expect(0 == try bit_stream_le.readBits(u1, 1, &out_bits));
try expect(out_bits == 0);
try expectError(error.EndOfStream, bit_stream_le.readBitsNoEof(u1, 1));
}

View File

@ -1,55 +0,0 @@
const std = @import("../std.zig");
const mem = std.mem;
const fs = std.fs;
const File = std.fs.File;
pub const BufferedAtomicFile = struct {
atomic_file: fs.AtomicFile,
file_writer: File.Writer,
buffered_writer: BufferedWriter,
allocator: mem.Allocator,
pub const buffer_size = 4096;
pub const BufferedWriter = std.io.BufferedWriter(buffer_size, File.Writer);
pub const Writer = std.io.Writer(*BufferedWriter, BufferedWriter.Error, BufferedWriter.write);
/// TODO when https://github.com/ziglang/zig/issues/2761 is solved
/// this API will not need an allocator
pub fn create(
allocator: mem.Allocator,
dir: fs.Dir,
dest_path: []const u8,
atomic_file_options: fs.Dir.AtomicFileOptions,
) !*BufferedAtomicFile {
var self = try allocator.create(BufferedAtomicFile);
self.* = BufferedAtomicFile{
.atomic_file = undefined,
.file_writer = undefined,
.buffered_writer = undefined,
.allocator = allocator,
};
errdefer allocator.destroy(self);
self.atomic_file = try dir.atomicFile(dest_path, atomic_file_options);
errdefer self.atomic_file.deinit();
self.file_writer = self.atomic_file.file.writer();
self.buffered_writer = .{ .unbuffered_writer = self.file_writer };
return self;
}
/// always call destroy, even after successful finish()
pub fn destroy(self: *BufferedAtomicFile) void {
self.atomic_file.deinit();
self.allocator.destroy(self);
}
pub fn finish(self: *BufferedAtomicFile) !void {
try self.buffered_writer.flush();
try self.atomic_file.finish();
}
pub fn writer(self: *BufferedAtomicFile) Writer {
return .{ .context = &self.buffered_writer };
}
};