From ea9ded87582a8b9d0ed3afd3360a1d75f0359a5c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 24 Jan 2023 15:04:56 -0700 Subject: [PATCH] std.compress.xz public API cleanup * add xz to std.compress * prefer importing std.zig by file name, to reduce reliance on the standard library being a special case. * extract some types from inside generic functions. These types are the same regardless of the generic parameters. * expose some more types in the std.compress.xz namespace. * rename xz.stream to xz.decompress * rename check.Kind to Check * use std.leb for LEB instead of a redundant implementation --- lib/std/compress.zig | 2 + lib/std/compress/xz.zig | 141 +++++++++++++++++- lib/std/compress/xz/block.zig | 26 ++-- lib/std/compress/xz/check.zig | 7 - lib/std/compress/xz/lzma.zig | 2 +- lib/std/compress/xz/multibyte.zig | 23 --- lib/std/compress/xz/stream.zig | 136 ----------------- .../compress/xz/{stream_test.zig => test.zig} | 6 +- 8 files changed, 157 insertions(+), 186 deletions(-) delete mode 100644 lib/std/compress/xz/check.zig delete mode 100644 lib/std/compress/xz/multibyte.zig delete mode 100644 lib/std/compress/xz/stream.zig rename lib/std/compress/xz/{stream_test.zig => test.zig} (94%) diff --git a/lib/std/compress.zig b/lib/std/compress.zig index 3c52002cfc..334d7bfcb8 100644 --- a/lib/std/compress.zig +++ b/lib/std/compress.zig @@ -3,6 +3,7 @@ const std = @import("std.zig"); pub const deflate = @import("compress/deflate.zig"); pub const gzip = @import("compress/gzip.zig"); pub const zlib = @import("compress/zlib.zig"); +pub const xz = @import("compress/xz.zig"); pub fn HashedReader( comptime ReaderType: anytype, @@ -38,4 +39,5 @@ test { _ = deflate; _ = gzip; _ = zlib; + _ = xz; } diff --git a/lib/std/compress/xz.zig b/lib/std/compress/xz.zig index 3af2d91cfb..2c56be9c77 100644 --- a/lib/std/compress/xz.zig +++ b/lib/std/compress/xz.zig @@ -1,5 +1,142 @@ -pub usingnamespace @import("xz/stream.zig"); +const std = @import("std"); +const block = @import("xz/block.zig"); +const Allocator = std.mem.Allocator; +const Crc32 = std.hash.Crc32; + +pub const Flags = packed struct(u16) { + reserved1: u8, + check_kind: Check, + reserved2: u4, +}; + +pub const Header = extern struct { + magic: [6]u8, + flags: Flags, + crc32: u32, +}; + +pub const Footer = extern struct { + crc32: u32, + backward_size: u32, + flags: Flags, + magic: [2]u8, +}; + +pub const Check = enum(u4) { + none = 0x00, + crc32 = 0x01, + crc64 = 0x04, + sha256 = 0x0A, + _, +}; + +pub fn decompress(allocator: Allocator, reader: anytype) !Decompress(@TypeOf(reader)) { + return Decompress(@TypeOf(reader)).init(allocator, reader); +} + +pub fn Decompress(comptime ReaderType: type) type { + return struct { + const Self = @This(); + + pub const Error = ReaderType.Error || block.Decoder(ReaderType).Error; + pub const Reader = std.io.Reader(*Self, Error, read); + + allocator: Allocator, + block_decoder: block.Decoder(ReaderType), + in_reader: ReaderType, + + fn init(allocator: Allocator, source: ReaderType) !Self { + const header = try source.readStruct(Header); + + if (!std.mem.eql(u8, &header.magic, &.{ 0xFD, '7', 'z', 'X', 'Z', 0x00 })) + return error.BadHeader; + + if (header.flags.reserved1 != 0 or header.flags.reserved2 != 0) + return error.BadHeader; + + const hash = Crc32.hash(std.mem.asBytes(&header.flags)); + if (hash != header.crc32) + return error.WrongChecksum; + + return Self{ + .allocator = allocator, + .block_decoder = try block.decoder(allocator, source, header.flags.check_kind), + .in_reader = source, + }; + } + + pub fn deinit(self: *Self) void { + self.block_decoder.deinit(); + } + + pub fn reader(self: *Self) Reader { + return .{ .context = self }; + } + + pub fn read(self: *Self, buffer: []u8) Error!usize { + if (buffer.len == 0) + return 0; + + const r = try self.block_decoder.read(buffer); + if (r != 0) + return r; + + const index_size = blk: { + var hasher = std.compress.hashedReader(self.in_reader, Crc32.init()); + hasher.hasher.update(&[1]u8{0x00}); + + var counter = std.io.countingReader(hasher.reader()); + counter.bytes_read += 1; + + const counting_reader = counter.reader(); + + const record_count = try std.leb.readULEB128(u64, counting_reader); + if (record_count != self.block_decoder.block_count) + return error.CorruptInput; + + var i: usize = 0; + while (i < record_count) : (i += 1) { + // TODO: validate records + _ = try std.leb.readULEB128(u64, counting_reader); + _ = try std.leb.readULEB128(u64, counting_reader); + } + + while (counter.bytes_read % 4 != 0) { + if (try counting_reader.readByte() != 0) + return error.CorruptInput; + } + + const hash_a = hasher.hasher.final(); + const hash_b = try counting_reader.readIntLittle(u32); + if (hash_a != hash_b) + return error.WrongChecksum; + + break :blk counter.bytes_read; + }; + + const footer = try self.in_reader.readStruct(Footer); + const backward_size = (footer.backward_size + 1) * 4; + if (backward_size != index_size) + return error.CorruptInput; + + if (footer.flags.reserved1 != 0 or footer.flags.reserved2 != 0) + return error.CorruptInput; + + var hasher = Crc32.init(); + hasher.update(std.mem.asBytes(&footer.backward_size)); + hasher.update(std.mem.asBytes(&footer.flags)); + const hash = hasher.final(); + if (hash != footer.crc32) + return error.WrongChecksum; + + if (!std.mem.eql(u8, &footer.magic, &.{ 'Y', 'Z' })) + return error.CorruptInput; + + return 0; + } + }; +} test { - _ = @import("xz/stream.zig"); + _ = @import("xz/test.zig"); } diff --git a/lib/std/compress/xz/block.zig b/lib/std/compress/xz/block.zig index 27b2fc0b5f..1ceaea4984 100644 --- a/lib/std/compress/xz/block.zig +++ b/lib/std/compress/xz/block.zig @@ -1,11 +1,10 @@ -const std = @import("std"); -const check = @import("check.zig"); +const std = @import("../../std.zig"); const lzma = @import("lzma.zig"); -const multibyte = @import("multibyte.zig"); const Allocator = std.mem.Allocator; const Crc32 = std.hash.Crc32; const Crc64 = std.hash.crc.Crc64Xz; const Sha256 = std.crypto.hash.sha2.Sha256; +const xz = std.compress.xz; const DecodeError = error{ CorruptInput, @@ -16,8 +15,8 @@ const DecodeError = error{ Overflow, }; -pub fn decoder(allocator: Allocator, reader: anytype, check_kind: check.Kind) !Decoder(@TypeOf(reader)) { - return Decoder(@TypeOf(reader)).init(allocator, reader, check_kind); +pub fn decoder(allocator: Allocator, reader: anytype, check: xz.Check) !Decoder(@TypeOf(reader)) { + return Decoder(@TypeOf(reader)).init(allocator, reader, check); } pub fn Decoder(comptime ReaderType: type) type { @@ -31,17 +30,17 @@ pub fn Decoder(comptime ReaderType: type) type { allocator: Allocator, inner_reader: ReaderType, - check_kind: check.Kind, + check: xz.Check, err: ?Error, accum: lzma.LzAccumBuffer, lzma_state: lzma.DecoderState, block_count: usize, - fn init(allocator: Allocator, in_reader: ReaderType, check_kind: check.Kind) !Self { + fn init(allocator: Allocator, in_reader: ReaderType, check: xz.Check) !Self { return Self{ .allocator = allocator, .inner_reader = in_reader, - .check_kind = check_kind, + .check = check, .err = null, .accum = .{}, .lzma_state = try lzma.DecoderState.init(allocator), @@ -116,10 +115,10 @@ pub fn Decoder(comptime ReaderType: type) type { return error.Unsupported; if (flags.has_packed_size) - packed_size = try multibyte.readInt(header_reader); + packed_size = try std.leb.readULEB128(u64, header_reader); if (flags.has_unpacked_size) - unpacked_size = try multibyte.readInt(header_reader); + unpacked_size = try std.leb.readULEB128(u64, header_reader); const FilterId = enum(u64) { lzma2 = 0x21, @@ -128,7 +127,7 @@ pub fn Decoder(comptime ReaderType: type) type { const filter_id = @intToEnum( FilterId, - try multibyte.readInt(header_reader), + try std.leb.readULEB128(u64, header_reader), ); if (@enumToInt(filter_id) >= 0x4000_0000_0000_0000) @@ -137,7 +136,7 @@ pub fn Decoder(comptime ReaderType: type) type { if (filter_id != .lzma2) return error.Unsupported; - const properties_size = try multibyte.readInt(header_reader); + const properties_size = try std.leb.readULEB128(u64, header_reader); if (properties_size != 1) return error.CorruptInput; @@ -177,8 +176,7 @@ pub fn Decoder(comptime ReaderType: type) type { return error.CorruptInput; } - // Check - switch (self.check_kind) { + switch (self.check) { .none => {}, .crc32 => { const hash_a = Crc32.hash(unpacked_bytes); diff --git a/lib/std/compress/xz/check.zig b/lib/std/compress/xz/check.zig deleted file mode 100644 index 20151ad4cf..0000000000 --- a/lib/std/compress/xz/check.zig +++ /dev/null @@ -1,7 +0,0 @@ -pub const Kind = enum(u4) { - none = 0x00, - crc32 = 0x01, - crc64 = 0x04, - sha256 = 0x0A, - _, -}; diff --git a/lib/std/compress/xz/lzma.zig b/lib/std/compress/xz/lzma.zig index ead707e0be..9fe941e2b1 100644 --- a/lib/std/compress/xz/lzma.zig +++ b/lib/std/compress/xz/lzma.zig @@ -1,6 +1,6 @@ // Ported from https://github.com/gendx/lzma-rs -const std = @import("std"); +const std = @import("../../std.zig"); const assert = std.debug.assert; const Allocator = std.mem.Allocator; const ArrayListUnmanaged = std.ArrayListUnmanaged; diff --git a/lib/std/compress/xz/multibyte.zig b/lib/std/compress/xz/multibyte.zig deleted file mode 100644 index 1226ffcfb2..0000000000 --- a/lib/std/compress/xz/multibyte.zig +++ /dev/null @@ -1,23 +0,0 @@ -const Multibyte = packed struct(u8) { - value: u7, - more: bool, -}; - -pub fn readInt(reader: anytype) !u64 { - const max_size = 9; - - var chunk = try reader.readStruct(Multibyte); - var num: u64 = chunk.value; - var i: u6 = 0; - - while (chunk.more) { - chunk = try reader.readStruct(Multibyte); - i += 1; - if (i >= max_size or @bitCast(u8, chunk) == 0x00) - return error.CorruptInput; - - num |= @as(u64, chunk.value) << (i * 7); - } - - return num; -} diff --git a/lib/std/compress/xz/stream.zig b/lib/std/compress/xz/stream.zig deleted file mode 100644 index 33916e20df..0000000000 --- a/lib/std/compress/xz/stream.zig +++ /dev/null @@ -1,136 +0,0 @@ -const std = @import("std"); -const block = @import("block.zig"); -const check = @import("check.zig"); -const multibyte = @import("multibyte.zig"); -const Allocator = std.mem.Allocator; -const Crc32 = std.hash.Crc32; - -test { - _ = @import("stream_test.zig"); -} - -const Flags = packed struct(u16) { - reserved1: u8, - check_kind: check.Kind, - reserved2: u4, -}; - -pub fn stream(allocator: Allocator, reader: anytype) !Stream(@TypeOf(reader)) { - return Stream(@TypeOf(reader)).init(allocator, reader); -} - -pub fn Stream(comptime ReaderType: type) type { - return struct { - const Self = @This(); - - pub const Error = ReaderType.Error || block.Decoder(ReaderType).Error; - pub const Reader = std.io.Reader(*Self, Error, read); - - allocator: Allocator, - block_decoder: block.Decoder(ReaderType), - in_reader: ReaderType, - - fn init(allocator: Allocator, source: ReaderType) !Self { - const Header = extern struct { - magic: [6]u8, - flags: Flags, - crc32: u32, - }; - - const header = try source.readStruct(Header); - - if (!std.mem.eql(u8, &header.magic, &.{ 0xFD, '7', 'z', 'X', 'Z', 0x00 })) - return error.BadHeader; - - if (header.flags.reserved1 != 0 or header.flags.reserved2 != 0) - return error.BadHeader; - - const hash = Crc32.hash(std.mem.asBytes(&header.flags)); - if (hash != header.crc32) - return error.WrongChecksum; - - return Self{ - .allocator = allocator, - .block_decoder = try block.decoder(allocator, source, header.flags.check_kind), - .in_reader = source, - }; - } - - pub fn deinit(self: *Self) void { - self.block_decoder.deinit(); - } - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - - pub fn read(self: *Self, buffer: []u8) Error!usize { - if (buffer.len == 0) - return 0; - - const r = try self.block_decoder.read(buffer); - if (r != 0) - return r; - - const index_size = blk: { - var hasher = std.compress.hashedReader(self.in_reader, Crc32.init()); - hasher.hasher.update(&[1]u8{0x00}); - - var counter = std.io.countingReader(hasher.reader()); - counter.bytes_read += 1; - - const counting_reader = counter.reader(); - - const record_count = try multibyte.readInt(counting_reader); - if (record_count != self.block_decoder.block_count) - return error.CorruptInput; - - var i: usize = 0; - while (i < record_count) : (i += 1) { - // TODO: validate records - _ = try multibyte.readInt(counting_reader); - _ = try multibyte.readInt(counting_reader); - } - - while (counter.bytes_read % 4 != 0) { - if (try counting_reader.readByte() != 0) - return error.CorruptInput; - } - - const hash_a = hasher.hasher.final(); - const hash_b = try counting_reader.readIntLittle(u32); - if (hash_a != hash_b) - return error.WrongChecksum; - - break :blk counter.bytes_read; - }; - - const Footer = extern struct { - crc32: u32, - backward_size: u32, - flags: Flags, - magic: [2]u8, - }; - - const footer = try self.in_reader.readStruct(Footer); - const backward_size = (footer.backward_size + 1) * 4; - if (backward_size != index_size) - return error.CorruptInput; - - if (footer.flags.reserved1 != 0 or footer.flags.reserved2 != 0) - return error.CorruptInput; - - var hasher = Crc32.init(); - hasher.update(std.mem.asBytes(&footer.backward_size)); - hasher.update(std.mem.asBytes(&footer.flags)); - const hash = hasher.final(); - if (hash != footer.crc32) - return error.WrongChecksum; - - if (!std.mem.eql(u8, &footer.magic, &.{ 'Y', 'Z' })) - return error.CorruptInput; - - return 0; - } - }; -} diff --git a/lib/std/compress/xz/stream_test.zig b/lib/std/compress/xz/test.zig similarity index 94% rename from lib/std/compress/xz/stream_test.zig rename to lib/std/compress/xz/test.zig index beaeedf535..848f518c78 100644 --- a/lib/std/compress/xz/stream_test.zig +++ b/lib/std/compress/xz/test.zig @@ -1,11 +1,11 @@ -const std = @import("std"); +const std = @import("../../std.zig"); const testing = std.testing; -const stream = @import("stream.zig").stream; +const xz = std.compress.xz; fn decompress(data: []const u8) ![]u8 { var in_stream = std.io.fixedBufferStream(data); - var xz_stream = try stream(testing.allocator, in_stream.reader()); + var xz_stream = try xz.decompress(testing.allocator, in_stream.reader()); defer xz_stream.deinit(); return xz_stream.reader().readAllAlloc(testing.allocator, std.math.maxInt(usize));