diff --git a/lib/std/compress/flate.zig b/lib/std/compress/flate.zig index 6a111ac0fc..a93e03a826 100644 --- a/lib/std/compress/flate.zig +++ b/lib/std/compress/flate.zig @@ -1,3 +1,5 @@ +const std = @import("../std.zig"); + /// Deflate is a lossless data compression file format that uses a combination /// of LZ77 and Huffman coding. pub const deflate = @import("flate/deflate.zig"); @@ -7,77 +9,48 @@ pub const deflate = @import("flate/deflate.zig"); pub const inflate = @import("flate/inflate.zig"); /// Decompress compressed data from reader and write plain data to the writer. -pub fn decompress(reader: anytype, writer: anytype) !void { +pub fn decompress(reader: *std.io.BufferedReader, writer: *std.io.BufferedWriter) anyerror!void { try inflate.decompress(.raw, reader, writer); } -/// Decompressor type -pub fn Decompressor(comptime ReaderType: type) type { - return inflate.Decompressor(.raw, ReaderType); -} - -/// Create Decompressor which will read compressed data from reader. -pub fn decompressor(reader: anytype) Decompressor(@TypeOf(reader)) { - return inflate.decompressor(.raw, reader); -} +pub const Decompressor = inflate.Decompressor(.raw); /// Compression level, trades between speed and compression size. pub const Options = deflate.Options; /// Compress plain data from reader and write compressed data to the writer. -pub fn compress(reader: anytype, writer: anytype, options: Options) !void { +pub fn compress(reader: *std.io.BufferedReader, writer: *std.io.BufferedWriter, options: Options) anyerror!void { try deflate.compress(.raw, reader, writer, options); } -/// Compressor type -pub fn Compressor(comptime WriterType: type) type { - return deflate.Compressor(.raw, WriterType); -} - -/// Create Compressor which outputs compressed data to the writer. -pub fn compressor(writer: anytype, options: Options) !Compressor(@TypeOf(writer)) { - return try deflate.compressor(.raw, writer, options); -} +pub const Compressor = deflate.Compressor(.raw); /// Huffman only compression. Without Lempel-Ziv match searching. Faster /// compression, less memory requirements but bigger compressed sizes. pub const huffman = struct { - pub fn compress(reader: anytype, writer: anytype) !void { + pub fn compress(reader: *std.io.BufferedReader, writer: *std.io.BufferedWriter) anyerror!void { try deflate.huffman.compress(.raw, reader, writer); } - pub fn Compressor(comptime WriterType: type) type { - return deflate.huffman.Compressor(.raw, WriterType); - } - - pub fn compressor(writer: anytype) !huffman.Compressor(@TypeOf(writer)) { - return deflate.huffman.compressor(.raw, writer); - } + pub const Compressor = deflate.huffman.Compressor(.raw); }; // No compression store only. Compressed size is slightly bigger than plain. pub const store = struct { - pub fn compress(reader: anytype, writer: anytype) !void { + pub fn compress(reader: *std.io.BufferedReader, writer: *std.io.BufferedWriter) anyerror!void { try deflate.store.compress(.raw, reader, writer); } - pub fn Compressor(comptime WriterType: type) type { - return deflate.store.Compressor(.raw, WriterType); - } - - pub fn compressor(writer: anytype) !store.Compressor(@TypeOf(writer)) { - return deflate.store.compressor(.raw, writer); - } + pub const Compressor = deflate.store.Compressor(.raw); }; -/// Container defines header/footer around deflate bit stream. Gzip and zlib -/// compression algorithms are containers around deflate bit stream body. -const Container = @import("flate/container.zig").Container; -const std = @import("std"); +const builtin = @import("builtin"); const testing = std.testing; const fixedBufferStream = std.io.fixedBufferStream; const print = std.debug.print; -const builtin = @import("builtin"); +/// Container defines header/footer around deflate bit stream. Gzip and zlib +/// compression algorithms are containers around deflate bit stream body. +const Container = @import("flate/container.zig").Container; test { _ = deflate; diff --git a/lib/std/compress/flate/bit_reader.zig b/lib/std/compress/flate/bit_reader.zig deleted file mode 100644 index a68fe096ca..0000000000 --- a/lib/std/compress/flate/bit_reader.zig +++ /dev/null @@ -1,421 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const testing = std.testing; - -pub const Flags = packed struct(u3) { - /// dont advance internal buffer, just get bits, leave them in buffer - peek: bool = false, - /// assume that there is no need to fill, fill should be called before - buffered: bool = false, - /// bit reverse read bits - reverse: bool = false, -}; - -/// Bit reader used during inflate (decompression). Has internal buffer of 64 -/// bits which shifts right after bits are consumed. Uses forward_reader to fill -/// that internal buffer when needed. -/// -/// readF is the core function. Supports few different ways of getting bits -/// controlled by flags. In hot path we try to avoid checking whether we need to -/// fill buffer from forward_reader by calling fill in advance and readF with -/// buffered flag set. -/// -pub fn BitReader(comptime T: type) type { - assert(T == u32 or T == u64); - const t_bytes: usize = @sizeOf(T); - const Tshift = if (T == u64) u6 else u5; - - return struct { - // Underlying reader used for filling internal bits buffer - forward_reader: *std.io.BufferedReader, - // Internal buffer of 64 bits - bits: T = 0, - // Number of bits in the buffer - nbits: u32 = 0, - - const Self = @This(); - - pub fn init(forward_reader: *std.io.BufferedReader) Self { - var self = Self{ .forward_reader = forward_reader }; - self.fill(1) catch {}; - return self; - } - - /// Try to have `nice` bits are available in buffer. Reads from - /// forward reader if there is no `nice` bits in buffer. Returns error - /// if end of forward stream is reached and internal buffer is empty. - /// It will not error if less than `nice` bits are in buffer, only when - /// all bits are exhausted. During inflate we usually know what is the - /// maximum bits for the next step but usually that step will need less - /// bits to decode. So `nice` is not hard limit, it will just try to have - /// that number of bits available. If end of forward stream is reached - /// it may be some extra zero bits in buffer. - pub fn fill(self: *Self, nice: u6) !void { - if (self.nbits >= nice and nice != 0) { - return; // We have enough bits - } - // Read more bits from forward reader - - // Number of empty bytes in bits, round nbits to whole bytes. - const empty_bytes = - @as(u8, if (self.nbits & 0x7 == 0) t_bytes else t_bytes - 1) - // 8 for 8, 16, 24..., 7 otherwise - (self.nbits >> 3); // 0 for 0-7, 1 for 8-16, ... same as / 8 - - var buf: [t_bytes]u8 = [_]u8{0} ** t_bytes; - const bytes_read = self.forward_reader.readAll(buf[0..empty_bytes]) catch 0; - if (bytes_read > 0) { - const u: T = std.mem.readInt(T, buf[0..t_bytes], .little); - self.bits |= u << @as(Tshift, @intCast(self.nbits)); - self.nbits += 8 * @as(u8, @intCast(bytes_read)); - return; - } - - if (self.nbits == 0) - return error.EndOfStream; - } - - /// Read exactly buf.len bytes into buf. - pub fn readAll(self: *Self, buf: []u8) !void { - assert(self.alignBits() == 0); // internal bits must be at byte boundary - - // First read from internal bits buffer. - var n: usize = 0; - while (self.nbits > 0 and n < buf.len) { - buf[n] = try self.readF(u8, .{ .buffered = true }); - n += 1; - } - // Then use forward reader for all other bytes. - try self.forward_reader.readNoEof(buf[n..]); - } - - /// Alias for readF(U, 0). - pub fn read(self: *Self, comptime U: type) !U { - return self.readF(U, 0); - } - - /// Alias for readF with flag.peak set. - pub inline fn peekF(self: *Self, comptime U: type, comptime how: Flags) !U { - return self.readF(U, .{ - .peek = true, - .buffered = how.buffered, - .reverse = how.reverse, - }); - } - - /// Read with flags provided. - pub fn readF(self: *Self, comptime U: type, comptime how: Flags) !U { - if (U == T) { - assert(how == 0); - assert(self.alignBits() == 0); - try self.fill(@bitSizeOf(T)); - if (self.nbits != @bitSizeOf(T)) return error.EndOfStream; - const v = self.bits; - self.nbits = 0; - self.bits = 0; - return v; - } - const n: Tshift = @bitSizeOf(U); - switch (how) { - 0 => { // `normal` read - try self.fill(n); // ensure that there are n bits in the buffer - const u: U = @truncate(self.bits); // get n bits - try self.shift(n); // advance buffer for n - return u; - }, - .{ .peek = true } => { // no shift, leave bits in the buffer - try self.fill(n); - return @truncate(self.bits); - }, - .{ .buffered = true } => { // no fill, assume that buffer has enough bits - const u: U = @truncate(self.bits); - try self.shift(n); - return u; - }, - .{ .reverse = true } => { // same as 0 with bit reverse - try self.fill(n); - const u: U = @truncate(self.bits); - try self.shift(n); - return @bitReverse(u); - }, - .{ .peek = true, .reverse = true } => { - try self.fill(n); - return @bitReverse(@as(U, @truncate(self.bits))); - }, - .{ .buffered = true, .reverse = true } => { - const u: U = @truncate(self.bits); - try self.shift(n); - return @bitReverse(u); - }, - .{ .peek = true, .buffered = true }, - => { - return @truncate(self.bits); - }, - .{ .peek = true, .buffered = true, .reverse = true } => { - return @bitReverse(@as(U, @truncate(self.bits))); - }, - } - } - - /// Read n number of bits. - /// Only buffered flag can be used in how. - pub fn readN(self: *Self, n: u4, comptime how: u3) !u16 { - switch (how) { - 0 => { - try self.fill(n); - }, - .{ .buffered = true } => {}, - else => unreachable, - } - const mask: u16 = (@as(u16, 1) << n) - 1; - const u: u16 = @as(u16, @truncate(self.bits)) & mask; - try self.shift(n); - return u; - } - - /// Advance buffer for n bits. - pub fn shift(self: *Self, n: Tshift) !void { - if (n > self.nbits) return error.EndOfStream; - self.bits >>= n; - self.nbits -= n; - } - - /// Skip n bytes. - pub fn skipBytes(self: *Self, n: u16) !void { - for (0..n) |_| { - try self.fill(8); - try self.shift(8); - } - } - - // Number of bits to align stream to the byte boundary. - fn alignBits(self: *Self) u3 { - return @intCast(self.nbits & 0x7); - } - - /// Align stream to the byte boundary. - pub fn alignToByte(self: *Self) void { - const ab = self.alignBits(); - if (ab > 0) self.shift(ab) catch unreachable; - } - - /// Skip zero terminated string. - pub fn skipStringZ(self: *Self) !void { - while (true) { - if (try self.readF(u8, 0) == 0) break; - } - } - - /// Read deflate fixed fixed code. - /// Reads first 7 bits, and then maybe 1 or 2 more to get full 7,8 or 9 bit code. - /// ref: https://datatracker.ietf.org/doc/html/rfc1951#page-12 - /// Lit Value Bits Codes - /// --------- ---- ----- - /// 0 - 143 8 00110000 through - /// 10111111 - /// 144 - 255 9 110010000 through - /// 111111111 - /// 256 - 279 7 0000000 through - /// 0010111 - /// 280 - 287 8 11000000 through - /// 11000111 - pub fn readFixedCode(self: *Self) !u16 { - try self.fill(7 + 2); - const code7 = try self.readF(u7, .{ .buffered = true, .reverse = true }); - if (code7 <= 0b0010_111) { // 7 bits, 256-279, codes 0000_000 - 0010_111 - return @as(u16, code7) + 256; - } else if (code7 <= 0b1011_111) { // 8 bits, 0-143, codes 0011_0000 through 1011_1111 - return (@as(u16, code7) << 1) + @as(u16, try self.readF(u1, .{ .buffered = true })) - 0b0011_0000; - } else if (code7 <= 0b1100_011) { // 8 bit, 280-287, codes 1100_0000 - 1100_0111 - return (@as(u16, code7 - 0b1100000) << 1) + try self.readF(u1, .{ .buffered = true }) + 280; - } else { // 9 bit, 144-255, codes 1_1001_0000 - 1_1111_1111 - return (@as(u16, code7 - 0b1100_100) << 2) + @as(u16, try self.readF(u2, .{ .buffered = true, .reverse = true })) + 144; - } - } - }; -} - -test "readF" { - var input: std.io.BufferedReader = undefined; - input.initFixed(&[_]u8{ 0xf3, 0x48, 0xcd, 0xc9, 0x00, 0x00 }); - var br: BitReader(u64) = .init(&input); - - try testing.expectEqual(@as(u8, 48), br.nbits); - try testing.expectEqual(@as(u64, 0xc9cd48f3), br.bits); - - try testing.expect(try br.readF(u1, 0) == 0b0000_0001); - try testing.expect(try br.readF(u2, 0) == 0b0000_0001); - try testing.expectEqual(@as(u8, 48 - 3), br.nbits); - try testing.expectEqual(@as(u3, 5), br.alignBits()); - - try testing.expect(try br.readF(u8, .{ .peek = true }) == 0b0001_1110); - try testing.expect(try br.readF(u9, .{ .peek = true }) == 0b1_0001_1110); - try br.shift(9); - try testing.expectEqual(@as(u8, 36), br.nbits); - try testing.expectEqual(@as(u3, 4), br.alignBits()); - - try testing.expect(try br.readF(u4, 0) == 0b0100); - try testing.expectEqual(@as(u8, 32), br.nbits); - try testing.expectEqual(@as(u3, 0), br.alignBits()); - - try br.shift(1); - try testing.expectEqual(@as(u3, 7), br.alignBits()); - try br.shift(1); - try testing.expectEqual(@as(u3, 6), br.alignBits()); - br.alignToByte(); - try testing.expectEqual(@as(u3, 0), br.alignBits()); - - try testing.expectEqual(@as(u64, 0xc9), br.bits); - try testing.expectEqual(@as(u16, 0x9), try br.readN(4, 0)); - try testing.expectEqual(@as(u16, 0xc), try br.readN(4, 0)); -} - -test "read block type 1 data" { - inline for ([_]type{ u64, u32 }) |T| { - const data = [_]u8{ - 0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1 - 0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00, - 0x0c, 0x01, 0x02, 0x03, // - 0xaa, 0xbb, 0xcc, 0xdd, - }; - var fbs: std.io.BufferedReader = undefined; - fbs.initFixed(&data); - var br: BitReader(T) = .init(&fbs); - - try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal - try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type - - for ("Hello world\n") |c| { - try testing.expectEqual(@as(u8, c), try br.readF(u8, .{ .reverse = true }) - 0x30); - } - try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block - br.alignToByte(); - try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0)); - try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0)); - try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0)); - } -} - -test "shift/fill" { - const data = [_]u8{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - }; - var fbs: std.io.BufferedReader = undefined; - fbs.initFixed(&data); - var br: BitReader(u64) = .init(&fbs); - - try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits); - try br.shift(8); - try testing.expectEqual(@as(u64, 0x00_08_07_06_05_04_03_02), br.bits); - try br.fill(60); // fill with 1 byte - try testing.expectEqual(@as(u64, 0x01_08_07_06_05_04_03_02), br.bits); - try br.shift(8 * 4 + 4); - try testing.expectEqual(@as(u64, 0x00_00_00_00_00_10_80_70), br.bits); - - try br.fill(60); // fill with 4 bytes (shift by 4) - try testing.expectEqual(@as(u64, 0x00_50_40_30_20_10_80_70), br.bits); - try testing.expectEqual(@as(u8, 8 * 7 + 4), br.nbits); - - try br.shift(@intCast(br.nbits)); // clear buffer - try br.fill(8); // refill with the rest of the bytes - try testing.expectEqual(@as(u64, 0x00_00_00_00_00_08_07_06), br.bits); -} - -test "readAll" { - inline for ([_]type{ u64, u32 }) |T| { - const data = [_]u8{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - }; - var fbs: std.io.BufferedReader = undefined; - fbs.initFixed(&data); - var br: BitReader(T) = .init(&fbs); - - switch (T) { - u64 => try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits), - u32 => try testing.expectEqual(@as(u32, 0x04_03_02_01), br.bits), - else => unreachable, - } - - var out: [16]u8 = undefined; - try br.readAll(out[0..]); - try testing.expect(br.nbits == 0); - try testing.expect(br.bits == 0); - - try testing.expectEqualSlices(u8, data[0..16], &out); - } -} - -test "readFixedCode" { - inline for ([_]type{ u64, u32 }) |T| { - const fixed_codes = @import("huffman_encoder.zig").fixed_codes; - - var fbs: std.io.BufferedReader = undefined; - fbs.initFixed(&fixed_codes); - var rdr: BitReader(T) = .init(&fbs); - - for (0..286) |c| { - try testing.expectEqual(c, try rdr.readFixedCode()); - } - try testing.expect(rdr.nbits == 0); - } -} - -test "u32 leaves no bits on u32 reads" { - const data = [_]u8{ - 0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - }; - var fbs: std.io.BufferedReader = undefined; - fbs.initFixed(&data); - var br: BitReader(u32) = .init(&fbs); - - _ = try br.read(u3); - try testing.expectEqual(29, br.nbits); - br.alignToByte(); - try testing.expectEqual(24, br.nbits); - try testing.expectEqual(0x04_03_02_01, try br.read(u32)); - try testing.expectEqual(0, br.nbits); - try testing.expectEqual(0x08_07_06_05, try br.read(u32)); - try testing.expectEqual(0, br.nbits); - - _ = try br.read(u9); - try testing.expectEqual(23, br.nbits); - br.alignToByte(); - try testing.expectEqual(16, br.nbits); - try testing.expectEqual(0x0e_0d_0c_0b, try br.read(u32)); - try testing.expectEqual(0, br.nbits); -} - -test "u64 need fill after alignToByte" { - const data = [_]u8{ - 0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - }; - - // without fill - var fbs: std.io.BufferedReader = undefined; - fbs.initFixed(&data); - var br: BitReader(u64) = .init(&fbs); - _ = try br.read(u23); - try testing.expectEqual(41, br.nbits); - br.alignToByte(); - try testing.expectEqual(40, br.nbits); - try testing.expectEqual(0x06_05_04_03, try br.read(u32)); - try testing.expectEqual(8, br.nbits); - try testing.expectEqual(0x0a_09_08_07, try br.read(u32)); - try testing.expectEqual(32, br.nbits); - - // fill after align ensures all bits filled - fbs.reset(); - br = .init(&fbs); - _ = try br.read(u23); - try testing.expectEqual(41, br.nbits); - br.alignToByte(); - try br.fill(0); - try testing.expectEqual(64, br.nbits); - try testing.expectEqual(0x06_05_04_03, try br.read(u32)); - try testing.expectEqual(32, br.nbits); - try testing.expectEqual(0x0a_09_08_07, try br.read(u32)); - try testing.expectEqual(0, br.nbits); -} diff --git a/lib/std/compress/flate/inflate.zig b/lib/std/compress/flate/inflate.zig index fcc73c9878..de0eecc7e3 100644 --- a/lib/std/compress/flate/inflate.zig +++ b/lib/std/compress/flate/inflate.zig @@ -3,7 +3,6 @@ const assert = std.debug.assert; const testing = std.testing; const hfd = @import("huffman_decoder.zig"); -const BitReader = @import("bit_reader.zig").BitReader; const CircularBuffer = @import("CircularBuffer.zig"); const Container = @import("container.zig").Container; const Token = @import("Token.zig"); @@ -48,16 +47,14 @@ pub fn Decompressor(comptime container: Container) type { /// * 64K for history (CircularBuffer) /// * ~10K huffman decoders (Literal and DistanceDecoder) /// -pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type { - assert(LookaheadType == u32 or LookaheadType == u64); - const BitReaderType = BitReader(LookaheadType); +pub fn Inflate(comptime container: Container, comptime Lookahead: type) type { + assert(Lookahead == u32 or Lookahead == u64); + const LookaheadBitReader = BitReader(Lookahead); return struct { - const F = BitReaderType.flag; - - bits: BitReaderType, + bits: LookaheadBitReader, hist: CircularBuffer = .{}, - // Hashes, produces checkusm, of uncompressed data for gzip/zlib footer. + // Hashes, produces checksum, of uncompressed data for gzip/zlib footer. hasher: container.Hasher() = .{}, // dynamic block huffman code decoders @@ -79,7 +76,7 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type const Self = @This(); - pub const Error = BitReaderType.Error || Container.Error || hfd.Error || error{ + pub const Error = anyerror || Container.Error || hfd.Error || error{ InvalidCode, InvalidMatch, InvalidBlockType, @@ -88,10 +85,10 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type }; pub fn init(bw: *std.io.BufferedReader) Self { - return .{ .bits = BitReaderType.init(bw) }; + return .{ .bits = LookaheadBitReader.init(bw) }; } - fn blockHeader(self: *Self) !void { + fn blockHeader(self: *Self) anyerror!void { self.bfinal = try self.bits.read(u1); self.block_type = try self.bits.read(u2); } @@ -129,7 +126,10 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type fn fixedDistanceCode(self: *Self, code: u8) !void { try self.bits.fill(5 + 5 + 13); const length = try self.decodeLength(code); - const distance = try self.decodeDistance(try self.bits.readF(u5, F.buffered | F.reverse)); + const distance = try self.decodeDistance(try self.bits.readF(u5, .{ + .buffered = true, + .reverse = true, + })); try self.hist.writeMatch(length, distance); } @@ -139,7 +139,7 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type return if (ml.extra_bits == 0) // 0 - 5 extra bits ml.base else - ml.base + try self.bits.readN(ml.extra_bits, F.buffered); + ml.base + try self.bits.readN(ml.extra_bits, .{ .buffered = true }); } fn decodeDistance(self: *Self, code: u8) !u16 { @@ -148,7 +148,7 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type return if (md.extra_bits == 0) // 0 - 13 extra bits md.base else - md.base + try self.bits.readN(md.extra_bits, F.buffered); + md.base + try self.bits.readN(md.extra_bits, .{ .buffered = true }); } fn dynamicBlockHeader(self: *Self) !void { @@ -171,7 +171,7 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type var dec_lens = [_]u4{0} ** (286 + 30); var pos: usize = 0; while (pos < hlit + hdist) { - const sym = try cl_dec.find(try self.bits.peekF(u7, F.reverse)); + const sym = try cl_dec.find(try self.bits.peekF(u7, .{ .reverse = true })); try self.bits.shift(sym.code_bits); pos += try self.dynamicCodeLength(sym.symbol, &dec_lens, pos); } @@ -230,13 +230,13 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type .literal => self.hist.write(sym.symbol), .match => { // Decode match backreference // fill so we can use buffered reads - if (LookaheadType == u32) + if (Lookahead == u32) try self.bits.fill(5 + 15) else try self.bits.fill(5 + 15 + 13); const length = try self.decodeLength(sym.symbol); const dsm = try self.decodeSymbol(&self.dst_dec); - if (LookaheadType == u32) try self.bits.fill(13); + if (Lookahead == u32) try self.bits.fill(13); const distance = try self.decodeDistance(dsm.symbol); try self.hist.writeMatch(length, distance); }, @@ -251,7 +251,7 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type // used. Shift bit reader for that much bits, those bits are used. And // return symbol. fn decodeSymbol(self: *Self, decoder: anytype) !hfd.Symbol { - const sym = try decoder.find(try self.bits.peekF(u15, F.buffered | F.reverse)); + const sym = try decoder.find(try self.bits.peekF(u15, .{ .buffered = true, .reverse = true })); try self.bits.shift(sym.code_bits); return sym; } @@ -338,22 +338,48 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type) type } } - // Reader interface - - pub const Reader = std.io.Reader(*Self, Error, read); - - /// Returns the number of bytes read. It may be less than buffer.len. - /// If the number of bytes read is 0, it means end of stream. - /// End of stream is not an error condition. - pub fn read(self: *Self, buffer: []u8) Error!usize { - if (buffer.len == 0) return 0; - const out = try self.get(buffer.len); - @memcpy(buffer[0..out.len], out); - return out.len; + fn reader_streamRead( + ctx: ?*anyopaque, + bw: *std.io.BufferedWriter, + limit: std.io.Reader.Limit, + ) std.io.Reader.RwResult { + const self: *Self = @alignCast(@ptrCast(ctx)); + const out = bw.writableSlice(1) catch |err| return .{ .write_err = err }; + const in = self.get(limit.min(out.len)) catch |err| return .{ .read_err = err }; + if (in.len == 0) return .{ .read_end = true }; + @memcpy(out[0..in.len], in); + return .{ .len = in.len }; } - pub fn reader(self: *Self) Reader { - return .{ .context = self }; + fn reader_streamReadVec(ctx: ?*anyopaque, data: []const []u8) std.io.Reader.Result { + const self: *Self = @alignCast(@ptrCast(ctx)); + var total: usize = 0; + for (data) |buffer| { + if (buffer.len == 0) break; + const out = self.get(buffer.len) catch |err| { + return .{ .len = total, .err = err }; + }; + if (out.len == 0) break; + @memcpy(buffer[0..out.len], out); + total += out.len; + } + return .{ .len = total, .end = total == 0 }; + } + + pub fn streamReadVec(self: *Self, data: []const []u8) std.io.Reader.Result { + return reader_streamReadVec(self, data); + } + + pub fn reader(self: *Self) std.io.Reader { + return .{ + .context = self, + .vtable = &.{ + .posRead = null, + .posReadVec = null, + .streamRead = reader_streamRead, + .streamReadVec = reader_streamReadVec, + }, + }; } }; } @@ -567,3 +593,427 @@ test "bug 19895" { var buf: [0]u8 = undefined; try testing.expectEqual(0, try decomp.read(&buf)); } + +/// Bit reader used during inflate (decompression). Has internal buffer of 64 +/// bits which shifts right after bits are consumed. Uses forward_reader to fill +/// that internal buffer when needed. +/// +/// readF is the core function. Supports few different ways of getting bits +/// controlled by flags. In hot path we try to avoid checking whether we need to +/// fill buffer from forward_reader by calling fill in advance and readF with +/// buffered flag set. +/// +pub fn BitReader(comptime T: type) type { + assert(T == u32 or T == u64); + const t_bytes: usize = @sizeOf(T); + const Tshift = if (T == u64) u6 else u5; + + return struct { + // Underlying reader used for filling internal bits buffer + forward_reader: *std.io.BufferedReader, + // Internal buffer of 64 bits + bits: T = 0, + // Number of bits in the buffer + nbits: u32 = 0, + + const Self = @This(); + + pub const Flags = packed struct(u3) { + /// dont advance internal buffer, just get bits, leave them in buffer + peek: bool = false, + /// assume that there is no need to fill, fill should be called before + buffered: bool = false, + /// bit reverse read bits + reverse: bool = false, + + /// work around https://github.com/ziglang/zig/issues/18882 + pub inline fn toInt(f: Flags) u3 { + return @bitCast(f); + } + }; + + pub fn init(forward_reader: *std.io.BufferedReader) Self { + var self = Self{ .forward_reader = forward_reader }; + self.fill(1) catch {}; + return self; + } + + /// Try to have `nice` bits are available in buffer. Reads from + /// forward reader if there is no `nice` bits in buffer. Returns error + /// if end of forward stream is reached and internal buffer is empty. + /// It will not error if less than `nice` bits are in buffer, only when + /// all bits are exhausted. During inflate we usually know what is the + /// maximum bits for the next step but usually that step will need less + /// bits to decode. So `nice` is not hard limit, it will just try to have + /// that number of bits available. If end of forward stream is reached + /// it may be some extra zero bits in buffer. + pub fn fill(self: *Self, nice: u6) !void { + if (self.nbits >= nice and nice != 0) { + return; // We have enough bits + } + // Read more bits from forward reader + + // Number of empty bytes in bits, round nbits to whole bytes. + const empty_bytes = + @as(u8, if (self.nbits & 0x7 == 0) t_bytes else t_bytes - 1) - // 8 for 8, 16, 24..., 7 otherwise + (self.nbits >> 3); // 0 for 0-7, 1 for 8-16, ... same as / 8 + + var buf: [t_bytes]u8 = [_]u8{0} ** t_bytes; + const bytes_read = self.forward_reader.partialRead(buf[0..empty_bytes]) catch 0; + if (bytes_read > 0) { + const u: T = std.mem.readInt(T, buf[0..t_bytes], .little); + self.bits |= u << @as(Tshift, @intCast(self.nbits)); + self.nbits += 8 * @as(u8, @intCast(bytes_read)); + return; + } + + if (self.nbits == 0) + return error.EndOfStream; + } + + /// Read exactly buf.len bytes into buf. + pub fn readAll(self: *Self, buf: []u8) anyerror!void { + assert(self.alignBits() == 0); // internal bits must be at byte boundary + + // First read from internal bits buffer. + var n: usize = 0; + while (self.nbits > 0 and n < buf.len) { + buf[n] = try self.readF(u8, .{ .buffered = true }); + n += 1; + } + // Then use forward reader for all other bytes. + try self.forward_reader.read(buf[n..]); + } + + /// Alias for readF(U, 0). + pub fn read(self: *Self, comptime U: type) !U { + return self.readF(U, .{}); + } + + /// Alias for readF with flag.peak set. + pub inline fn peekF(self: *Self, comptime U: type, comptime how: Flags) !U { + return self.readF(U, .{ + .peek = true, + .buffered = how.buffered, + .reverse = how.reverse, + }); + } + + /// Read with flags provided. + pub fn readF(self: *Self, comptime U: type, comptime how: Flags) !U { + if (U == T) { + assert(how.toInt() == 0); + assert(self.alignBits() == 0); + try self.fill(@bitSizeOf(T)); + if (self.nbits != @bitSizeOf(T)) return error.EndOfStream; + const v = self.bits; + self.nbits = 0; + self.bits = 0; + return v; + } + const n: Tshift = @bitSizeOf(U); + // work around https://github.com/ziglang/zig/issues/18882 + switch (how.toInt()) { + @as(Flags, .{}).toInt() => { // `normal` read + try self.fill(n); // ensure that there are n bits in the buffer + const u: U = @truncate(self.bits); // get n bits + try self.shift(n); // advance buffer for n + return u; + }, + @as(Flags, .{ .peek = true }).toInt() => { // no shift, leave bits in the buffer + try self.fill(n); + return @truncate(self.bits); + }, + @as(Flags, .{ .buffered = true }).toInt() => { // no fill, assume that buffer has enough bits + const u: U = @truncate(self.bits); + try self.shift(n); + return u; + }, + @as(Flags, .{ .reverse = true }).toInt() => { // same as 0 with bit reverse + try self.fill(n); + const u: U = @truncate(self.bits); + try self.shift(n); + return @bitReverse(u); + }, + @as(Flags, .{ .peek = true, .reverse = true }).toInt() => { + try self.fill(n); + return @bitReverse(@as(U, @truncate(self.bits))); + }, + @as(Flags, .{ .buffered = true, .reverse = true }).toInt() => { + const u: U = @truncate(self.bits); + try self.shift(n); + return @bitReverse(u); + }, + @as(Flags, .{ .peek = true, .buffered = true }).toInt() => { + return @truncate(self.bits); + }, + @as(Flags, .{ .peek = true, .buffered = true, .reverse = true }).toInt() => { + return @bitReverse(@as(U, @truncate(self.bits))); + }, + } + } + + /// Read n number of bits. + /// Only buffered flag can be used in how. + pub fn readN(self: *Self, n: u4, comptime how: Flags) !u16 { + // work around https://github.com/ziglang/zig/issues/18882 + switch (how.toInt()) { + @as(Flags, .{}).toInt() => { + try self.fill(n); + }, + @as(Flags, .{ .buffered = true }).toInt() => {}, + else => unreachable, + } + const mask: u16 = (@as(u16, 1) << n) - 1; + const u: u16 = @as(u16, @truncate(self.bits)) & mask; + try self.shift(n); + return u; + } + + /// Advance buffer for n bits. + pub fn shift(self: *Self, n: Tshift) !void { + if (n > self.nbits) return error.EndOfStream; + self.bits >>= n; + self.nbits -= n; + } + + /// Skip n bytes. + pub fn skipBytes(self: *Self, n: u16) !void { + for (0..n) |_| { + try self.fill(8); + try self.shift(8); + } + } + + // Number of bits to align stream to the byte boundary. + fn alignBits(self: *Self) u3 { + return @intCast(self.nbits & 0x7); + } + + /// Align stream to the byte boundary. + pub fn alignToByte(self: *Self) void { + const ab = self.alignBits(); + if (ab > 0) self.shift(ab) catch unreachable; + } + + /// Skip zero terminated string. + pub fn skipStringZ(self: *Self) !void { + while (true) { + if (try self.readF(u8, 0) == 0) break; + } + } + + /// Read deflate fixed fixed code. + /// Reads first 7 bits, and then maybe 1 or 2 more to get full 7,8 or 9 bit code. + /// ref: https://datatracker.ietf.org/doc/html/rfc1951#page-12 + /// Lit Value Bits Codes + /// --------- ---- ----- + /// 0 - 143 8 00110000 through + /// 10111111 + /// 144 - 255 9 110010000 through + /// 111111111 + /// 256 - 279 7 0000000 through + /// 0010111 + /// 280 - 287 8 11000000 through + /// 11000111 + pub fn readFixedCode(self: *Self) !u16 { + try self.fill(7 + 2); + const code7 = try self.readF(u7, .{ .buffered = true, .reverse = true }); + if (code7 <= 0b0010_111) { // 7 bits, 256-279, codes 0000_000 - 0010_111 + return @as(u16, code7) + 256; + } else if (code7 <= 0b1011_111) { // 8 bits, 0-143, codes 0011_0000 through 1011_1111 + return (@as(u16, code7) << 1) + @as(u16, try self.readF(u1, .{ .buffered = true })) - 0b0011_0000; + } else if (code7 <= 0b1100_011) { // 8 bit, 280-287, codes 1100_0000 - 1100_0111 + return (@as(u16, code7 - 0b1100000) << 1) + try self.readF(u1, .{ .buffered = true }) + 280; + } else { // 9 bit, 144-255, codes 1_1001_0000 - 1_1111_1111 + return (@as(u16, code7 - 0b1100_100) << 2) + @as(u16, try self.readF(u2, .{ .buffered = true, .reverse = true })) + 144; + } + } + }; +} + +test "readF" { + var input: std.io.BufferedReader = undefined; + input.initFixed(&[_]u8{ 0xf3, 0x48, 0xcd, 0xc9, 0x00, 0x00 }); + var br: BitReader(u64) = .init(&input); + + try testing.expectEqual(@as(u8, 48), br.nbits); + try testing.expectEqual(@as(u64, 0xc9cd48f3), br.bits); + + try testing.expect(try br.readF(u1, 0) == 0b0000_0001); + try testing.expect(try br.readF(u2, 0) == 0b0000_0001); + try testing.expectEqual(@as(u8, 48 - 3), br.nbits); + try testing.expectEqual(@as(u3, 5), br.alignBits()); + + try testing.expect(try br.readF(u8, .{ .peek = true }) == 0b0001_1110); + try testing.expect(try br.readF(u9, .{ .peek = true }) == 0b1_0001_1110); + try br.shift(9); + try testing.expectEqual(@as(u8, 36), br.nbits); + try testing.expectEqual(@as(u3, 4), br.alignBits()); + + try testing.expect(try br.readF(u4, 0) == 0b0100); + try testing.expectEqual(@as(u8, 32), br.nbits); + try testing.expectEqual(@as(u3, 0), br.alignBits()); + + try br.shift(1); + try testing.expectEqual(@as(u3, 7), br.alignBits()); + try br.shift(1); + try testing.expectEqual(@as(u3, 6), br.alignBits()); + br.alignToByte(); + try testing.expectEqual(@as(u3, 0), br.alignBits()); + + try testing.expectEqual(@as(u64, 0xc9), br.bits); + try testing.expectEqual(@as(u16, 0x9), try br.readN(4, 0)); + try testing.expectEqual(@as(u16, 0xc), try br.readN(4, 0)); +} + +test "read block type 1 data" { + inline for ([_]type{ u64, u32 }) |T| { + const data = [_]u8{ + 0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1 + 0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00, + 0x0c, 0x01, 0x02, 0x03, // + 0xaa, 0xbb, 0xcc, 0xdd, + }; + var fbs: std.io.BufferedReader = undefined; + fbs.initFixed(&data); + var br: BitReader(T) = .init(&fbs); + + try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal + try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type + + for ("Hello world\n") |c| { + try testing.expectEqual(@as(u8, c), try br.readF(u8, .{ .reverse = true }) - 0x30); + } + try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block + br.alignToByte(); + try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0)); + try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0)); + try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0)); + } +} + +test "shift/fill" { + const data = [_]u8{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + }; + var fbs: std.io.BufferedReader = undefined; + fbs.initFixed(&data); + var br: BitReader(u64) = .init(&fbs); + + try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits); + try br.shift(8); + try testing.expectEqual(@as(u64, 0x00_08_07_06_05_04_03_02), br.bits); + try br.fill(60); // fill with 1 byte + try testing.expectEqual(@as(u64, 0x01_08_07_06_05_04_03_02), br.bits); + try br.shift(8 * 4 + 4); + try testing.expectEqual(@as(u64, 0x00_00_00_00_00_10_80_70), br.bits); + + try br.fill(60); // fill with 4 bytes (shift by 4) + try testing.expectEqual(@as(u64, 0x00_50_40_30_20_10_80_70), br.bits); + try testing.expectEqual(@as(u8, 8 * 7 + 4), br.nbits); + + try br.shift(@intCast(br.nbits)); // clear buffer + try br.fill(8); // refill with the rest of the bytes + try testing.expectEqual(@as(u64, 0x00_00_00_00_00_08_07_06), br.bits); +} + +test "readAll" { + inline for ([_]type{ u64, u32 }) |T| { + const data = [_]u8{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + }; + var fbs: std.io.BufferedReader = undefined; + fbs.initFixed(&data); + var br: BitReader(T) = .init(&fbs); + + switch (T) { + u64 => try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits), + u32 => try testing.expectEqual(@as(u32, 0x04_03_02_01), br.bits), + else => unreachable, + } + + var out: [16]u8 = undefined; + try br.readAll(out[0..]); + try testing.expect(br.nbits == 0); + try testing.expect(br.bits == 0); + + try testing.expectEqualSlices(u8, data[0..16], &out); + } +} + +test "readFixedCode" { + inline for ([_]type{ u64, u32 }) |T| { + const fixed_codes = @import("huffman_encoder.zig").fixed_codes; + + var fbs: std.io.BufferedReader = undefined; + fbs.initFixed(&fixed_codes); + var rdr: BitReader(T) = .init(&fbs); + + for (0..286) |c| { + try testing.expectEqual(c, try rdr.readFixedCode()); + } + try testing.expect(rdr.nbits == 0); + } +} + +test "u32 leaves no bits on u32 reads" { + const data = [_]u8{ + 0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + }; + var fbs: std.io.BufferedReader = undefined; + fbs.initFixed(&data); + var br: BitReader(u32) = .init(&fbs); + + _ = try br.read(u3); + try testing.expectEqual(29, br.nbits); + br.alignToByte(); + try testing.expectEqual(24, br.nbits); + try testing.expectEqual(0x04_03_02_01, try br.read(u32)); + try testing.expectEqual(0, br.nbits); + try testing.expectEqual(0x08_07_06_05, try br.read(u32)); + try testing.expectEqual(0, br.nbits); + + _ = try br.read(u9); + try testing.expectEqual(23, br.nbits); + br.alignToByte(); + try testing.expectEqual(16, br.nbits); + try testing.expectEqual(0x0e_0d_0c_0b, try br.read(u32)); + try testing.expectEqual(0, br.nbits); +} + +test "u64 need fill after alignToByte" { + const data = [_]u8{ + 0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + }; + + // without fill + var fbs: std.io.BufferedReader = undefined; + fbs.initFixed(&data); + var br: BitReader(u64) = .init(&fbs); + _ = try br.read(u23); + try testing.expectEqual(41, br.nbits); + br.alignToByte(); + try testing.expectEqual(40, br.nbits); + try testing.expectEqual(0x06_05_04_03, try br.read(u32)); + try testing.expectEqual(8, br.nbits); + try testing.expectEqual(0x0a_09_08_07, try br.read(u32)); + try testing.expectEqual(32, br.nbits); + + // fill after align ensures all bits filled + fbs.reset(); + br = .init(&fbs); + _ = try br.read(u23); + try testing.expectEqual(41, br.nbits); + br.alignToByte(); + try br.fill(0); + try testing.expectEqual(64, br.nbits); + try testing.expectEqual(0x06_05_04_03, try br.read(u32)); + try testing.expectEqual(32, br.nbits); + try testing.expectEqual(0x0a_09_08_07, try br.read(u32)); + try testing.expectEqual(0, br.nbits); +} diff --git a/lib/std/compress/lzma.zig b/lib/std/compress/lzma.zig index aa35c3ffa9..51425d3b12 100644 --- a/lib/std/compress/lzma.zig +++ b/lib/std/compress/lzma.zig @@ -1,90 +1,936 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const math = std.math; const mem = std.mem; const Allocator = std.mem.Allocator; +const testing = std.testing; +const expectEqualSlices = std.testing.expectEqualSlices; +const expectError = std.testing.expectError; -pub const decode = @import("lzma/decode.zig"); +pub const RangeDecoder = struct { + range: u32, + code: u32, -pub fn decompress( - allocator: Allocator, - reader: anytype, -) !Decompress(@TypeOf(reader)) { - return decompressWithOptions(allocator, reader, .{}); -} + pub fn init(rd: *RangeDecoder, br: *std.io.BufferedReader) anyerror!usize { + const reserved = try br.takeByte(); + if (reserved != 0) return error.CorruptInput; + rd.* = .{ + .range = 0xFFFF_FFFF, + .code = try br.takeInt(u32, .big), + }; + return 5; + } -pub fn decompressWithOptions( - allocator: Allocator, - reader: anytype, - options: decode.Options, -) !Decompress(@TypeOf(reader)) { - const params = try decode.Params.readHeader(reader, options); - return Decompress(@TypeOf(reader)).init(allocator, reader, params, options.memlimit); -} + pub inline fn isFinished(self: RangeDecoder) bool { + return self.code == 0; + } -pub fn Decompress(comptime ReaderType: type) type { + inline fn normalize(self: *RangeDecoder, br: *std.io.BufferedReader) !void { + if (self.range < 0x0100_0000) { + self.range <<= 8; + self.code = (self.code << 8) ^ @as(u32, try br.takeByte()); + } + } + + inline fn getBit(self: *RangeDecoder, br: *std.io.BufferedReader) !bool { + self.range >>= 1; + + const bit = self.code >= self.range; + if (bit) + self.code -= self.range; + + try self.normalize(br); + return bit; + } + + pub fn get(self: *RangeDecoder, br: *std.io.BufferedReader, count: usize) !u32 { + var result: u32 = 0; + var i: usize = 0; + while (i < count) : (i += 1) + result = (result << 1) ^ @intFromBool(try self.getBit(br)); + return result; + } + + pub inline fn decodeBit(self: *RangeDecoder, br: *std.io.BufferedReader, prob: *u16, update: bool) !bool { + const bound = (self.range >> 11) * prob.*; + + if (self.code < bound) { + if (update) + prob.* += (0x800 - prob.*) >> 5; + self.range = bound; + + try self.normalize(br); + return false; + } else { + if (update) + prob.* -= prob.* >> 5; + self.code -= bound; + self.range -= bound; + + try self.normalize(br); + return true; + } + } + + fn parseBitTree( + self: *RangeDecoder, + br: *std.io.BufferedReader, + num_bits: u5, + probs: []u16, + update: bool, + ) !u32 { + var tmp: u32 = 1; + var i: @TypeOf(num_bits) = 0; + while (i < num_bits) : (i += 1) { + const bit = try self.decodeBit(br, &probs[tmp], update); + tmp = (tmp << 1) ^ @intFromBool(bit); + } + return tmp - (@as(u32, 1) << num_bits); + } + + pub fn parseReverseBitTree( + self: *RangeDecoder, + br: *std.io.BufferedReader, + num_bits: u5, + probs: []u16, + offset: usize, + update: bool, + ) !u32 { + var result: u32 = 0; + var tmp: usize = 1; + var i: @TypeOf(num_bits) = 0; + while (i < num_bits) : (i += 1) { + const bit = @intFromBool(try self.decodeBit(br, &probs[offset + tmp], update)); + tmp = (tmp << 1) ^ bit; + result ^= @as(u32, bit) << i; + } + return result; + } +}; + +pub const LenDecoder = struct { + choice: u16 = 0x400, + choice2: u16 = 0x400, + low_coder: [16]BitTree(3) = @splat(.{}), + mid_coder: [16]BitTree(3) = @splat(.{}), + high_coder: BitTree(8) = .{}, + + pub fn decode( + self: *LenDecoder, + br: *std.io.BufferedReader, + decoder: *RangeDecoder, + pos_state: usize, + update: bool, + ) !usize { + if (!try decoder.decodeBit(br, &self.choice, update)) { + return @as(usize, try self.low_coder[pos_state].parse(br, decoder, update)); + } else if (!try decoder.decodeBit(br, &self.choice2, update)) { + return @as(usize, try self.mid_coder[pos_state].parse(br, decoder, update)) + 8; + } else { + return @as(usize, try self.high_coder.parse(br, decoder, update)) + 16; + } + } + + pub fn reset(self: *LenDecoder) void { + self.choice = 0x400; + self.choice2 = 0x400; + for (&self.low_coder) |*t| t.reset(); + for (&self.mid_coder) |*t| t.reset(); + self.high_coder.reset(); + } +}; + +pub fn BitTree(comptime num_bits: usize) type { return struct { + probs: [1 << num_bits]u16 = @splat(0x400), + const Self = @This(); - pub const Error = - ReaderType.Error || - Allocator.Error || - error{ CorruptInput, EndOfStream, Overflow }; - - pub const Reader = std.io.Reader(*Self, Error, read); - - allocator: Allocator, - in_reader: ReaderType, - to_read: std.ArrayListUnmanaged(u8), - - buffer: decode.lzbuffer.LzCircularBuffer, - decoder: decode.rangecoder.RangeDecoder, - state: decode.DecoderState, - - pub fn init(allocator: Allocator, source: ReaderType, params: decode.Params, memlimit: ?usize) !Self { - return Self{ - .allocator = allocator, - .in_reader = source, - .to_read = .{}, - - .buffer = decode.lzbuffer.LzCircularBuffer.init(params.dict_size, memlimit orelse math.maxInt(usize)), - .decoder = try decode.rangecoder.RangeDecoder.init(source), - .state = try decode.DecoderState.init(allocator, params.properties, params.unpacked_size), - }; + pub fn parse( + self: *Self, + br: *std.io.BufferedReader, + decoder: *RangeDecoder, + update: bool, + ) !u32 { + return decoder.parseBitTree(br, num_bits, &self.probs, update); } - pub fn reader(self: *Self) Reader { - return .{ .context = self }; + pub fn parseReverse( + self: *Self, + br: *std.io.BufferedReader, + decoder: *RangeDecoder, + update: bool, + ) !u32 { + return decoder.parseReverseBitTree(br, num_bits, &self.probs, 0, update); } - pub fn deinit(self: *Self) void { - self.to_read.deinit(self.allocator); - self.buffer.deinit(self.allocator); - self.state.deinit(self.allocator); - self.* = undefined; - } - - pub fn read(self: *Self, output: []u8) Error!usize { - const writer = self.to_read.writer(self.allocator); - while (self.to_read.items.len < output.len) { - switch (try self.state.process(self.allocator, self.in_reader, writer, &self.buffer, &self.decoder)) { - .continue_ => {}, - .finished => { - try self.buffer.finish(writer); - break; - }, - } - } - const input = self.to_read.items; - const n = @min(input.len, output.len); - @memcpy(output[0..n], input[0..n]); - std.mem.copyForwards(u8, input[0 .. input.len - n], input[n..]); - self.to_read.shrinkRetainingCapacity(input.len - n); - return n; + pub fn reset(self: *Self) void { + @memset(&self.probs, 0x400); } }; } -test { - _ = @import("lzma/test.zig"); - _ = @import("lzma/vec2d.zig"); +pub const Decode = struct { + properties: Properties, + unpacked_size: ?u64, + literal_probs: Vec2D(u16), + pos_slot_decoder: [4]BitTree(6), + align_decoder: BitTree(4), + pos_decoders: [115]u16, + is_match: [192]u16, + is_rep: [12]u16, + is_rep_g0: [12]u16, + is_rep_g1: [12]u16, + is_rep_g2: [12]u16, + is_rep_0long: [192]u16, + state: usize, + rep: [4]usize, + len_decoder: LenDecoder, + rep_len_decoder: LenDecoder, + + pub const Options = struct { + unpacked_size: UnpackedSize = .read_from_header, + memlimit: ?usize = null, + allow_incomplete: bool = false, + }; + + pub const UnpackedSize = union(enum) { + read_from_header, + read_header_but_use_provided: ?u64, + use_provided: ?u64, + }; + + const ProcessingStatus = enum { + cont, + finished, + }; + + pub const Properties = struct { + lc: u4, + lp: u3, + pb: u3, + + fn validate(self: Properties) void { + assert(self.lc <= 8); + assert(self.lp <= 4); + assert(self.pb <= 4); + } + }; + + pub const Params = struct { + properties: Properties, + dict_size: u32, + unpacked_size: ?u64, + + pub fn readHeader(br: *std.io.BufferedReader, options: Options) anyerror!Params { + var props = try br.readByte(); + if (props >= 225) { + return error.CorruptInput; + } + + const lc = @as(u4, @intCast(props % 9)); + props /= 9; + const lp = @as(u3, @intCast(props % 5)); + props /= 5; + const pb = @as(u3, @intCast(props)); + + const dict_size_provided = try br.readInt(u32, .little); + const dict_size = @max(0x1000, dict_size_provided); + + const unpacked_size = switch (options.unpacked_size) { + .read_from_header => blk: { + const unpacked_size_provided = try br.readInt(u64, .little); + const marker_mandatory = unpacked_size_provided == 0xFFFF_FFFF_FFFF_FFFF; + break :blk if (marker_mandatory) + null + else + unpacked_size_provided; + }, + .read_header_but_use_provided => |x| blk: { + _ = try br.readInt(u64, .little); + break :blk x; + }, + .use_provided => |x| x, + }; + + return Params{ + .properties = Properties{ .lc = lc, .lp = lp, .pb = pb }, + .dict_size = dict_size, + .unpacked_size = unpacked_size, + }; + } + }; + + pub fn init( + allocator: Allocator, + properties: Properties, + unpacked_size: ?u64, + ) !Decode { + return .{ + .properties = properties, + .unpacked_size = unpacked_size, + .literal_probs = try Vec2D(u16).init(allocator, 0x400, .{ @as(usize, 1) << (properties.lc + properties.lp), 0x300 }), + .pos_slot_decoder = @splat(.{}), + .align_decoder = .{}, + .pos_decoders = @splat(0x400), + .is_match = @splat(0x400), + .is_rep = @splat(0x400), + .is_rep_g0 = @splat(0x400), + .is_rep_g1 = @splat(0x400), + .is_rep_g2 = @splat(0x400), + .is_rep_0long = @splat(0x400), + .state = 0, + .rep = @splat(0), + .len_decoder = .{}, + .rep_len_decoder = .{}, + }; + } + + pub fn deinit(self: *Decode, allocator: Allocator) void { + self.literal_probs.deinit(allocator); + self.* = undefined; + } + + pub fn resetState(self: *Decode, allocator: Allocator, new_props: Properties) !void { + new_props.validate(); + if (self.properties.lc + self.properties.lp == new_props.lc + new_props.lp) { + self.literal_probs.fill(0x400); + } else { + self.literal_probs.deinit(allocator); + self.literal_probs = try Vec2D(u16).init(allocator, 0x400, .{ @as(usize, 1) << (new_props.lc + new_props.lp), 0x300 }); + } + + self.properties = new_props; + for (&self.pos_slot_decoder) |*t| t.reset(); + self.align_decoder.reset(); + self.pos_decoders = @splat(0x400); + self.is_match = @splat(0x400); + self.is_rep = @splat(0x400); + self.is_rep_g0 = @splat(0x400); + self.is_rep_g1 = @splat(0x400); + self.is_rep_g2 = @splat(0x400); + self.is_rep_0long = @splat(0x400); + self.state = 0; + self.rep = @splat(0); + self.len_decoder.reset(); + self.rep_len_decoder.reset(); + } + + fn processNextInner( + self: *Decode, + allocator: Allocator, + br: *std.io.BufferedReader, + bw: *std.io.BufferedWriter, + buffer: anytype, + decoder: *RangeDecoder, + bytes_read: *usize, + update: bool, + ) !ProcessingStatus { + const pos_state = buffer.len & ((@as(usize, 1) << self.properties.pb) - 1); + + if (!try decoder.decodeBit(br, &self.is_match[(self.state << 4) + pos_state], update, bytes_read)) { + const byte: u8 = try self.decodeLiteral(br, buffer, decoder, update, bytes_read); + + if (update) { + try buffer.appendLiteral(allocator, byte, bw); + + self.state = if (self.state < 4) + 0 + else if (self.state < 10) + self.state - 3 + else + self.state - 6; + } + return .cont; + } + + var len: usize = undefined; + if (try decoder.decodeBit(br, &self.is_rep[self.state], update, bytes_read)) { + if (!try decoder.decodeBit(br, &self.is_rep_g0[self.state], update, bytes_read)) { + if (!try decoder.decodeBit(br, &self.is_rep_0long[(self.state << 4) + pos_state], update, bytes_read)) { + if (update) { + self.state = if (self.state < 7) 9 else 11; + const dist = self.rep[0] + 1; + try buffer.appendLz(allocator, 1, dist, bw); + } + return .cont; + } + } else { + const idx: usize = if (!try decoder.decodeBit(br, &self.is_rep_g1[self.state], update, bytes_read)) + 1 + else if (!try decoder.decodeBit(br, &self.is_rep_g2[self.state], update, bytes_read)) + 2 + else + 3; + if (update) { + const dist = self.rep[idx]; + var i = idx; + while (i > 0) : (i -= 1) { + self.rep[i] = self.rep[i - 1]; + } + self.rep[0] = dist; + } + } + + len = try self.rep_len_decoder.decode(br, decoder, pos_state, update, bytes_read); + + if (update) { + self.state = if (self.state < 7) 8 else 11; + } + } else { + if (update) { + self.rep[3] = self.rep[2]; + self.rep[2] = self.rep[1]; + self.rep[1] = self.rep[0]; + } + + len = try self.len_decoder.decode(br, decoder, pos_state, update, bytes_read); + + if (update) { + self.state = if (self.state < 7) 7 else 10; + } + + const rep_0 = try self.decodeDistance(br, decoder, len, update, bytes_read); + + if (update) { + self.rep[0] = rep_0; + if (self.rep[0] == 0xFFFF_FFFF) { + if (decoder.isFinished()) { + return .finished; + } + return error.CorruptInput; + } + } + } + + if (update) { + len += 2; + + const dist = self.rep[0] + 1; + try buffer.appendLz(allocator, len, dist, bw); + } + + return .cont; + } + + fn processNext( + self: *Decode, + allocator: Allocator, + br: *std.io.BufferedReader, + bw: *std.io.BufferedWriter, + buffer: anytype, + decoder: *RangeDecoder, + bytes_read: *usize, + ) !ProcessingStatus { + return self.processNextInner(allocator, br, bw, buffer, decoder, bytes_read, true); + } + + pub fn process( + self: *Decode, + allocator: Allocator, + br: *std.io.BufferedReader, + bw: *std.io.BufferedWriter, + buffer: anytype, + decoder: *RangeDecoder, + bytes_read: *usize, + ) !ProcessingStatus { + process_next: { + if (self.unpacked_size) |unpacked_size| { + if (buffer.len >= unpacked_size) { + break :process_next; + } + } else if (decoder.isFinished()) { + break :process_next; + } + + switch (try self.processNext(allocator, br, bw, buffer, decoder, bytes_read)) { + .cont => return .cont, + .finished => break :process_next, + } + } + + if (self.unpacked_size) |unpacked_size| { + if (buffer.len != unpacked_size) { + return error.CorruptInput; + } + } + + return .finished; + } + + fn decodeLiteral( + self: *Decode, + br: *std.io.BufferedReader, + buffer: anytype, + decoder: *RangeDecoder, + update: bool, + bytes_read: *usize, + ) !u8 { + const def_prev_byte = 0; + const prev_byte = @as(usize, buffer.lastOr(def_prev_byte)); + + var result: usize = 1; + const lit_state = ((buffer.len & ((@as(usize, 1) << self.properties.lp) - 1)) << self.properties.lc) + + (prev_byte >> (8 - self.properties.lc)); + const probs = try self.literal_probs.getMut(lit_state); + + if (self.state >= 7) { + var match_byte = @as(usize, try buffer.lastN(self.rep[0] + 1)); + + while (result < 0x100) { + const match_bit = (match_byte >> 7) & 1; + match_byte <<= 1; + const bit = @intFromBool(try decoder.decodeBit( + br, + &probs[((@as(usize, 1) + match_bit) << 8) + result], + update, + bytes_read, + )); + result = (result << 1) ^ bit; + if (match_bit != bit) { + break; + } + } + } + + while (result < 0x100) { + result = (result << 1) ^ @intFromBool(try decoder.decodeBit(br, &probs[result], update, bytes_read)); + } + + return @as(u8, @truncate(result - 0x100)); + } + + fn decodeDistance( + self: *Decode, + br: *std.io.BufferedReader, + decoder: *RangeDecoder, + length: usize, + update: bool, + bytes_read: *usize, + ) !usize { + const len_state = if (length > 3) 3 else length; + + const pos_slot = @as(usize, try self.pos_slot_decoder[len_state].parse(br, decoder, update, bytes_read)); + if (pos_slot < 4) + return pos_slot; + + const num_direct_bits = @as(u5, @intCast((pos_slot >> 1) - 1)); + var result = (2 ^ (pos_slot & 1)) << num_direct_bits; + + if (pos_slot < 14) { + result += try decoder.parseReverseBitTree( + br, + num_direct_bits, + &self.pos_decoders, + result - pos_slot, + update, + bytes_read, + ); + } else { + result += @as(usize, try decoder.get(br, num_direct_bits - 4, bytes_read)) << 4; + result += try self.align_decoder.parseReverse(br, decoder, update, bytes_read); + } + + return result; + } +}; + +pub const Decompress = struct { + pub const Error = + anyerror || + Allocator.Error || + error{ CorruptInput, EndOfStream, Overflow }; + + allocator: Allocator, + in_reader: *std.io.BufferedReader, + to_read: std.ArrayListUnmanaged(u8), + + buffer: LzCircularBuffer, + decoder: RangeDecoder, + state: Decode, + + pub fn initOptions(allocator: Allocator, br: *std.io.BufferedReader, options: Decode.Options) !Decompress { + const params = try Decode.Params.readHeader(br, options); + return init(allocator, br, params, options.memlimit); + } + + pub fn init(allocator: Allocator, source: *std.io.BufferedReader, params: Decode.Params, memlimit: ?usize) !Decompress { + return .{ + .allocator = allocator, + .in_reader = source, + .to_read = .{}, + + .buffer = LzCircularBuffer.init(params.dict_size, memlimit orelse math.maxInt(usize)), + .decoder = try RangeDecoder.init(source), + .state = try Decode.init(allocator, params.properties, params.unpacked_size), + }; + } + + pub fn reader(self: *Decompress) std.io.Reader { + return .{ .context = self }; + } + + pub fn deinit(self: *Decompress) void { + self.to_read.deinit(self.allocator); + self.buffer.deinit(self.allocator); + self.state.deinit(self.allocator); + self.* = undefined; + } + + pub fn read(self: *Decompress, output: []u8) Error!usize { + const bw = self.to_read.writer(self.allocator); + while (self.to_read.items.len < output.len) { + switch (try self.state.process(self.allocator, self.in_reader, bw, &self.buffer, &self.decoder)) { + .cont => {}, + .finished => { + try self.buffer.finish(bw); + break; + }, + } + } + const input = self.to_read.items; + const n = @min(input.len, output.len); + @memcpy(output[0..n], input[0..n]); + std.mem.copyForwards(u8, input[0 .. input.len - n], input[n..]); + self.to_read.shrinkRetainingCapacity(input.len - n); + return n; + } +}; + +/// A circular buffer for LZ sequences +const LzCircularBuffer = struct { + /// Circular buffer + buf: std.ArrayListUnmanaged(u8), + + /// Length of the buffer + dict_size: usize, + + /// Buffer memory limit + memlimit: usize, + + /// Current position + cursor: usize, + + /// Total number of bytes sent through the buffer + len: usize, + + const Self = @This(); + + pub fn init(dict_size: usize, memlimit: usize) Self { + return Self{ + .buf = .{}, + .dict_size = dict_size, + .memlimit = memlimit, + .cursor = 0, + .len = 0, + }; + } + + pub fn get(self: Self, index: usize) u8 { + return if (0 <= index and index < self.buf.items.len) + self.buf.items[index] + else + 0; + } + + pub fn set(self: *Self, allocator: Allocator, index: usize, value: u8) !void { + if (index >= self.memlimit) { + return error.CorruptInput; + } + try self.buf.ensureTotalCapacity(allocator, index + 1); + while (self.buf.items.len < index) { + self.buf.appendAssumeCapacity(0); + } + self.buf.appendAssumeCapacity(value); + } + + /// Retrieve the last byte or return a default + pub fn lastOr(self: Self, lit: u8) u8 { + return if (self.len == 0) + lit + else + self.get((self.dict_size + self.cursor - 1) % self.dict_size); + } + + /// Retrieve the n-th last byte + pub fn lastN(self: Self, dist: usize) !u8 { + if (dist > self.dict_size or dist > self.len) { + return error.CorruptInput; + } + + const offset = (self.dict_size + self.cursor - dist) % self.dict_size; + return self.get(offset); + } + + /// Append a literal + pub fn appendLiteral( + self: *Self, + allocator: Allocator, + lit: u8, + bw: *std.io.BufferedWriter, + ) anyerror!void { + try self.set(allocator, self.cursor, lit); + self.cursor += 1; + self.len += 1; + + // Flush the circular buffer to the output + if (self.cursor == self.dict_size) { + try bw.writeAll(self.buf.items); + self.cursor = 0; + } + } + + /// Fetch an LZ sequence (length, distance) from inside the buffer + pub fn appendLz( + self: *Self, + allocator: Allocator, + len: usize, + dist: usize, + bw: *std.io.BufferedWriter, + ) anyerror!void { + if (dist > self.dict_size or dist > self.len) { + return error.CorruptInput; + } + + var offset = (self.dict_size + self.cursor - dist) % self.dict_size; + var i: usize = 0; + while (i < len) : (i += 1) { + const x = self.get(offset); + try self.appendLiteral(allocator, x, bw); + offset += 1; + if (offset == self.dict_size) { + offset = 0; + } + } + } + + pub fn finish(self: *Self, bw: *std.io.BufferedWriter) anyerror!void { + if (self.cursor > 0) { + try bw.writeAll(self.buf.items[0..self.cursor]); + self.cursor = 0; + } + } + + pub fn deinit(self: *Self, allocator: Allocator) void { + self.buf.deinit(allocator); + self.* = undefined; + } +}; + +pub fn Vec2D(comptime T: type) type { + return struct { + data: []T, + cols: usize, + + const Self = @This(); + + pub fn init(allocator: Allocator, value: T, size: struct { usize, usize }) !Self { + const len = try math.mul(usize, size[0], size[1]); + const data = try allocator.alloc(T, len); + @memset(data, value); + return Self{ + .data = data, + .cols = size[1], + }; + } + + pub fn deinit(self: *Self, allocator: Allocator) void { + allocator.free(self.data); + self.* = undefined; + } + + pub fn fill(self: *Self, value: T) void { + @memset(self.data, value); + } + + inline fn _get(self: Self, row: usize) ![]T { + const start_row = try math.mul(usize, row, self.cols); + const end_row = try math.add(usize, start_row, self.cols); + return self.data[start_row..end_row]; + } + + pub fn get(self: Self, row: usize) ![]const T { + return self._get(row); + } + + pub fn getMut(self: *Self, row: usize) ![]T { + return self._get(row); + } + }; +} + +test "Vec2D init" { + const allocator = testing.allocator; + var vec2d = try Vec2D(i32).init(allocator, 1, .{ 2, 3 }); + defer vec2d.deinit(allocator); + + try expectEqualSlices(i32, &.{ 1, 1, 1 }, try vec2d.get(0)); + try expectEqualSlices(i32, &.{ 1, 1, 1 }, try vec2d.get(1)); +} + +test "Vec2D init overflow" { + const allocator = testing.allocator; + try expectError( + error.Overflow, + Vec2D(i32).init(allocator, 1, .{ math.maxInt(usize), math.maxInt(usize) }), + ); +} + +test "Vec2D fill" { + const allocator = testing.allocator; + var vec2d = try Vec2D(i32).init(allocator, 0, .{ 2, 3 }); + defer vec2d.deinit(allocator); + + vec2d.fill(7); + + try expectEqualSlices(i32, &.{ 7, 7, 7 }, try vec2d.get(0)); + try expectEqualSlices(i32, &.{ 7, 7, 7 }, try vec2d.get(1)); +} + +test "Vec2D get" { + var data = [_]i32{ 0, 1, 2, 3, 4, 5, 6, 7 }; + const vec2d = Vec2D(i32){ + .data = &data, + .cols = 2, + }; + + try expectEqualSlices(i32, &.{ 0, 1 }, try vec2d.get(0)); + try expectEqualSlices(i32, &.{ 2, 3 }, try vec2d.get(1)); + try expectEqualSlices(i32, &.{ 4, 5 }, try vec2d.get(2)); + try expectEqualSlices(i32, &.{ 6, 7 }, try vec2d.get(3)); +} + +test "Vec2D getMut" { + var data = [_]i32{ 0, 1, 2, 3, 4, 5, 6, 7 }; + var vec2d = Vec2D(i32){ + .data = &data, + .cols = 2, + }; + + const row = try vec2d.getMut(1); + row[1] = 9; + + try expectEqualSlices(i32, &.{ 0, 1 }, try vec2d.get(0)); + // (1, 1) should be 9. + try expectEqualSlices(i32, &.{ 2, 9 }, try vec2d.get(1)); + try expectEqualSlices(i32, &.{ 4, 5 }, try vec2d.get(2)); + try expectEqualSlices(i32, &.{ 6, 7 }, try vec2d.get(3)); +} + +test "Vec2D get multiplication overflow" { + const allocator = testing.allocator; + var matrix = try Vec2D(i32).init(allocator, 0, .{ 3, 4 }); + defer matrix.deinit(allocator); + + const row = (math.maxInt(usize) / 4) + 1; + try expectError(error.Overflow, matrix.get(row)); + try expectError(error.Overflow, matrix.getMut(row)); +} + +test "Vec2D get addition overflow" { + const allocator = testing.allocator; + var matrix = try Vec2D(i32).init(allocator, 0, .{ 3, 5 }); + defer matrix.deinit(allocator); + + const row = math.maxInt(usize) / 5; + try expectError(error.Overflow, matrix.get(row)); + try expectError(error.Overflow, matrix.getMut(row)); +} + +fn testDecompress(compressed: []const u8) ![]u8 { + const allocator = std.testing.allocator; + var br: std.io.BufferedReader = undefined; + br.initFixed(compressed); + var decompressor = try Decompress.initOptions(allocator, &br, .{}); + defer decompressor.deinit(); + const reader = decompressor.reader(); + return reader.readAllAlloc(allocator, std.math.maxInt(usize)); +} + +fn testDecompressEqual(expected: []const u8, compressed: []const u8) !void { + const allocator = std.testing.allocator; + const decomp = try testDecompress(compressed); + defer allocator.free(decomp); + try std.testing.expectEqualSlices(u8, expected, decomp); +} + +fn testDecompressError(expected: anyerror, compressed: []const u8) !void { + return std.testing.expectError(expected, testDecompress(compressed)); +} + +test "decompress empty world" { + try testDecompressEqual( + "", + &[_]u8{ + 0x5d, 0x00, 0x00, 0x80, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x83, 0xff, + 0xfb, 0xff, 0xff, 0xc0, 0x00, 0x00, 0x00, + }, + ); +} + +test "decompress hello world" { + try testDecompressEqual( + "Hello world\n", + &[_]u8{ + 0x5d, 0x00, 0x00, 0x80, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x24, 0x19, + 0x49, 0x98, 0x6f, 0x10, 0x19, 0xc6, 0xd7, 0x31, 0xeb, 0x36, 0x50, 0xb2, 0x98, 0x48, 0xff, 0xfe, + 0xa5, 0xb0, 0x00, + }, + ); +} + +test "decompress huge dict" { + try testDecompressEqual( + "Hello world\n", + &[_]u8{ + 0x5d, 0x7f, 0x7f, 0x7f, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x24, 0x19, + 0x49, 0x98, 0x6f, 0x10, 0x19, 0xc6, 0xd7, 0x31, 0xeb, 0x36, 0x50, 0xb2, 0x98, 0x48, 0xff, 0xfe, + 0xa5, 0xb0, 0x00, + }, + ); +} + +test "unknown size with end of payload marker" { + try testDecompressEqual( + "Hello\nWorld!\n", + @embedFile("testdata/good-unknown_size-with_eopm.lzma"), + ); +} + +test "known size without end of payload marker" { + try testDecompressEqual( + "Hello\nWorld!\n", + @embedFile("testdata/good-known_size-without_eopm.lzma"), + ); +} + +test "known size with end of payload marker" { + try testDecompressEqual( + "Hello\nWorld!\n", + @embedFile("testdata/good-known_size-with_eopm.lzma"), + ); +} + +test "too big uncompressed size in header" { + try testDecompressError( + error.CorruptInput, + @embedFile("testdata/bad-too_big_size-with_eopm.lzma"), + ); +} + +test "too small uncompressed size in header" { + try testDecompressError( + error.CorruptInput, + @embedFile("testdata/bad-too_small_size-without_eopm-3.lzma"), + ); +} + +test "reading one byte" { + const compressed = @embedFile("testdata/good-known_size-with_eopm.lzma"); + var br: std.io.BufferedReader = undefined; + br.initFixed(compressed); + var decompressor = try Decompress.initOptions(std.testing.allocator, &br, .{}); + defer decompressor.deinit(); + var buffer = [1]u8{0}; + _ = try decompressor.read(buffer[0..]); } diff --git a/lib/std/compress/lzma/decode.zig b/lib/std/compress/lzma/decode.zig deleted file mode 100644 index 37a3281314..0000000000 --- a/lib/std/compress/lzma/decode.zig +++ /dev/null @@ -1,539 +0,0 @@ -const std = @import("../../std.zig"); -const assert = std.debug.assert; -const math = std.math; -const Allocator = std.mem.Allocator; - -pub const lzbuffer = @import("decode/lzbuffer.zig"); - -const LzCircularBuffer = lzbuffer.LzCircularBuffer; -const Vec2D = @import("vec2d.zig").Vec2D; - -pub const RangeDecoder = struct { - range: u32, - code: u32, - - pub fn init(br: *std.io.BufferedReader) !RangeDecoder { - const reserved = try br.takeByte(); - if (reserved != 0) { - return error.CorruptInput; - } - return .{ - .range = 0xFFFF_FFFF, - .code = try br.readInt(u32, .big), - }; - } - - pub inline fn isFinished(self: RangeDecoder) bool { - return self.code == 0; - } - - inline fn normalize(self: *RangeDecoder, br: *std.io.BufferedReader) !void { - if (self.range < 0x0100_0000) { - self.range <<= 8; - self.code = (self.code << 8) ^ @as(u32, try br.takeByte()); - } - } - - inline fn getBit(self: *RangeDecoder, br: *std.io.BufferedReader) !bool { - self.range >>= 1; - - const bit = self.code >= self.range; - if (bit) - self.code -= self.range; - - try self.normalize(br); - return bit; - } - - pub fn get(self: *RangeDecoder, br: *std.io.BufferedReader, count: usize) !u32 { - var result: u32 = 0; - var i: usize = 0; - while (i < count) : (i += 1) - result = (result << 1) ^ @intFromBool(try self.getBit(br)); - return result; - } - - pub inline fn decodeBit(self: *RangeDecoder, br: *std.io.BufferedReader, prob: *u16, update: bool) !bool { - const bound = (self.range >> 11) * prob.*; - - if (self.code < bound) { - if (update) - prob.* += (0x800 - prob.*) >> 5; - self.range = bound; - - try self.normalize(br); - return false; - } else { - if (update) - prob.* -= prob.* >> 5; - self.code -= bound; - self.range -= bound; - - try self.normalize(br); - return true; - } - } - - fn parseBitTree( - self: *RangeDecoder, - br: *std.io.BufferedReader, - num_bits: u5, - probs: []u16, - update: bool, - ) !u32 { - var tmp: u32 = 1; - var i: @TypeOf(num_bits) = 0; - while (i < num_bits) : (i += 1) { - const bit = try self.decodeBit(br, &probs[tmp], update); - tmp = (tmp << 1) ^ @intFromBool(bit); - } - return tmp - (@as(u32, 1) << num_bits); - } - - pub fn parseReverseBitTree( - self: *RangeDecoder, - br: *std.io.BufferedReader, - num_bits: u5, - probs: []u16, - offset: usize, - update: bool, - ) !u32 { - var result: u32 = 0; - var tmp: usize = 1; - var i: @TypeOf(num_bits) = 0; - while (i < num_bits) : (i += 1) { - const bit = @intFromBool(try self.decodeBit(br, &probs[offset + tmp], update)); - tmp = (tmp << 1) ^ bit; - result ^= @as(u32, bit) << i; - } - return result; - } -}; - -pub fn BitTree(comptime num_bits: usize) type { - return struct { - probs: [1 << num_bits]u16 = @splat(0x400), - - const Self = @This(); - - pub fn parse( - self: *Self, - br: *std.io.BufferedReader, - decoder: *RangeDecoder, - update: bool, - ) !u32 { - return decoder.parseBitTree(br, num_bits, &self.probs, update); - } - - pub fn parseReverse( - self: *Self, - br: *std.io.BufferedReader, - decoder: *RangeDecoder, - update: bool, - ) !u32 { - return decoder.parseReverseBitTree(br, num_bits, &self.probs, 0, update); - } - - pub fn reset(self: *Self) void { - @memset(&self.probs, 0x400); - } - }; -} - -pub const LenDecoder = struct { - choice: u16 = 0x400, - choice2: u16 = 0x400, - low_coder: [16]BitTree(3) = @splat(.{}), - mid_coder: [16]BitTree(3) = @splat(.{}), - high_coder: BitTree(8) = .{}, - - pub fn decode( - self: *LenDecoder, - br: *std.io.BufferedReader, - decoder: *RangeDecoder, - pos_state: usize, - update: bool, - ) !usize { - if (!try decoder.decodeBit(br, &self.choice, update)) { - return @as(usize, try self.low_coder[pos_state].parse(br, decoder, update)); - } else if (!try decoder.decodeBit(br, &self.choice2, update)) { - return @as(usize, try self.mid_coder[pos_state].parse(br, decoder, update)) + 8; - } else { - return @as(usize, try self.high_coder.parse(br, decoder, update)) + 16; - } - } - - pub fn reset(self: *LenDecoder) void { - self.choice = 0x400; - self.choice2 = 0x400; - for (&self.low_coder) |*t| t.reset(); - for (&self.mid_coder) |*t| t.reset(); - self.high_coder.reset(); - } -}; - -pub const Options = struct { - unpacked_size: UnpackedSize = .read_from_header, - memlimit: ?usize = null, - allow_incomplete: bool = false, -}; - -pub const UnpackedSize = union(enum) { - read_from_header, - read_header_but_use_provided: ?u64, - use_provided: ?u64, -}; - -const ProcessingStatus = enum { - continue_, - finished, -}; - -pub const Properties = struct { - lc: u4, - lp: u3, - pb: u3, - - fn validate(self: Properties) void { - assert(self.lc <= 8); - assert(self.lp <= 4); - assert(self.pb <= 4); - } -}; - -pub const Params = struct { - properties: Properties, - dict_size: u32, - unpacked_size: ?u64, - - pub fn readHeader(reader: anytype, options: Options) !Params { - var props = try reader.readByte(); - if (props >= 225) { - return error.CorruptInput; - } - - const lc = @as(u4, @intCast(props % 9)); - props /= 9; - const lp = @as(u3, @intCast(props % 5)); - props /= 5; - const pb = @as(u3, @intCast(props)); - - const dict_size_provided = try reader.readInt(u32, .little); - const dict_size = @max(0x1000, dict_size_provided); - - const unpacked_size = switch (options.unpacked_size) { - .read_from_header => blk: { - const unpacked_size_provided = try reader.readInt(u64, .little); - const marker_mandatory = unpacked_size_provided == 0xFFFF_FFFF_FFFF_FFFF; - break :blk if (marker_mandatory) - null - else - unpacked_size_provided; - }, - .read_header_but_use_provided => |x| blk: { - _ = try reader.readInt(u64, .little); - break :blk x; - }, - .use_provided => |x| x, - }; - - return Params{ - .properties = Properties{ .lc = lc, .lp = lp, .pb = pb }, - .dict_size = dict_size, - .unpacked_size = unpacked_size, - }; - } -}; - -pub const DecoderState = struct { - lzma_props: Properties, - unpacked_size: ?u64, - literal_probs: Vec2D(u16), - pos_slot_decoder: [4]BitTree(6), - align_decoder: BitTree(4), - pos_decoders: [115]u16, - is_match: [192]u16, - is_rep: [12]u16, - is_rep_g0: [12]u16, - is_rep_g1: [12]u16, - is_rep_g2: [12]u16, - is_rep_0long: [192]u16, - state: usize, - rep: [4]usize, - len_decoder: LenDecoder, - rep_len_decoder: LenDecoder, - - pub fn init( - allocator: Allocator, - lzma_props: Properties, - unpacked_size: ?u64, - ) !DecoderState { - return .{ - .lzma_props = lzma_props, - .unpacked_size = unpacked_size, - .literal_probs = try Vec2D(u16).init(allocator, 0x400, .{ @as(usize, 1) << (lzma_props.lc + lzma_props.lp), 0x300 }), - .pos_slot_decoder = @splat(.{}), - .align_decoder = .{}, - .pos_decoders = @splat(0x400), - .is_match = @splat(0x400), - .is_rep = @splat(0x400), - .is_rep_g0 = @splat(0x400), - .is_rep_g1 = @splat(0x400), - .is_rep_g2 = @splat(0x400), - .is_rep_0long = @splat(0x400), - .state = 0, - .rep = @splat(0), - .len_decoder = .{}, - .rep_len_decoder = .{}, - }; - } - - pub fn deinit(self: *DecoderState, allocator: Allocator) void { - self.literal_probs.deinit(allocator); - self.* = undefined; - } - - pub fn resetState(self: *DecoderState, allocator: Allocator, new_props: Properties) !void { - new_props.validate(); - if (self.lzma_props.lc + self.lzma_props.lp == new_props.lc + new_props.lp) { - self.literal_probs.fill(0x400); - } else { - self.literal_probs.deinit(allocator); - self.literal_probs = try Vec2D(u16).init(allocator, 0x400, .{ @as(usize, 1) << (new_props.lc + new_props.lp), 0x300 }); - } - - self.lzma_props = new_props; - for (&self.pos_slot_decoder) |*t| t.reset(); - self.align_decoder.reset(); - self.pos_decoders = @splat(0x400); - self.is_match = @splat(0x400); - self.is_rep = @splat(0x400); - self.is_rep_g0 = @splat(0x400); - self.is_rep_g1 = @splat(0x400); - self.is_rep_g2 = @splat(0x400); - self.is_rep_0long = @splat(0x400); - self.state = 0; - self.rep = @splat(0); - self.len_decoder.reset(); - self.rep_len_decoder.reset(); - } - - fn processNextInner( - self: *DecoderState, - allocator: Allocator, - reader: anytype, - writer: anytype, - buffer: anytype, - decoder: *RangeDecoder, - update: bool, - ) !ProcessingStatus { - const pos_state = buffer.len & ((@as(usize, 1) << self.lzma_props.pb) - 1); - - if (!try decoder.decodeBit( - reader, - &self.is_match[(self.state << 4) + pos_state], - update, - )) { - const byte: u8 = try self.decodeLiteral(reader, buffer, decoder, update); - - if (update) { - try buffer.appendLiteral(allocator, byte, writer); - - self.state = if (self.state < 4) - 0 - else if (self.state < 10) - self.state - 3 - else - self.state - 6; - } - return .continue_; - } - - var len: usize = undefined; - if (try decoder.decodeBit(reader, &self.is_rep[self.state], update)) { - if (!try decoder.decodeBit(reader, &self.is_rep_g0[self.state], update)) { - if (!try decoder.decodeBit( - reader, - &self.is_rep_0long[(self.state << 4) + pos_state], - update, - )) { - if (update) { - self.state = if (self.state < 7) 9 else 11; - const dist = self.rep[0] + 1; - try buffer.appendLz(allocator, 1, dist, writer); - } - return .continue_; - } - } else { - const idx: usize = if (!try decoder.decodeBit(reader, &self.is_rep_g1[self.state], update)) - 1 - else if (!try decoder.decodeBit(reader, &self.is_rep_g2[self.state], update)) - 2 - else - 3; - if (update) { - const dist = self.rep[idx]; - var i = idx; - while (i > 0) : (i -= 1) { - self.rep[i] = self.rep[i - 1]; - } - self.rep[0] = dist; - } - } - - len = try self.rep_len_decoder.decode(reader, decoder, pos_state, update); - - if (update) { - self.state = if (self.state < 7) 8 else 11; - } - } else { - if (update) { - self.rep[3] = self.rep[2]; - self.rep[2] = self.rep[1]; - self.rep[1] = self.rep[0]; - } - - len = try self.len_decoder.decode(reader, decoder, pos_state, update); - - if (update) { - self.state = if (self.state < 7) 7 else 10; - } - - const rep_0 = try self.decodeDistance(reader, decoder, len, update); - - if (update) { - self.rep[0] = rep_0; - if (self.rep[0] == 0xFFFF_FFFF) { - if (decoder.isFinished()) { - return .finished; - } - return error.CorruptInput; - } - } - } - - if (update) { - len += 2; - - const dist = self.rep[0] + 1; - try buffer.appendLz(allocator, len, dist, writer); - } - - return .continue_; - } - - fn processNext( - self: *DecoderState, - allocator: Allocator, - reader: anytype, - writer: anytype, - buffer: anytype, - decoder: *RangeDecoder, - ) !ProcessingStatus { - return self.processNextInner(allocator, reader, writer, buffer, decoder, true); - } - - pub fn process( - self: *DecoderState, - allocator: Allocator, - reader: anytype, - writer: anytype, - buffer: anytype, - decoder: *RangeDecoder, - ) !ProcessingStatus { - process_next: { - if (self.unpacked_size) |unpacked_size| { - if (buffer.len >= unpacked_size) { - break :process_next; - } - } else if (decoder.isFinished()) { - break :process_next; - } - - switch (try self.processNext(allocator, reader, writer, buffer, decoder)) { - .continue_ => return .continue_, - .finished => break :process_next, - } - } - - if (self.unpacked_size) |unpacked_size| { - if (buffer.len != unpacked_size) { - return error.CorruptInput; - } - } - - return .finished; - } - - fn decodeLiteral( - self: *DecoderState, - reader: anytype, - buffer: anytype, - decoder: *RangeDecoder, - update: bool, - ) !u8 { - const def_prev_byte = 0; - const prev_byte = @as(usize, buffer.lastOr(def_prev_byte)); - - var result: usize = 1; - const lit_state = ((buffer.len & ((@as(usize, 1) << self.lzma_props.lp) - 1)) << self.lzma_props.lc) + - (prev_byte >> (8 - self.lzma_props.lc)); - const probs = try self.literal_probs.getMut(lit_state); - - if (self.state >= 7) { - var match_byte = @as(usize, try buffer.lastN(self.rep[0] + 1)); - - while (result < 0x100) { - const match_bit = (match_byte >> 7) & 1; - match_byte <<= 1; - const bit = @intFromBool(try decoder.decodeBit( - reader, - &probs[((@as(usize, 1) + match_bit) << 8) + result], - update, - )); - result = (result << 1) ^ bit; - if (match_bit != bit) { - break; - } - } - } - - while (result < 0x100) { - result = (result << 1) ^ @intFromBool(try decoder.decodeBit(reader, &probs[result], update)); - } - - return @as(u8, @truncate(result - 0x100)); - } - - fn decodeDistance( - self: *DecoderState, - reader: anytype, - decoder: *RangeDecoder, - length: usize, - update: bool, - ) !usize { - const len_state = if (length > 3) 3 else length; - - const pos_slot = @as(usize, try self.pos_slot_decoder[len_state].parse(reader, decoder, update)); - if (pos_slot < 4) - return pos_slot; - - const num_direct_bits = @as(u5, @intCast((pos_slot >> 1) - 1)); - var result = (2 ^ (pos_slot & 1)) << num_direct_bits; - - if (pos_slot < 14) { - result += try decoder.parseReverseBitTree( - reader, - num_direct_bits, - &self.pos_decoders, - result - pos_slot, - update, - ); - } else { - result += @as(usize, try decoder.get(reader, num_direct_bits - 4)) << 4; - result += try self.align_decoder.parseReverse(reader, decoder, update); - } - - return result; - } -}; diff --git a/lib/std/compress/lzma/decode/lzbuffer.zig b/lib/std/compress/lzma/decode/lzbuffer.zig deleted file mode 100644 index 80c470c5f9..0000000000 --- a/lib/std/compress/lzma/decode/lzbuffer.zig +++ /dev/null @@ -1,228 +0,0 @@ -const std = @import("../../../std.zig"); -const math = std.math; -const mem = std.mem; -const Allocator = std.mem.Allocator; -const ArrayListUnmanaged = std.ArrayListUnmanaged; - -/// An accumulating buffer for LZ sequences -pub const LzAccumBuffer = struct { - /// Buffer - buf: ArrayListUnmanaged(u8), - - /// Buffer memory limit - memlimit: usize, - - /// Total number of bytes sent through the buffer - len: usize, - - const Self = @This(); - - pub fn init(memlimit: usize) Self { - return Self{ - .buf = .{}, - .memlimit = memlimit, - .len = 0, - }; - } - - pub fn appendByte(self: *Self, allocator: Allocator, byte: u8) !void { - try self.buf.append(allocator, byte); - self.len += 1; - } - - /// Reset the internal dictionary - pub fn reset(self: *Self, writer: anytype) !void { - try writer.writeAll(self.buf.items); - self.buf.clearRetainingCapacity(); - self.len = 0; - } - - /// Retrieve the last byte or return a default - pub fn lastOr(self: Self, lit: u8) u8 { - const buf_len = self.buf.items.len; - return if (buf_len == 0) - lit - else - self.buf.items[buf_len - 1]; - } - - /// Retrieve the n-th last byte - pub fn lastN(self: Self, dist: usize) !u8 { - const buf_len = self.buf.items.len; - if (dist > buf_len) { - return error.CorruptInput; - } - - return self.buf.items[buf_len - dist]; - } - - /// Append a literal - pub fn appendLiteral( - self: *Self, - allocator: Allocator, - lit: u8, - writer: anytype, - ) !void { - _ = writer; - if (self.len >= self.memlimit) { - return error.CorruptInput; - } - try self.buf.append(allocator, lit); - self.len += 1; - } - - /// Fetch an LZ sequence (length, distance) from inside the buffer - pub fn appendLz( - self: *Self, - allocator: Allocator, - len: usize, - dist: usize, - writer: anytype, - ) !void { - _ = writer; - - const buf_len = self.buf.items.len; - if (dist > buf_len) { - return error.CorruptInput; - } - - var offset = buf_len - dist; - var i: usize = 0; - while (i < len) : (i += 1) { - const x = self.buf.items[offset]; - try self.buf.append(allocator, x); - offset += 1; - } - self.len += len; - } - - pub fn finish(self: *Self, writer: anytype) !void { - try writer.writeAll(self.buf.items); - self.buf.clearRetainingCapacity(); - } - - pub fn deinit(self: *Self, allocator: Allocator) void { - self.buf.deinit(allocator); - self.* = undefined; - } -}; - -/// A circular buffer for LZ sequences -pub const LzCircularBuffer = struct { - /// Circular buffer - buf: ArrayListUnmanaged(u8), - - /// Length of the buffer - dict_size: usize, - - /// Buffer memory limit - memlimit: usize, - - /// Current position - cursor: usize, - - /// Total number of bytes sent through the buffer - len: usize, - - const Self = @This(); - - pub fn init(dict_size: usize, memlimit: usize) Self { - return Self{ - .buf = .{}, - .dict_size = dict_size, - .memlimit = memlimit, - .cursor = 0, - .len = 0, - }; - } - - pub fn get(self: Self, index: usize) u8 { - return if (0 <= index and index < self.buf.items.len) - self.buf.items[index] - else - 0; - } - - pub fn set(self: *Self, allocator: Allocator, index: usize, value: u8) !void { - if (index >= self.memlimit) { - return error.CorruptInput; - } - try self.buf.ensureTotalCapacity(allocator, index + 1); - while (self.buf.items.len < index) { - self.buf.appendAssumeCapacity(0); - } - self.buf.appendAssumeCapacity(value); - } - - /// Retrieve the last byte or return a default - pub fn lastOr(self: Self, lit: u8) u8 { - return if (self.len == 0) - lit - else - self.get((self.dict_size + self.cursor - 1) % self.dict_size); - } - - /// Retrieve the n-th last byte - pub fn lastN(self: Self, dist: usize) !u8 { - if (dist > self.dict_size or dist > self.len) { - return error.CorruptInput; - } - - const offset = (self.dict_size + self.cursor - dist) % self.dict_size; - return self.get(offset); - } - - /// Append a literal - pub fn appendLiteral( - self: *Self, - allocator: Allocator, - lit: u8, - writer: anytype, - ) !void { - try self.set(allocator, self.cursor, lit); - self.cursor += 1; - self.len += 1; - - // Flush the circular buffer to the output - if (self.cursor == self.dict_size) { - try writer.writeAll(self.buf.items); - self.cursor = 0; - } - } - - /// Fetch an LZ sequence (length, distance) from inside the buffer - pub fn appendLz( - self: *Self, - allocator: Allocator, - len: usize, - dist: usize, - writer: anytype, - ) !void { - if (dist > self.dict_size or dist > self.len) { - return error.CorruptInput; - } - - var offset = (self.dict_size + self.cursor - dist) % self.dict_size; - var i: usize = 0; - while (i < len) : (i += 1) { - const x = self.get(offset); - try self.appendLiteral(allocator, x, writer); - offset += 1; - if (offset == self.dict_size) { - offset = 0; - } - } - } - - pub fn finish(self: *Self, writer: anytype) !void { - if (self.cursor > 0) { - try writer.writeAll(self.buf.items[0..self.cursor]); - self.cursor = 0; - } - } - - pub fn deinit(self: *Self, allocator: Allocator) void { - self.buf.deinit(allocator); - self.* = undefined; - } -}; diff --git a/lib/std/compress/lzma/test.zig b/lib/std/compress/lzma/test.zig deleted file mode 100644 index eafb91b6bb..0000000000 --- a/lib/std/compress/lzma/test.zig +++ /dev/null @@ -1,99 +0,0 @@ -const std = @import("../../std.zig"); -const lzma = @import("../lzma.zig"); - -fn testDecompress(compressed: []const u8) ![]u8 { - const allocator = std.testing.allocator; - var stream = std.io.fixedBufferStream(compressed); - var decompressor = try lzma.decompress(allocator, stream.reader()); - defer decompressor.deinit(); - const reader = decompressor.reader(); - return reader.readAllAlloc(allocator, std.math.maxInt(usize)); -} - -fn testDecompressEqual(expected: []const u8, compressed: []const u8) !void { - const allocator = std.testing.allocator; - const decomp = try testDecompress(compressed); - defer allocator.free(decomp); - try std.testing.expectEqualSlices(u8, expected, decomp); -} - -fn testDecompressError(expected: anyerror, compressed: []const u8) !void { - return std.testing.expectError(expected, testDecompress(compressed)); -} - -test "decompress empty world" { - try testDecompressEqual( - "", - &[_]u8{ - 0x5d, 0x00, 0x00, 0x80, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x83, 0xff, - 0xfb, 0xff, 0xff, 0xc0, 0x00, 0x00, 0x00, - }, - ); -} - -test "decompress hello world" { - try testDecompressEqual( - "Hello world\n", - &[_]u8{ - 0x5d, 0x00, 0x00, 0x80, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x24, 0x19, - 0x49, 0x98, 0x6f, 0x10, 0x19, 0xc6, 0xd7, 0x31, 0xeb, 0x36, 0x50, 0xb2, 0x98, 0x48, 0xff, 0xfe, - 0xa5, 0xb0, 0x00, - }, - ); -} - -test "decompress huge dict" { - try testDecompressEqual( - "Hello world\n", - &[_]u8{ - 0x5d, 0x7f, 0x7f, 0x7f, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x24, 0x19, - 0x49, 0x98, 0x6f, 0x10, 0x19, 0xc6, 0xd7, 0x31, 0xeb, 0x36, 0x50, 0xb2, 0x98, 0x48, 0xff, 0xfe, - 0xa5, 0xb0, 0x00, - }, - ); -} - -test "unknown size with end of payload marker" { - try testDecompressEqual( - "Hello\nWorld!\n", - @embedFile("testdata/good-unknown_size-with_eopm.lzma"), - ); -} - -test "known size without end of payload marker" { - try testDecompressEqual( - "Hello\nWorld!\n", - @embedFile("testdata/good-known_size-without_eopm.lzma"), - ); -} - -test "known size with end of payload marker" { - try testDecompressEqual( - "Hello\nWorld!\n", - @embedFile("testdata/good-known_size-with_eopm.lzma"), - ); -} - -test "too big uncompressed size in header" { - try testDecompressError( - error.CorruptInput, - @embedFile("testdata/bad-too_big_size-with_eopm.lzma"), - ); -} - -test "too small uncompressed size in header" { - try testDecompressError( - error.CorruptInput, - @embedFile("testdata/bad-too_small_size-without_eopm-3.lzma"), - ); -} - -test "reading one byte" { - const compressed = @embedFile("testdata/good-known_size-with_eopm.lzma"); - var stream = std.io.fixedBufferStream(compressed); - var decompressor = try lzma.decompress(std.testing.allocator, stream.reader()); - defer decompressor.deinit(); - - var buffer = [1]u8{0}; - _ = try decompressor.read(buffer[0..]); -} diff --git a/lib/std/compress/lzma/vec2d.zig b/lib/std/compress/lzma/vec2d.zig deleted file mode 100644 index df61093b85..0000000000 --- a/lib/std/compress/lzma/vec2d.zig +++ /dev/null @@ -1,128 +0,0 @@ -const std = @import("../../std.zig"); -const math = std.math; -const mem = std.mem; -const Allocator = std.mem.Allocator; - -pub fn Vec2D(comptime T: type) type { - return struct { - data: []T, - cols: usize, - - const Self = @This(); - - pub fn init(allocator: Allocator, value: T, size: struct { usize, usize }) !Self { - const len = try math.mul(usize, size[0], size[1]); - const data = try allocator.alloc(T, len); - @memset(data, value); - return Self{ - .data = data, - .cols = size[1], - }; - } - - pub fn deinit(self: *Self, allocator: Allocator) void { - allocator.free(self.data); - self.* = undefined; - } - - pub fn fill(self: *Self, value: T) void { - @memset(self.data, value); - } - - inline fn _get(self: Self, row: usize) ![]T { - const start_row = try math.mul(usize, row, self.cols); - const end_row = try math.add(usize, start_row, self.cols); - return self.data[start_row..end_row]; - } - - pub fn get(self: Self, row: usize) ![]const T { - return self._get(row); - } - - pub fn getMut(self: *Self, row: usize) ![]T { - return self._get(row); - } - }; -} - -const testing = std.testing; -const expectEqualSlices = std.testing.expectEqualSlices; -const expectError = std.testing.expectError; - -test "init" { - const allocator = testing.allocator; - var vec2d = try Vec2D(i32).init(allocator, 1, .{ 2, 3 }); - defer vec2d.deinit(allocator); - - try expectEqualSlices(i32, &.{ 1, 1, 1 }, try vec2d.get(0)); - try expectEqualSlices(i32, &.{ 1, 1, 1 }, try vec2d.get(1)); -} - -test "init overflow" { - const allocator = testing.allocator; - try expectError( - error.Overflow, - Vec2D(i32).init(allocator, 1, .{ math.maxInt(usize), math.maxInt(usize) }), - ); -} - -test "fill" { - const allocator = testing.allocator; - var vec2d = try Vec2D(i32).init(allocator, 0, .{ 2, 3 }); - defer vec2d.deinit(allocator); - - vec2d.fill(7); - - try expectEqualSlices(i32, &.{ 7, 7, 7 }, try vec2d.get(0)); - try expectEqualSlices(i32, &.{ 7, 7, 7 }, try vec2d.get(1)); -} - -test "get" { - var data = [_]i32{ 0, 1, 2, 3, 4, 5, 6, 7 }; - const vec2d = Vec2D(i32){ - .data = &data, - .cols = 2, - }; - - try expectEqualSlices(i32, &.{ 0, 1 }, try vec2d.get(0)); - try expectEqualSlices(i32, &.{ 2, 3 }, try vec2d.get(1)); - try expectEqualSlices(i32, &.{ 4, 5 }, try vec2d.get(2)); - try expectEqualSlices(i32, &.{ 6, 7 }, try vec2d.get(3)); -} - -test "getMut" { - var data = [_]i32{ 0, 1, 2, 3, 4, 5, 6, 7 }; - var vec2d = Vec2D(i32){ - .data = &data, - .cols = 2, - }; - - const row = try vec2d.getMut(1); - row[1] = 9; - - try expectEqualSlices(i32, &.{ 0, 1 }, try vec2d.get(0)); - // (1, 1) should be 9. - try expectEqualSlices(i32, &.{ 2, 9 }, try vec2d.get(1)); - try expectEqualSlices(i32, &.{ 4, 5 }, try vec2d.get(2)); - try expectEqualSlices(i32, &.{ 6, 7 }, try vec2d.get(3)); -} - -test "get multiplication overflow" { - const allocator = testing.allocator; - var matrix = try Vec2D(i32).init(allocator, 0, .{ 3, 4 }); - defer matrix.deinit(allocator); - - const row = (math.maxInt(usize) / 4) + 1; - try expectError(error.Overflow, matrix.get(row)); - try expectError(error.Overflow, matrix.getMut(row)); -} - -test "get addition overflow" { - const allocator = testing.allocator; - var matrix = try Vec2D(i32).init(allocator, 0, .{ 3, 5 }); - defer matrix.deinit(allocator); - - const row = math.maxInt(usize) / 5; - try expectError(error.Overflow, matrix.get(row)); - try expectError(error.Overflow, matrix.getMut(row)); -} diff --git a/lib/std/compress/lzma2.zig b/lib/std/compress/lzma2.zig index 4306e79214..6472e65760 100644 --- a/lib/std/compress/lzma2.zig +++ b/lib/std/compress/lzma2.zig @@ -1,15 +1,276 @@ const std = @import("../std.zig"); const Allocator = std.mem.Allocator; +const lzma = std.compress.lzma; -pub const decode = @import("lzma2/decode.zig"); - -pub fn decompress(allocator: Allocator, reader: *std.io.BufferedReader, writer: *std.io.BufferedWriter) !void { - var decoder = try decode.Decoder.init(allocator); - defer decoder.deinit(allocator); - return decoder.decompress(allocator, reader, writer); +pub fn decompress(gpa: Allocator, reader: *std.io.BufferedReader, writer: *std.io.BufferedWriter) anyerror!void { + var decoder = try Decode.init(gpa); + defer decoder.deinit(gpa); + return decoder.decompress(gpa, reader, writer); } -test { +pub const Decode = struct { + lzma1: lzma.Decode, + + pub fn init(allocator: Allocator) !Decode { + return .{ + .lzma1 = try lzma.Decode.init( + allocator, + .{ + .lc = 0, + .lp = 0, + .pb = 0, + }, + null, + ), + }; + } + + pub fn deinit(self: *Decode, allocator: Allocator) void { + self.lzma1.deinit(allocator); + self.* = undefined; + } + + pub fn decompress( + self: *Decode, + allocator: Allocator, + reader: *std.io.BufferedReader, + writer: *std.io.BufferedWriter, + ) !void { + var accum = LzAccumBuffer.init(std.math.maxInt(usize)); + defer accum.deinit(allocator); + + while (true) { + const status = try reader.takeByte(); + + switch (status) { + 0 => break, + 1 => try parseUncompressed(allocator, reader, writer, &accum, true), + 2 => try parseUncompressed(allocator, reader, writer, &accum, false), + else => try self.parseLzma(allocator, reader, writer, &accum, status), + } + } + + try accum.finish(writer); + } + + fn parseLzma( + self: *Decode, + allocator: Allocator, + br: *std.io.BufferedReader, + writer: *std.io.BufferedWriter, + accum: *LzAccumBuffer, + status: u8, + ) !void { + if (status & 0x80 == 0) { + return error.CorruptInput; + } + + const Reset = struct { + dict: bool, + state: bool, + props: bool, + }; + + const reset = switch ((status >> 5) & 0x3) { + 0 => Reset{ + .dict = false, + .state = false, + .props = false, + }, + 1 => Reset{ + .dict = false, + .state = true, + .props = false, + }, + 2 => Reset{ + .dict = false, + .state = true, + .props = true, + }, + 3 => Reset{ + .dict = true, + .state = true, + .props = true, + }, + else => unreachable, + }; + + const unpacked_size = blk: { + var tmp: u64 = status & 0x1F; + tmp <<= 16; + tmp |= try br.takeInt(u16, .big); + break :blk tmp + 1; + }; + + const packed_size = blk: { + const tmp: u17 = try br.takeInt(u16, .big); + break :blk tmp + 1; + }; + + if (reset.dict) { + try accum.reset(writer); + } + + if (reset.state) { + var new_props = self.lzma1.properties; + + if (reset.props) { + var props = try br.takeByte(); + if (props >= 225) { + return error.CorruptInput; + } + + const lc = @as(u4, @intCast(props % 9)); + props /= 9; + const lp = @as(u3, @intCast(props % 5)); + props /= 5; + const pb = @as(u3, @intCast(props)); + + if (lc + lp > 4) { + return error.CorruptInput; + } + + new_props = .{ .lc = lc, .lp = lp, .pb = pb }; + } + + try self.lzma1.resetState(allocator, new_props); + } + + self.lzma1.unpacked_size = unpacked_size + accum.len; + + var range_decoder: lzma.RangeDecoder = undefined; + var bytes_read = try lzma.RangeDecoder.init(br); + while (try self.lzma1.process(allocator, br, writer, accum, &range_decoder, &bytes_read) == .cont) {} + + if (bytes_read != packed_size) { + return error.CorruptInput; + } + } + + fn parseUncompressed( + allocator: Allocator, + reader: *std.io.BufferedReader, + writer: *std.io.BufferedWriter, + accum: *LzAccumBuffer, + reset_dict: bool, + ) !void { + const unpacked_size = @as(u17, try reader.takeInt(u16, .big)) + 1; + + if (reset_dict) { + try accum.reset(writer); + } + + var i: @TypeOf(unpacked_size) = 0; + while (i < unpacked_size) : (i += 1) { + try accum.appendByte(allocator, try reader.takeByte()); + } + } +}; + +/// An accumulating buffer for LZ sequences +const LzAccumBuffer = struct { + /// Buffer + buf: std.ArrayListUnmanaged(u8), + + /// Buffer memory limit + memlimit: usize, + + /// Total number of bytes sent through the buffer + len: usize, + + const Self = @This(); + + pub fn init(memlimit: usize) Self { + return Self{ + .buf = .{}, + .memlimit = memlimit, + .len = 0, + }; + } + + pub fn appendByte(self: *Self, allocator: Allocator, byte: u8) !void { + try self.buf.append(allocator, byte); + self.len += 1; + } + + /// Reset the internal dictionary + pub fn reset(self: *Self, writer: anytype) !void { + try writer.writeAll(self.buf.items); + self.buf.clearRetainingCapacity(); + self.len = 0; + } + + /// Retrieve the last byte or return a default + pub fn lastOr(self: Self, lit: u8) u8 { + const buf_len = self.buf.items.len; + return if (buf_len == 0) + lit + else + self.buf.items[buf_len - 1]; + } + + /// Retrieve the n-th last byte + pub fn lastN(self: Self, dist: usize) !u8 { + const buf_len = self.buf.items.len; + if (dist > buf_len) { + return error.CorruptInput; + } + + return self.buf.items[buf_len - dist]; + } + + /// Append a literal + pub fn appendLiteral( + self: *Self, + allocator: Allocator, + lit: u8, + writer: anytype, + ) !void { + _ = writer; + if (self.len >= self.memlimit) { + return error.CorruptInput; + } + try self.buf.append(allocator, lit); + self.len += 1; + } + + /// Fetch an LZ sequence (length, distance) from inside the buffer + pub fn appendLz( + self: *Self, + allocator: Allocator, + len: usize, + dist: usize, + writer: anytype, + ) !void { + _ = writer; + + const buf_len = self.buf.items.len; + if (dist > buf_len) { + return error.CorruptInput; + } + + var offset = buf_len - dist; + var i: usize = 0; + while (i < len) : (i += 1) { + const x = self.buf.items[offset]; + try self.buf.append(allocator, x); + offset += 1; + } + self.len += len; + } + + pub fn finish(self: *Self, writer: anytype) !void { + try writer.writeAll(self.buf.items); + self.buf.clearRetainingCapacity(); + } + + pub fn deinit(self: *Self, allocator: Allocator) void { + self.buf.deinit(allocator); + self.* = undefined; + } +}; + +test decompress { const expected = "Hello\nWorld!\n"; const compressed = [_]u8{ 0x01, 0x00, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x0A, 0x02, diff --git a/lib/std/compress/lzma2/decode.zig b/lib/std/compress/lzma2/decode.zig deleted file mode 100644 index cf37ec36ff..0000000000 --- a/lib/std/compress/lzma2/decode.zig +++ /dev/null @@ -1,169 +0,0 @@ -const std = @import("../../std.zig"); -const Allocator = std.mem.Allocator; - -const lzma = @import("../lzma.zig"); -const DecoderState = lzma.decode.DecoderState; -const LzAccumBuffer = lzma.decode.lzbuffer.LzAccumBuffer; -const Properties = lzma.decode.Properties; -const RangeDecoder = lzma.decode.RangeDecoder; - -pub const Decoder = struct { - lzma_state: DecoderState, - - pub fn init(allocator: Allocator) !Decoder { - return Decoder{ - .lzma_state = try DecoderState.init( - allocator, - Properties{ - .lc = 0, - .lp = 0, - .pb = 0, - }, - null, - ), - }; - } - - pub fn deinit(self: *Decoder, allocator: Allocator) void { - self.lzma_state.deinit(allocator); - self.* = undefined; - } - - pub fn decompress( - self: *Decoder, - allocator: Allocator, - reader: *std.io.BufferedReader, - writer: *std.io.BufferedWriter, - ) !void { - var accum = LzAccumBuffer.init(std.math.maxInt(usize)); - defer accum.deinit(allocator); - - while (true) { - const status = try reader.takeByte(); - - switch (status) { - 0 => break, - 1 => try parseUncompressed(allocator, reader, writer, &accum, true), - 2 => try parseUncompressed(allocator, reader, writer, &accum, false), - else => try self.parseLzma(allocator, reader, writer, &accum, status), - } - } - - try accum.finish(writer); - } - - fn parseLzma( - self: *Decoder, - allocator: Allocator, - br: *std.io.BufferedReader, - writer: *std.io.BufferedWriter, - accum: *LzAccumBuffer, - status: u8, - ) !void { - if (status & 0x80 == 0) { - return error.CorruptInput; - } - - const Reset = struct { - dict: bool, - state: bool, - props: bool, - }; - - const reset = switch ((status >> 5) & 0x3) { - 0 => Reset{ - .dict = false, - .state = false, - .props = false, - }, - 1 => Reset{ - .dict = false, - .state = true, - .props = false, - }, - 2 => Reset{ - .dict = false, - .state = true, - .props = true, - }, - 3 => Reset{ - .dict = true, - .state = true, - .props = true, - }, - else => unreachable, - }; - - const unpacked_size = blk: { - var tmp: u64 = status & 0x1F; - tmp <<= 16; - tmp |= try br.takeInt(u16, .big); - break :blk tmp + 1; - }; - - const packed_size = blk: { - const tmp: u17 = try br.takeInt(u16, .big); - break :blk tmp + 1; - }; - - if (reset.dict) { - try accum.reset(writer); - } - - if (reset.state) { - var new_props = self.lzma_state.lzma_props; - - if (reset.props) { - var props = try br.takeByte(); - if (props >= 225) { - return error.CorruptInput; - } - - const lc = @as(u4, @intCast(props % 9)); - props /= 9; - const lp = @as(u3, @intCast(props % 5)); - props /= 5; - const pb = @as(u3, @intCast(props)); - - if (lc + lp > 4) { - return error.CorruptInput; - } - - new_props = Properties{ .lc = lc, .lp = lp, .pb = pb }; - } - - try self.lzma_state.resetState(allocator, new_props); - } - - self.lzma_state.unpacked_size = unpacked_size + accum.len; - - var counter: std.io.CountingReader = .{ .child_reader = br.reader() }; - var counter_reader = counter.reader().unbuffered(); - - var rangecoder = try RangeDecoder.init(&counter_reader); - while (try self.lzma_state.process(allocator, counter_reader, writer, accum, &rangecoder) == .continue_) {} - - if (counter.bytes_read != packed_size) { - return error.CorruptInput; - } - } - - fn parseUncompressed( - allocator: Allocator, - reader: *std.io.BufferedReader, - writer: *std.io.BufferedWriter, - accum: *LzAccumBuffer, - reset_dict: bool, - ) !void { - const unpacked_size = @as(u17, try reader.takeInt(u16, .big)) + 1; - - if (reset_dict) { - try accum.reset(writer); - } - - var i: @TypeOf(unpacked_size) = 0; - while (i < unpacked_size) : (i += 1) { - try accum.appendByte(allocator, try reader.takeByte()); - } - } -}; diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index 7b41e1fe3e..58cec0d733 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -16,191 +16,187 @@ pub const DecompressorOptions = struct { pub const default_window_buffer_len = 8 * 1024 * 1024; }; -pub fn Decompressor(comptime ReaderType: type) type { - return struct { - const Self = @This(); +pub const Decompressor = struct { + const Self = @This(); - const table_size_max = types.compressed_block.table_size_max; + const table_size_max = types.compressed_block.table_size_max; - source: std.io.CountingReader(ReaderType), - state: enum { NewFrame, InFrame, LastBlock }, - decode_state: decompress.block.DecodeState, - frame_context: decompress.FrameContext, - buffer: WindowBuffer, - literal_fse_buffer: [table_size_max.literal]types.compressed_block.Table.Fse, - match_fse_buffer: [table_size_max.match]types.compressed_block.Table.Fse, - offset_fse_buffer: [table_size_max.offset]types.compressed_block.Table.Fse, - literals_buffer: [types.block_size_max]u8, - sequence_buffer: [types.block_size_max]u8, - verify_checksum: bool, - checksum: ?u32, - current_frame_decompressed_size: usize, + source: std.io.CountingReader, + state: enum { NewFrame, InFrame, LastBlock }, + decode_state: decompress.block.DecodeState, + frame_context: decompress.FrameContext, + buffer: WindowBuffer, + literal_fse_buffer: [table_size_max.literal]types.compressed_block.Table.Fse, + match_fse_buffer: [table_size_max.match]types.compressed_block.Table.Fse, + offset_fse_buffer: [table_size_max.offset]types.compressed_block.Table.Fse, + literals_buffer: [types.block_size_max]u8, + sequence_buffer: [types.block_size_max]u8, + verify_checksum: bool, + checksum: ?u32, + current_frame_decompressed_size: usize, - const WindowBuffer = struct { - data: []u8 = undefined, - read_index: usize = 0, - write_index: usize = 0, - }; - - pub const Error = ReaderType.Error || error{ - ChecksumFailure, - DictionaryIdFlagUnsupported, - MalformedBlock, - MalformedFrame, - OutOfMemory, - }; - - pub const Reader = std.io.Reader(*Self, Error, read); - - pub fn init(source: ReaderType, options: DecompressorOptions) Self { - return .{ - .source = std.io.countingReader(source), - .state = .NewFrame, - .decode_state = undefined, - .frame_context = undefined, - .buffer = .{ .data = options.window_buffer }, - .literal_fse_buffer = undefined, - .match_fse_buffer = undefined, - .offset_fse_buffer = undefined, - .literals_buffer = undefined, - .sequence_buffer = undefined, - .verify_checksum = options.verify_checksum, - .checksum = undefined, - .current_frame_decompressed_size = undefined, - }; - } - - fn frameInit(self: *Self) !void { - const source_reader = self.source.reader(); - switch (try decompress.decodeFrameHeader(source_reader)) { - .skippable => |header| { - try source_reader.skipBytes(header.frame_size, .{}); - self.state = .NewFrame; - }, - .zstandard => |header| { - const frame_context = try decompress.FrameContext.init( - header, - self.buffer.data.len, - self.verify_checksum, - ); - - const decode_state = decompress.block.DecodeState.init( - &self.literal_fse_buffer, - &self.match_fse_buffer, - &self.offset_fse_buffer, - ); - - self.decode_state = decode_state; - self.frame_context = frame_context; - - self.checksum = null; - self.current_frame_decompressed_size = 0; - - self.state = .InFrame; - }, - } - } - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - - pub fn read(self: *Self, buffer: []u8) Error!usize { - if (buffer.len == 0) return 0; - - var size: usize = 0; - while (size == 0) { - while (self.state == .NewFrame) { - const initial_count = self.source.bytes_read; - self.frameInit() catch |err| switch (err) { - error.DictionaryIdFlagUnsupported => return error.DictionaryIdFlagUnsupported, - error.EndOfStream => return if (self.source.bytes_read == initial_count) - 0 - else - error.MalformedFrame, - else => return error.MalformedFrame, - }; - } - size = try self.readInner(buffer); - } - return size; - } - - fn readInner(self: *Self, buffer: []u8) Error!usize { - std.debug.assert(self.state != .NewFrame); - - var ring_buffer = RingBuffer{ - .data = self.buffer.data, - .read_index = self.buffer.read_index, - .write_index = self.buffer.write_index, - }; - defer { - self.buffer.read_index = ring_buffer.read_index; - self.buffer.write_index = ring_buffer.write_index; - } - - const source_reader = self.source.reader(); - while (ring_buffer.isEmpty() and self.state != .LastBlock) { - const header_bytes = source_reader.readBytesNoEof(3) catch - return error.MalformedFrame; - const block_header = decompress.block.decodeBlockHeader(&header_bytes); - - decompress.block.decodeBlockReader( - &ring_buffer, - source_reader, - block_header, - &self.decode_state, - self.frame_context.block_size_max, - &self.literals_buffer, - &self.sequence_buffer, - ) catch - return error.MalformedBlock; - - if (self.frame_context.content_size) |size| { - if (self.current_frame_decompressed_size > size) return error.MalformedFrame; - } - - const size = ring_buffer.len(); - self.current_frame_decompressed_size += size; - - if (self.frame_context.hasher_opt) |*hasher| { - if (size > 0) { - const written_slice = ring_buffer.sliceLast(size); - hasher.update(written_slice.first); - hasher.update(written_slice.second); - } - } - if (block_header.last_block) { - self.state = .LastBlock; - if (self.frame_context.has_checksum) { - const checksum = source_reader.readInt(u32, .little) catch - return error.MalformedFrame; - if (self.verify_checksum) { - if (self.frame_context.hasher_opt) |*hasher| { - if (checksum != decompress.computeChecksum(hasher)) - return error.ChecksumFailure; - } - } - } - if (self.frame_context.content_size) |content_size| { - if (content_size != self.current_frame_decompressed_size) { - return error.MalformedFrame; - } - } - } - } - - const size = @min(ring_buffer.len(), buffer.len); - if (size > 0) { - ring_buffer.readFirstAssumeLength(buffer, size); - } - if (self.state == .LastBlock and ring_buffer.len() == 0) { - self.state = .NewFrame; - } - return size; - } + const WindowBuffer = struct { + data: []u8 = undefined, + read_index: usize = 0, + write_index: usize = 0, }; -} + + pub const Error = anyerror || error{ + ChecksumFailure, + DictionaryIdFlagUnsupported, + MalformedBlock, + MalformedFrame, + OutOfMemory, + }; + + pub fn init(source: *std.io.BufferedReader, options: DecompressorOptions) Self { + return .{ + .source = std.io.countingReader(source), + .state = .NewFrame, + .decode_state = undefined, + .frame_context = undefined, + .buffer = .{ .data = options.window_buffer }, + .literal_fse_buffer = undefined, + .match_fse_buffer = undefined, + .offset_fse_buffer = undefined, + .literals_buffer = undefined, + .sequence_buffer = undefined, + .verify_checksum = options.verify_checksum, + .checksum = undefined, + .current_frame_decompressed_size = undefined, + }; + } + + fn frameInit(self: *Self) !void { + const source_reader = self.source; + switch (try decompress.decodeFrameHeader(source_reader)) { + .skippable => |header| { + try source_reader.skipBytes(header.frame_size, .{}); + self.state = .NewFrame; + }, + .zstandard => |header| { + const frame_context = try decompress.FrameContext.init( + header, + self.buffer.data.len, + self.verify_checksum, + ); + + const decode_state = decompress.block.DecodeState.init( + &self.literal_fse_buffer, + &self.match_fse_buffer, + &self.offset_fse_buffer, + ); + + self.decode_state = decode_state; + self.frame_context = frame_context; + + self.checksum = null; + self.current_frame_decompressed_size = 0; + + self.state = .InFrame; + }, + } + } + + pub fn reader(self: *Self) std.io.Reader { + return .{ .context = self }; + } + + pub fn read(self: *Self, buffer: []u8) Error!usize { + if (buffer.len == 0) return 0; + + var size: usize = 0; + while (size == 0) { + while (self.state == .NewFrame) { + const initial_count = self.source.bytes_read; + self.frameInit() catch |err| switch (err) { + error.DictionaryIdFlagUnsupported => return error.DictionaryIdFlagUnsupported, + error.EndOfStream => return if (self.source.bytes_read == initial_count) + 0 + else + error.MalformedFrame, + else => return error.MalformedFrame, + }; + } + size = try self.readInner(buffer); + } + return size; + } + + fn readInner(self: *Self, buffer: []u8) Error!usize { + std.debug.assert(self.state != .NewFrame); + + var ring_buffer = RingBuffer{ + .data = self.buffer.data, + .read_index = self.buffer.read_index, + .write_index = self.buffer.write_index, + }; + defer { + self.buffer.read_index = ring_buffer.read_index; + self.buffer.write_index = ring_buffer.write_index; + } + + const source_reader = self.source; + while (ring_buffer.isEmpty() and self.state != .LastBlock) { + const header_bytes = source_reader.readBytesNoEof(3) catch + return error.MalformedFrame; + const block_header = decompress.block.decodeBlockHeader(&header_bytes); + + decompress.block.decodeBlockReader( + &ring_buffer, + source_reader, + block_header, + &self.decode_state, + self.frame_context.block_size_max, + &self.literals_buffer, + &self.sequence_buffer, + ) catch + return error.MalformedBlock; + + if (self.frame_context.content_size) |size| { + if (self.current_frame_decompressed_size > size) return error.MalformedFrame; + } + + const size = ring_buffer.len(); + self.current_frame_decompressed_size += size; + + if (self.frame_context.hasher_opt) |*hasher| { + if (size > 0) { + const written_slice = ring_buffer.sliceLast(size); + hasher.update(written_slice.first); + hasher.update(written_slice.second); + } + } + if (block_header.last_block) { + self.state = .LastBlock; + if (self.frame_context.has_checksum) { + const checksum = source_reader.readInt(u32, .little) catch + return error.MalformedFrame; + if (self.verify_checksum) { + if (self.frame_context.hasher_opt) |*hasher| { + if (checksum != decompress.computeChecksum(hasher)) + return error.ChecksumFailure; + } + } + } + if (self.frame_context.content_size) |content_size| { + if (content_size != self.current_frame_decompressed_size) { + return error.MalformedFrame; + } + } + } + } + + const size = @min(ring_buffer.len(), buffer.len); + if (size > 0) { + ring_buffer.readFirstAssumeLength(buffer, size); + } + if (self.state == .LastBlock and ring_buffer.len() == 0) { + self.state = .NewFrame; + } + return size; + } +}; pub fn decompressor(reader: anytype, options: DecompressorOptions) Decompressor(@TypeOf(reader)) { return Decompressor(@TypeOf(reader)).init(reader, options); diff --git a/lib/std/debug/Dwarf.zig b/lib/std/debug/Dwarf.zig index b4a287df9f..6acf7b1dcc 100644 --- a/lib/std/debug/Dwarf.zig +++ b/lib/std/debug/Dwarf.zig @@ -2212,7 +2212,7 @@ pub const ElfModule = struct { var separate_debug_filename: ?[]const u8 = null; var separate_debug_crc: ?u32 = null; - for (shdrs) |*shdr| { + shdrs: for (shdrs) |*shdr| { if (shdr.sh_type == elf.SHT_NULL or shdr.sh_type == elf.SHT_NOBITS) continue; const name = mem.sliceTo(header_strings[shdr.sh_name..], 0); @@ -2246,8 +2246,22 @@ pub const ElfModule = struct { const decompressed_section = try gpa.alloc(u8, ch_size); errdefer gpa.free(decompressed_section); - const read = zlib_stream.reader().readAll(decompressed_section) catch continue; - assert(read == decompressed_section.len); + { + var read_index: usize = 0; + while (true) { + const read_result = zlib_stream.streamReadVec(&.{decompressed_section[read_index..]}); + read_result.err catch { + gpa.free(decompressed_section); + continue :shdrs; + }; + read_index += read_result.len; + if (read_index == decompressed_section.len) break; + if (read_result.end) { + gpa.free(decompressed_section); + continue :shdrs; + } + } + } break :blk .{ .data = decompressed_section, diff --git a/lib/std/debug/FixedBufferReader.zig b/lib/std/debug/FixedBufferReader.zig index e4aec1a9c6..ff9c817bcb 100644 --- a/lib/std/debug/FixedBufferReader.zig +++ b/lib/std/debug/FixedBufferReader.zig @@ -1,5 +1,7 @@ //! Optimized for performance in debug builds. +// TODO I'm pretty sure this can be deleted thanks to the new std.io.BufferedReader semantics + const std = @import("../std.zig"); const MemoryAccessor = std.debug.MemoryAccessor; @@ -9,20 +11,20 @@ buf: []const u8, pos: usize = 0, endian: std.builtin.Endian, -pub const Error = error{ EndOfBuffer, Overflow, InvalidBuffer }; +pub const Error = error{ EndOfStream, Overflow, InvalidBuffer }; pub fn seekTo(fbr: *FixedBufferReader, pos: u64) Error!void { - if (pos > fbr.buf.len) return error.EndOfBuffer; + if (pos > fbr.buf.len) return error.EndOfStream; fbr.pos = @intCast(pos); } pub fn seekForward(fbr: *FixedBufferReader, amount: u64) Error!void { - if (fbr.buf.len - fbr.pos < amount) return error.EndOfBuffer; + if (fbr.buf.len - fbr.pos < amount) return error.EndOfStream; fbr.pos += @intCast(amount); } pub inline fn readByte(fbr: *FixedBufferReader) Error!u8 { - if (fbr.pos >= fbr.buf.len) return error.EndOfBuffer; + if (fbr.pos >= fbr.buf.len) return error.EndOfStream; defer fbr.pos += 1; return fbr.buf[fbr.pos]; } @@ -33,7 +35,7 @@ pub fn readByteSigned(fbr: *FixedBufferReader) Error!i8 { pub fn readInt(fbr: *FixedBufferReader, comptime T: type) Error!T { const size = @divExact(@typeInfo(T).int.bits, 8); - if (fbr.buf.len - fbr.pos < size) return error.EndOfBuffer; + if (fbr.buf.len - fbr.pos < size) return error.EndOfStream; defer fbr.pos += size; return std.mem.readInt(T, fbr.buf[fbr.pos..][0..size], fbr.endian); } @@ -50,11 +52,21 @@ pub fn readIntChecked( } pub fn readUleb128(fbr: *FixedBufferReader, comptime T: type) Error!T { - return std.leb.readUleb128(T, fbr); + var br: std.io.BufferedReader = undefined; + br.initFixed(fbr.buf); + br.seek = fbr.pos; + const result = br.takeUleb128(T); + fbr.pos = br.seek; + return @errorCast(result); } pub fn readIleb128(fbr: *FixedBufferReader, comptime T: type) Error!T { - return std.leb.readIleb128(T, fbr); + var br: std.io.BufferedReader = undefined; + br.initFixed(fbr.buf); + br.seek = fbr.pos; + const result = br.takeIleb128(T); + fbr.pos = br.seek; + return @errorCast(result); } pub fn readAddress(fbr: *FixedBufferReader, format: std.dwarf.Format) Error!u64 { @@ -76,7 +88,7 @@ pub fn readAddressChecked( } pub fn readBytes(fbr: *FixedBufferReader, len: usize) Error![]const u8 { - if (fbr.buf.len - fbr.pos < len) return error.EndOfBuffer; + if (fbr.buf.len - fbr.pos < len) return error.EndOfStream; defer fbr.pos += len; return fbr.buf[fbr.pos..][0..len]; } @@ -87,7 +99,7 @@ pub fn readBytesTo(fbr: *FixedBufferReader, comptime sentinel: u8) Error![:senti fbr.buf, fbr.pos, sentinel, - }) orelse return error.EndOfBuffer; + }) orelse return error.EndOfStream; defer fbr.pos = end + 1; return fbr.buf[fbr.pos..end :sentinel]; } diff --git a/lib/std/debug/SelfInfo.zig b/lib/std/debug/SelfInfo.zig index 97b861e516..d1567c7560 100644 --- a/lib/std/debug/SelfInfo.zig +++ b/lib/std/debug/SelfInfo.zig @@ -2028,13 +2028,13 @@ pub const VirtualMachine = struct { var prev_row: Row = self.current_row; var cie_stream: std.io.BufferedReader = undefined; - cie_stream.initFixed(&cie.initial_instructions); + cie_stream.initFixed(cie.initial_instructions); var fde_stream: std.io.BufferedReader = undefined; - fde_stream.initFixed(&fde.instructions); - const streams: [2]*std.io.FixedBufferStream = .{ &cie_stream, &fde_stream }; + fde_stream.initFixed(fde.instructions); + const streams: [2]*std.io.BufferedReader = .{ &cie_stream, &fde_stream }; for (&streams, 0..) |stream, i| { - while (stream.pos < stream.buffer.len) { + while (stream.seek < stream.buffer.len) { const instruction = try std.debug.Dwarf.call_frame.Instruction.read(stream, addr_size_bytes, endian); prev_row = try self.step(allocator, cie, i == 0, instruction); if (pc < fde.pc_begin + self.current_row.offset) return prev_row; diff --git a/lib/std/fmt.zig b/lib/std/fmt.zig index b095be173d..523f78f7c2 100644 --- a/lib/std/fmt.zig +++ b/lib/std/fmt.zig @@ -91,7 +91,7 @@ pub const Options = struct { /// A user type may be a `struct`, `vector`, `union` or `enum` type. /// /// To print literal curly braces, escape them by writing them twice, e.g. `{{` or `}}`. -pub fn format(bw: *std.io.BufferedWriter, comptime fmt: []const u8, args: anytype) anyerror!void { +pub fn format(bw: *std.io.BufferedWriter, comptime fmt: []const u8, args: anytype) anyerror!usize { const ArgsType = @TypeOf(args); const args_type_info = @typeInfo(ArgsType); if (args_type_info != .@"struct") { @@ -107,6 +107,7 @@ pub fn format(bw: *std.io.BufferedWriter, comptime fmt: []const u8, args: anytyp comptime var arg_state: ArgState = .{ .args_len = fields_info.len }; comptime var i = 0; comptime var literal: []const u8 = ""; + var bytes_written: usize = 0; inline while (true) { const start_index = i; @@ -136,7 +137,7 @@ pub fn format(bw: *std.io.BufferedWriter, comptime fmt: []const u8, args: anytyp // Write out the literal if (literal.len != 0) { - try bw.writeAll(literal); + bytes_written += try bw.writeAllCount(literal); literal = ""; } @@ -196,7 +197,7 @@ pub fn format(bw: *std.io.BufferedWriter, comptime fmt: []const u8, args: anytyp const arg_to_print = comptime arg_state.nextArg(arg_pos) orelse @compileError("too few arguments"); - try bw.printValue( + bytes_written += try bw.printValue( placeholder.specifier_arg, .{ .fill = placeholder.fill, @@ -217,6 +218,8 @@ pub fn format(bw: *std.io.BufferedWriter, comptime fmt: []const u8, args: anytyp else => @compileError(comptimePrint("{d}", .{missing_count}) ++ " unused arguments in '" ++ fmt ++ "'"), } } + + return bytes_written; } fn cacheString(str: anytype) []const u8 { @@ -852,11 +855,10 @@ pub fn bufPrintZ(buf: []u8, comptime fmt: []const u8, args: anytype) BufPrintErr } /// Count the characters needed for format. -pub fn count(comptime fmt: []const u8, args: anytype) u64 { - var counting_writer: std.io.CountingWriter = .{ .child_writer = std.io.null_writer }; - var bw = counting_writer.writer().unbuffered(); - bw.print(fmt, args) catch unreachable; - return counting_writer.bytes_written; +pub fn count(comptime fmt: []const u8, args: anytype) usize { + var buffer: [std.atomic.cache_line]u8 = undefined; + var bw = std.io.Writer.null.buffered(&buffer); + return bw.printCount(fmt, args) catch unreachable; } pub const AllocPrintError = error{OutOfMemory}; diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index 95b36b3404..b1aab92855 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -1512,23 +1512,48 @@ pub fn writeFileAllUnseekable(self: File, in_file: File, args: WriteFileOptions) return @errorCast(writeFileAllUnseekableInner(self, in_file, args)); } -fn writeFileAllUnseekableInner(self: File, in_file: File, args: WriteFileOptions) anyerror!void { +fn writeFileAllUnseekableInner(out_file: File, in_file: File, args: WriteFileOptions) anyerror!void { const headers = args.headers_and_trailers[0..args.header_count]; const trailers = args.headers_and_trailers[args.header_count..]; - try self.writevAll(headers); + try out_file.writevAll(headers); - try in_file.reader().skipBytes(args.in_offset, .{ .buf_size = 4096 }); + // Some possible optimizations here: + // * Could writev buffer multiple times if the amount to discard is larger than 4096 + // * Could combine discard and read in one readv if amount to discard is small - var fifo = std.fifo.LinearFifo(u8, .{ .Static = 4096 }).init(); + var buffer: [4096]u8 = undefined; + var remaining = args.in_offset; + while (remaining > 0) { + const n = try in_file.read(buffer[0..@min(buffer.len, remaining)]); + if (n == 0) return error.EndOfStream; + remaining -= n; + } if (args.in_len) |len| { - var stream = std.io.limitedReader(in_file.reader(), len); - try fifo.pump(stream.reader(), self.writer()); + remaining = len; + var buffer_index: usize = 0; + while (remaining > 0) { + const n = buffer_index + try in_file.read(buffer[buffer_index..@min(buffer.len, remaining)]); + if (n == 0) return error.EndOfStream; + const written = try out_file.write(buffer[0..n]); + if (written == 0) return error.EndOfStream; + remaining -= written; + std.mem.copyForwards(u8, &buffer, buffer[written..n]); + buffer_index = n - written; + } } else { - try fifo.pump(in_file.reader(), self.writer()); + var buffer_index: usize = 0; + while (true) { + const n = buffer_index + try in_file.read(buffer[buffer_index..]); + if (n == 0) break; + const written = try out_file.write(buffer[0..n]); + if (written == 0) return error.EndOfStream; + std.mem.copyForwards(u8, &buffer, buffer[written..n]); + buffer_index = n - written; + } } - try self.writevAll(trailers); + try out_file.writevAll(trailers); } /// Low level function which can fail for OS-specific reasons. @@ -1645,7 +1670,7 @@ pub fn reader_posReadVec(context: *anyopaque, data: []const []u8, offset: u64) a } pub fn reader_streamRead( - context: *anyopaque, + context: ?*anyopaque, bw: *std.io.BufferedWriter, limit: std.io.Reader.Limit, ) anyerror!std.io.Reader.Status { @@ -1658,7 +1683,7 @@ pub fn reader_streamRead( }; } -pub fn reader_streamReadVec(context: *anyopaque, data: []const []u8) anyerror!std.io.Reader.Status { +pub fn reader_streamReadVec(context: ?*anyopaque, data: []const []u8) anyerror!std.io.Reader.Status { const file = opaqueToHandle(context); const n = try file.readv(data); return .{ @@ -1667,12 +1692,12 @@ pub fn reader_streamReadVec(context: *anyopaque, data: []const []u8) anyerror!st }; } -pub fn writer_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { +pub fn writer_writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Result { const file = opaqueToHandle(context); var splat_buffer: [256]u8 = undefined; if (is_windows) { if (data.len == 1 and splat == 0) return 0; - return windows.WriteFile(file, data[0], null); + return .{ .len = windows.WriteFile(file, data[0], null) catch |err| return .{ .err = err } }; } var iovecs: [max_buffers_len]std.posix.iovec_const = undefined; var len: usize = @min(iovecs.len, data.len); @@ -1681,8 +1706,8 @@ pub fn writer_writeSplat(context: *anyopaque, data: []const []const u8, splat: u .len = d.len, }; switch (splat) { - 0 => return std.posix.writev(file, iovecs[0 .. len - 1]), - 1 => return std.posix.writev(file, iovecs[0..len]), + 0 => return .{ .len = std.posix.writev(file, iovecs[0 .. len - 1]) catch |err| return .{ .err = err } }, + 1 => return .{ .len = std.posix.writev(file, iovecs[0..len]) catch |err| return .{ .err = err } }, else => { const pattern = data[data.len - 1]; if (pattern.len == 1) { @@ -1700,21 +1725,21 @@ pub fn writer_writeSplat(context: *anyopaque, data: []const []const u8, splat: u iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat }; len += 1; } - return std.posix.writev(file, iovecs[0..len]); + return .{ .len = std.posix.writev(file, iovecs[0..len]) catch |err| return .{ .err = err } }; } }, } - return std.posix.writev(file, iovecs[0..len]); + return .{ .len = std.posix.writev(file, iovecs[0..len]) catch |err| return .{ .err = err } }; } pub fn writer_writeFile( - context: *anyopaque, + context: ?*anyopaque, in_file: std.fs.File, in_offset: u64, in_len: std.io.Writer.FileLen, headers_and_trailers: []const []const u8, headers_len: usize, -) anyerror!usize { +) std.io.Writer.Result { const out_fd = opaqueToHandle(context); const in_fd = in_file.handle; const len_int = switch (in_len) { diff --git a/lib/std/io.zig b/lib/std/io.zig index af10d6da9b..1f7817b01f 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -20,8 +20,6 @@ pub const Writer = @import("io/Writer.zig"); pub const BufferedReader = @import("io/BufferedReader.zig"); pub const BufferedWriter = @import("io/BufferedWriter.zig"); pub const AllocatingWriter = @import("io/AllocatingWriter.zig"); -pub const CountingWriter = @import("io/CountingWriter.zig"); -pub const CountingReader = @import("io/CountingReader.zig"); pub const CWriter = @import("io/c_writer.zig").CWriter; pub const cWriter = @import("io/c_writer.zig").cWriter; @@ -48,46 +46,6 @@ pub const BufferedAtomicFile = @import("io/buffered_atomic_file.zig").BufferedAt pub const tty = @import("io/tty.zig"); -/// A `Writer` that discards all data. -pub const null_writer: Writer = .{ - .context = undefined, - .vtable = &.{ - .writeSplat = null_writeSplat, - .writeFile = null_writeFile, - }, -}; - -fn null_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { - _ = context; - const headers = data[0 .. data.len - 1]; - const pattern = data[headers.len..]; - var written: usize = pattern.len * splat; - for (headers) |bytes| written += bytes.len; - return written; -} - -fn null_writeFile( - context: *anyopaque, - file: std.fs.File, - offset: u64, - len: Writer.FileLen, - headers_and_trailers: []const []const u8, - headers_len: usize, -) anyerror!usize { - _ = context; - _ = offset; - _ = headers_len; - _ = file; - if (len == .entire_file) return error.Unimplemented; - var n: usize = 0; - for (headers_and_trailers) |bytes| n += bytes.len; - return len.int() + n; -} - -test null_writer { - try null_writer.writeAll("yay"); -} - pub fn poll( allocator: Allocator, comptime StreamEnum: type, @@ -494,8 +452,6 @@ test { _ = BufferedReader; _ = Reader; _ = Writer; - _ = CountingWriter; - _ = CountingReader; _ = AllocatingWriter; _ = @import("io/bit_reader.zig"); _ = @import("io/bit_writer.zig"); diff --git a/lib/std/io/BufferedReader.zig b/lib/std/io/BufferedReader.zig index 7359dc8243..05fe0e8912 100644 --- a/lib/std/io/BufferedReader.zig +++ b/lib/std/io/BufferedReader.zig @@ -14,26 +14,37 @@ seek: usize, storage: BufferedWriter, unbuffered_reader: Reader, +pub fn init(br: *BufferedReader, r: Reader, buffer: []u8) void { + br.* = .{ + .seek = 0, + .storage = undefined, + .unbuffered_reader = r, + }; + br.storage.initFixed(buffer); +} + +/// Constructs `br` such that it will read from `buffer` and then end. pub fn initFixed(br: *BufferedReader, buffer: []const u8) void { br.* = .{ .seek = 0, .storage = .{ - .buffer = buffer, - .mode = .fixed, - }, - .reader = .{ - .context = br, - .vtable = &.{ - .streamRead = null, - .posRead = null, + .buffer = .initBuffer(@constCast(buffer)), + .unbuffered_writer = .{ + .context = undefined, + .vtable = &std.io.Writer.VTable.eof, }, }, + .unbuffered_reader = &.{ + .context = undefined, + .vtable = &std.io.Reader.VTable.eof, + }, }; } -pub fn deinit(br: *BufferedReader) void { - br.storage.deinit(); - br.* = undefined; +pub fn storageBuffer(br: *BufferedReader) []u8 { + assert(br.storage.unbuffered_writer.vtable == &std.io.Writer.VTable.eof); + assert(br.unbuffered_reader.vtable == &std.io.Reader.VTable.eof); + return br.storage.buffer.allocatedSlice(); } /// Although `BufferedReader` can easily satisfy the `Reader` interface, it's @@ -51,30 +62,31 @@ pub fn reader(br: *BufferedReader) Reader { }; } -fn passthru_streamRead(ctx: *anyopaque, bw: *BufferedWriter, limit: Reader.Limit) anyerror!Reader.Status { +fn passthru_streamRead(ctx: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit) anyerror!Reader.RwResult { const br: *BufferedReader = @alignCast(@ptrCast(ctx)); const buffer = br.storage.buffer.items; const buffered = buffer[br.seek..]; const limited = buffered[0..limit.min(buffered.len)]; if (limited.len > 0) { - const n = try bw.writeSplat(limited, 1); - br.seek += n; + const result = bw.writeSplat(limited, 1); + br.seek += result.len; return .{ - .end = false, - .len = @intCast(n), + .len = result.len, + .write_err = result.err, + .write_end = result.end, }; } return br.unbuffered_reader.streamRead(bw, limit); } -fn passthru_streamReadVec(ctx: *anyopaque, data: []const []u8) anyerror!Reader.Status { +fn passthru_streamReadVec(ctx: ?*anyopaque, data: []const []u8) anyerror!Reader.Status { const br: *BufferedReader = @alignCast(@ptrCast(ctx)); _ = br; _ = data; @panic("TODO"); } -fn passthru_posRead(ctx: *anyopaque, bw: *BufferedWriter, limit: Reader.Limit, off: u64) anyerror!Reader.Status { +fn passthru_posRead(ctx: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit, off: u64) anyerror!Reader.Status { const br: *BufferedReader = @alignCast(@ptrCast(ctx)); const buffer = br.storage.buffer.items; if (off < buffer.len) { @@ -84,7 +96,7 @@ fn passthru_posRead(ctx: *anyopaque, bw: *BufferedWriter, limit: Reader.Limit, o return br.unbuffered_reader.posRead(bw, limit, off - buffer.len); } -fn passthru_posReadVec(ctx: *anyopaque, data: []const []u8, off: u64) anyerror!Reader.Status { +fn passthru_posReadVec(ctx: ?*anyopaque, data: []const []u8, off: u64) anyerror!Reader.Status { const br: *BufferedReader = @alignCast(@ptrCast(ctx)); _ = br; _ = data; @@ -155,8 +167,24 @@ pub fn takeArray(br: *BufferedReader, comptime n: usize) anyerror!*[n]u8 { /// /// See also: /// * `toss` -/// * `discardAll` +/// * `discardUntilEnd` +/// * `discardUpTo` pub fn discard(br: *BufferedReader, n: usize) anyerror!void { + if ((try discardUpTo(br, n)) != n) return error.EndOfStream; +} + +/// Skips the next `n` bytes from the stream, advancing the seek position. +/// +/// Unlike `toss` which is infallible, in this function `n` can be any amount. +/// +/// Returns the number of bytes discarded, which is less than `n` if and only +/// if the stream reached the end. +/// +/// See also: +/// * `discard` +/// * `toss` +/// * `discardUntilEnd` +pub fn discardUpTo(br: *BufferedReader, n: usize) anyerror!usize { const list = &br.storage.buffer; var remaining = n; while (remaining > 0) { @@ -168,19 +196,22 @@ pub fn discard(br: *BufferedReader, n: usize) anyerror!void { remaining -= (list.items.len - br.seek); list.items.len = 0; br.seek = 0; - const status = try br.unbuffered_reader.streamRead(&br.storage, .none); + const result = try br.unbuffered_reader.streamRead(&br.storage, .none); + result.write_err catch unreachable; + try result.read_err; + assert(result.len == list.items.len); if (remaining <= list.items.len) continue; - if (status.end) return error.EndOfStream; + if (result.end) return n - remaining; } } /// Reads the stream until the end, ignoring all the data. /// Returns the number of bytes discarded. -pub fn discardAll(br: *BufferedReader) anyerror!usize { +pub fn discardUntilEnd(br: *BufferedReader) anyerror!usize { const list = &br.storage.buffer; var total: usize = list.items.len; list.items.len = 0; - total += try br.unbuffered_reader.discardAll(); + total += try br.unbuffered_reader.discardUntilEnd(); return total; } @@ -224,6 +255,15 @@ pub fn read(br: *BufferedReader, buffer: []u8) anyerror!void { } } +/// Returns the number of bytes read. If the number read is smaller than `buffer.len`, it +/// means the stream reached the end. Reaching the end of a stream is not an error +/// condition. +pub fn partialRead(br: *BufferedReader, buffer: []u8) anyerror!usize { + _ = br; + _ = buffer; + @panic("TODO"); +} + /// Returns a slice of the next bytes of buffered data from the stream until /// `delimiter` is found, advancing the seek position. /// @@ -463,6 +503,95 @@ pub fn takeEnum(br: *BufferedReader, comptime Enum: type, endian: std.builtin.En return std.meta.intToEnum(Enum, int); } +/// Read a single unsigned LEB128 value from the given reader as type T, +/// or error.Overflow if the value cannot fit. +pub fn takeUleb128(br: *std.io.BufferedReader, comptime T: type) anyerror!T { + const U = if (@typeInfo(T).int.bits < 8) u8 else T; + const ShiftT = std.math.Log2Int(U); + + const max_group = (@typeInfo(U).int.bits + 6) / 7; + + var value: U = 0; + var group: ShiftT = 0; + + while (group < max_group) : (group += 1) { + const byte = try br.takeByte(); + + const ov = @shlWithOverflow(@as(U, byte & 0x7f), group * 7); + if (ov[1] != 0) return error.Overflow; + + value |= ov[0]; + if (byte & 0x80 == 0) break; + } else { + return error.Overflow; + } + + // only applies in the case that we extended to u8 + if (U != T) { + if (value > std.math.maxInt(T)) return error.Overflow; + } + + return @truncate(value); +} + +/// Read a single signed LEB128 value from the given reader as type T, +/// or `error.Overflow` if the value cannot fit. +pub fn takeIleb128(br: *std.io.BufferedReader, comptime T: type) anyerror!T { + const S = if (@typeInfo(T).int.bits < 8) i8 else T; + const U = std.meta.Int(.unsigned, @typeInfo(S).int.bits); + const ShiftU = std.math.Log2Int(U); + + const max_group = (@typeInfo(U).int.bits + 6) / 7; + + var value = @as(U, 0); + var group = @as(ShiftU, 0); + + while (group < max_group) : (group += 1) { + const byte = try br.takeByte(); + + const shift = group * 7; + const ov = @shlWithOverflow(@as(U, byte & 0x7f), shift); + if (ov[1] != 0) { + // Overflow is ok so long as the sign bit is set and this is the last byte + if (byte & 0x80 != 0) return error.Overflow; + if (@as(S, @bitCast(ov[0])) >= 0) return error.Overflow; + + // and all the overflowed bits are 1 + const remaining_shift = @as(u3, @intCast(@typeInfo(U).int.bits - @as(u16, shift))); + const remaining_bits = @as(i8, @bitCast(byte | 0x80)) >> remaining_shift; + if (remaining_bits != -1) return error.Overflow; + } else { + // If we don't overflow and this is the last byte and the number being decoded + // is negative, check that the remaining bits are 1 + if ((byte & 0x80 == 0) and (@as(S, @bitCast(ov[0])) < 0)) { + const remaining_shift = @as(u3, @intCast(@typeInfo(U).int.bits - @as(u16, shift))); + const remaining_bits = @as(i8, @bitCast(byte | 0x80)) >> remaining_shift; + if (remaining_bits != -1) return error.Overflow; + } + } + + value |= ov[0]; + if (byte & 0x80 == 0) { + const needs_sign_ext = group + 1 < max_group; + if (byte & 0x40 != 0 and needs_sign_ext) { + const ones = @as(S, -1); + value |= @as(U, @bitCast(ones)) << (shift + 7); + } + break; + } + } else { + return error.Overflow; + } + + const result = @as(S, @bitCast(value)); + // Only applies if we extended to i8 + if (S != T) { + if (result > std.math.maxInt(T) or result < std.math.minInt(T)) return error.Overflow; + } + + return @truncate(result); +} + test initFixed { var br: BufferedReader = undefined; br.initFixed("a\x02"); @@ -501,7 +630,7 @@ test discard { try testing.expectError(error.EndOfStream, br.discard(1)); } -test discardAll { +test discardUntilEnd { return error.Unimplemented; } @@ -576,3 +705,11 @@ test takeStructEndian { test takeEnum { return error.Unimplemented; } + +test takeUleb128 { + return error.Unimplemented; +} + +test takeIleb128 { + return error.Unimplemented; +} diff --git a/lib/std/io/BufferedWriter.zig b/lib/std/io/BufferedWriter.zig index 5a22f8c0d1..94a9b5db62 100644 --- a/lib/std/io/BufferedWriter.zig +++ b/lib/std/io/BufferedWriter.zig @@ -43,7 +43,7 @@ const fixed_vtable: Writer.VTable = .{ }; /// Replaces the `BufferedWriter` with a new one that writes to `buffer` and -/// returns `error.NoSpaceLeft` when it is full. +/// then ends when it is full. pub fn initFixed(bw: *BufferedWriter, buffer: []u8) void { bw.* = .{ .unbuffered_writer = .{ @@ -77,6 +77,36 @@ pub fn unusedCapacitySlice(bw: *const BufferedWriter) []u8 { return bw.buffer.unusedCapacitySlice(); } +pub fn writableSlice(bw: *BufferedWriter, minimum_length: usize) anyerror![]u8 { + const list = &bw.buffer; + assert(list.capacity >= minimum_length); + const cap_slice = list.unusedCapacitySlice(); + if (cap_slice.len >= minimum_length) { + @branchHint(.likely); + return cap_slice; + } + const buffer = list.items; + const result = bw.unbuffered_writer.write(buffer); + if (result.len == buffer.len) { + @branchHint(.likely); + list.items.len = 0; + try result.err; + return list.unusedCapacitySlice(); + } + if (result.len > 0) { + const remainder = buffer[result.len..]; + std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); + list.items.len = remainder.len; + } + try result.err; + return list.unusedCapacitySlice(); +} + +/// After calling `writableSlice`, this function tracks how many bytes were written to it. +pub fn advance(bw: *BufferedWriter, n: usize) void { + bw.items.len += n; +} + /// The `data` parameter is mutable because this function needs to mutate the /// fields in order to handle partial writes from `Writer.VTable.writev`. pub fn writevAll(bw: *BufferedWriter, data: [][]const u8) anyerror!void { @@ -92,15 +122,15 @@ pub fn writevAll(bw: *BufferedWriter, data: [][]const u8) anyerror!void { } } -pub fn writeSplat(bw: *BufferedWriter, data: []const []const u8, splat: usize) anyerror!usize { +pub fn writeSplat(bw: *BufferedWriter, data: []const []const u8, splat: usize) Writer.Result { return passthru_writeSplat(bw, data, splat); } -pub fn writev(bw: *BufferedWriter, data: []const []const u8) anyerror!usize { +pub fn writev(bw: *BufferedWriter, data: []const []const u8) Writer.Result { return passthru_writeSplat(bw, data, 1); } -fn passthru_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { +fn passthru_writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) Writer.Result { const bw: *BufferedWriter = @alignCast(@ptrCast(context)); const list = &bw.buffer; const buffer = list.allocatedSlice(); @@ -126,27 +156,45 @@ fn passthru_writeSplat(context: *anyopaque, data: []const []const u8, splat: usi if (len >= remaining_data.len) { @branchHint(.likely); // Made it past the headers, so we can enable splatting. - const n = try bw.unbuffered_writer.writeSplat(send_buffers, splat); + const result = bw.unbuffered_writer.writeSplat(send_buffers, splat); + const n = result.len; if (n < end) { @branchHint(.unlikely); const remainder = buffer[n..end]; std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); list.items.len = remainder.len; - return end - start_end; + return .{ + .err = result.err, + .len = end - start_end, + .end = result.end, + }; } list.items.len = 0; - return n - start_end; + return .{ + .err = result.err, + .len = n - start_end, + .end = result.end, + }; } - const n = try bw.unbuffered_writer.writeSplat(send_buffers, 1); + const result = try bw.unbuffered_writer.writeSplat(send_buffers, 1); + const n = result.len; if (n < end) { @branchHint(.unlikely); const remainder = buffer[n..end]; std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); list.items.len = remainder.len; - return end - start_end; + return .{ + .err = result.err, + .len = end - start_end, + .end = result.end, + }; } list.items.len = 0; - return n - start_end; + return .{ + .err = result.err, + .len = n - start_end, + .end = result.end, + }; } const pattern = data[data.len - 1]; @@ -156,7 +204,7 @@ fn passthru_writeSplat(context: *anyopaque, data: []const []const u8, splat: usi // It was added in the loop above; undo it here. end -= pattern.len; list.items.len = end; - return end - start_end; + return .{ .len = end - start_end }; } const remaining_splat = splat - 1; @@ -164,7 +212,7 @@ fn passthru_writeSplat(context: *anyopaque, data: []const []const u8, splat: usi switch (pattern.len) { 0 => { list.items.len = end; - return end - start_end; + return .{ .len = end - start_end }; }, 1 => { const new_end = end + remaining_splat; @@ -172,20 +220,29 @@ fn passthru_writeSplat(context: *anyopaque, data: []const []const u8, splat: usi @branchHint(.likely); @memset(buffer[end..new_end], pattern[0]); list.items.len = new_end; - return new_end - start_end; + return .{ .len = new_end - start_end }; } buffers[0] = buffer[0..end]; buffers[1] = pattern; - const n = try bw.unbuffered_writer.writeSplat(buffers[0..2], remaining_splat); + const result = bw.unbuffered_writer.writeSplat(buffers[0..2], remaining_splat); + const n = result.len; if (n < end) { @branchHint(.unlikely); const remainder = buffer[n..end]; std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); list.items.len = remainder.len; - return end - start_end; + return .{ + .err = result.err, + .len = end - start_end, + .end = result.end, + }; } list.items.len = 0; - return n - start_end; + return .{ + .err = result.err, + .len = n - start_end, + .end = result.end, + }; }, else => { const new_end = end + pattern.len * remaining_splat; @@ -195,46 +252,43 @@ fn passthru_writeSplat(context: *anyopaque, data: []const []const u8, splat: usi @memcpy(buffer[end..][0..pattern.len], pattern); } list.items.len = new_end; - return new_end - start_end; + return .{ .len = new_end - start_end }; } buffers[0] = buffer[0..end]; buffers[1] = pattern; - const n = try bw.unbuffered_writer.writeSplat(buffers[0..2], remaining_splat); + const result = bw.unbuffered_writer.writeSplat(buffers[0..2], remaining_splat); + const n = result.len; if (n < end) { @branchHint(.unlikely); const remainder = buffer[n..end]; std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); list.items.len = remainder.len; - return end - start_end; + return .{ + .err = result.err, + .len = end - start_end, + .end = result.end, + }; } list.items.len = 0; - return n - start_end; + return .{ + .err = result.err, + .len = n - start_end, + .end = result.end, + }; }, } } -fn fixed_writev(context: *anyopaque, data: []const []const u8) anyerror!usize { - const bw: *BufferedWriter = @alignCast(@ptrCast(context)); - const list = &bw.buffer; - // When this function is called it means the buffer got full, so it's time - // to return an error. However, we still need to make sure all of the - // available buffer has been used. - const first = data[0]; - const dest = list.unusedCapacitySlice(); - @memcpy(dest, first[0..dest.len]); - list.items.len = list.capacity; - return error.NoSpaceLeft; -} - /// When this function is called it means the buffer got full, so it's time /// to return an error. However, we still need to make sure all of the /// available buffer has been filled. -fn fixed_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { +fn fixed_writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) Writer.Result { const bw: *BufferedWriter = @alignCast(@ptrCast(context)); const list = &bw.buffer; + const start_len = list.items.len; for (data) |bytes| { const dest = list.unusedCapacitySlice(); - if (dest.len == 0) return error.NoSpaceLeft; + if (dest.len == 0) return .{ .len = list.items.len - start_len, .end = true }; const len = @min(bytes.len, dest.len); @memcpy(dest[0..len], bytes[0..len]); list.items.len += len; @@ -247,90 +301,153 @@ fn fixed_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) else => for (0..splat - 1) |i| @memcpy(dest[i * pattern.len ..][0..pattern.len], pattern), } list.items.len = list.capacity; - return error.NoSpaceLeft; + return .{ .len = list.items.len - start_len, .end = true }; } -pub fn write(bw: *BufferedWriter, bytes: []const u8) anyerror!usize { +pub fn write(bw: *BufferedWriter, bytes: []const u8) Writer.Result { const list = &bw.buffer; const buffer = list.allocatedSlice(); const end = list.items.len; const new_end = end + bytes.len; if (new_end > buffer.len) { var data: [2][]const u8 = .{ buffer[0..end], bytes }; - const n = try bw.unbuffered_writer.writev(&data); + const result = bw.unbuffered_writer.writev(&data); + const n = result.len; if (n < end) { @branchHint(.unlikely); const remainder = buffer[n..end]; std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); list.items.len = remainder.len; - return 0; + return .{ + .err = result.err, + .len = 0, + .end = result.end, + }; } list.items.len = 0; - return n - end; + return .{ + .err = result.err, + .len = n - end, + .end = result.end, + }; } @memcpy(buffer[end..new_end], bytes); list.items.len = new_end; return bytes.len; } -/// This function is provided by the `Writer`, however it is -/// duplicated here so that `bw` can be passed to `std.fmt.format` directly, -/// avoiding one indirect function call. pub fn writeAll(bw: *BufferedWriter, bytes: []const u8) anyerror!void { + if ((try writeUntilEnd(bw, bytes)) != bytes.len) return error.WriteStreamEnd; +} + +pub fn writeAllCount(bw: *BufferedWriter, bytes: []const u8) anyerror!usize { + try writeAll(bw, bytes); + return bytes.len; +} + +/// If the number returned is less than `bytes.len` it indicates end of stream. +pub fn writeUntilEnd(bw: *BufferedWriter, bytes: []const u8) anyerror!usize { var index: usize = 0; - while (index < bytes.len) index += try write(bw, bytes[index..]); + while (true) { + const result = write(bw, bytes[index..]); + try result.err; + index += result.len; + assert(index <= bytes.len); + if (index == bytes.len or result.end) return index; + } } pub fn print(bw: *BufferedWriter, comptime format: []const u8, args: anytype) anyerror!void { + _ = try std.fmt.format(bw, format, args); +} + +pub fn printCount(bw: *BufferedWriter, comptime format: []const u8, args: anytype) anyerror!usize { return std.fmt.format(bw, format, args); } pub fn writeByte(bw: *BufferedWriter, byte: u8) anyerror!void { + if ((try writeByteUntilEnd(bw, byte)) == 0) return error.WriteStreamEnd; +} + +pub fn writeByteCount(bw: *BufferedWriter, byte: u8) anyerror!usize { + try writeByte(bw, byte); + return 1; +} + +/// Returns 0 or 1 indicating how many bytes were written. +/// `0` means end of stream encountered. +pub fn writeByteUntilEnd(bw: *BufferedWriter, byte: u8) anyerror!usize { const list = &bw.buffer; const buffer = list.items; if (buffer.len < list.capacity) { @branchHint(.likely); buffer.ptr[buffer.len] = byte; list.items.len = buffer.len + 1; - return; + return 1; } var buffers: [2][]const u8 = .{ buffer, &.{byte} }; while (true) { - const n = try bw.unbuffered_writer.writev(&buffers); + const result = bw.unbuffered_writer.writev(&buffers); + try result.err; + const n = result.len; if (n == 0) { @branchHint(.unlikely); + if (result.end) return 0; continue; } else if (n >= buffer.len) { @branchHint(.likely); if (n > buffer.len) { @branchHint(.likely); list.items.len = 0; - return; + return 1; } else { buffer[0] = byte; list.items.len = 1; - return; + return 1; } } const remainder = buffer[n..]; std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); buffer[remainder.len] = byte; list.items.len = remainder.len + 1; - return; + return 1; } } /// Writes the same byte many times, performing the underlying write call as -/// many times as necessary. +/// many times as necessary, returning `error.WriteStreamEnd` if the byte +/// could not be repeated `n` times. pub fn splatByteAll(bw: *BufferedWriter, byte: u8, n: usize) anyerror!void { - var remaining: usize = n; - while (remaining > 0) remaining -= try splatByte(bw, byte, remaining); + if ((try splatByteUntilEnd(bw, byte, n)) != n) return error.WriteStreamEnd; +} + +/// Writes the same byte many times, performing the underlying write call as +/// many times as necessary, returning `error.WriteStreamEnd` if the byte +/// could not be repeated `n` times, or returning `n` on success. +pub fn splatByteAllCount(bw: *BufferedWriter, byte: u8, n: usize) anyerror!usize { + try splatByteAll(bw, byte, n); + return n; +} + +/// Writes the same byte many times, performing the underlying write call as +/// many times as necessary. +/// +/// If the number returned is less than `n` it indicates end of stream. +pub fn splatByteUntilEnd(bw: *BufferedWriter, byte: u8, n: usize) anyerror!usize { + var index: usize = 0; + while (true) { + const result = splatByte(bw, byte, n - index); + try result.err; + index += result.len; + assert(index <= n); + if (index == n or result.end) return index; + } } /// Writes the same byte many times, allowing short writes. /// -/// Does maximum of one underlying `Writer.VTable.writev`. -pub fn splatByte(bw: *BufferedWriter, byte: u8, n: usize) anyerror!usize { +/// Does maximum of one underlying `Writer.VTable.writeSplat`. +pub fn splatByte(bw: *BufferedWriter, byte: u8, n: usize) Writer.Result { return passthru_writeSplat(bw, &.{&.{byte}}, n); } @@ -389,7 +506,7 @@ pub fn writeFile( } fn passthru_writeFile( - context: *anyopaque, + context: ?*anyopaque, file: std.fs.File, offset: u64, len: Writer.FileLen, @@ -544,32 +661,34 @@ pub fn alignBuffer( width: usize, alignment: std.fmt.Alignment, fill: u8, -) anyerror!void { +) anyerror!usize { const padding = if (buffer.len < width) width - buffer.len else 0; if (padding == 0) { @branchHint(.likely); - return bw.writeAll(buffer); + return bw.writeAllCount(buffer); } + var n: usize = 0; switch (alignment) { .left => { - try bw.writeAll(buffer); - try bw.splatByteAll(fill, padding); + n += try bw.writeAllCount(buffer); + n += try bw.splatByteAllCount(fill, padding); }, .center => { const left_padding = padding / 2; const right_padding = (padding + 1) / 2; - try bw.splatByteAll(fill, left_padding); - try bw.writeAll(buffer); - try bw.splatByteAll(fill, right_padding); + n += try bw.splatByteAllCount(fill, left_padding); + n += try bw.writeAllCount(buffer); + n += try bw.splatByteAllCount(fill, right_padding); }, .right => { - try bw.splatByteAll(fill, padding); - try bw.writeAll(buffer); + n += try bw.splatByteAllCount(fill, padding); + n += try bw.writeAllCount(buffer); }, } + return n; } -pub fn alignBufferOptions(bw: *BufferedWriter, buffer: []const u8, options: std.fmt.Options) anyerror!void { +pub fn alignBufferOptions(bw: *BufferedWriter, buffer: []const u8, options: std.fmt.Options) anyerror!usize { return alignBuffer(bw, buffer, options.width orelse buffer.len, options.alignment, options.fill); } @@ -604,7 +723,7 @@ pub fn printValue( options: std.fmt.Options, value: anytype, max_depth: usize, -) anyerror!void { +) anyerror!usize { const T = @TypeOf(value); const actual_fmt = comptime if (std.mem.eql(u8, fmt, ANY)) defaultFormatString(T) @@ -619,13 +738,10 @@ pub fn printValue( if (std.meta.hasMethod(T, "format")) { if (fmt.len > 0 and fmt[0] == 'f') { - return value.format(fmt[1..], options, bw); - } else { - //@deprecated(); - // After 0.14.0 is tagged, uncomment this next line: - //@compileError("ambiguous format string; specify {f} to call format method, or {any} to skip it"); - //and then delete the `hasMethod` condition - return value.format(fmt, options, bw); + return value.format(bw, fmt[1..]); + } else if (fmt.len == 0) { + // after 0.15.0 is tagged, delete the hasMethod condition and this compile error + @compileError("ambiguous format string; specify {f} to call format method, or {any} to skip it"); } } @@ -662,92 +778,104 @@ pub fn printValue( }, .error_set => { if (actual_fmt.len > 0 and actual_fmt[0] == 's') { - return bw.writeAll(@errorName(value)); + return bw.writeAllCount(@errorName(value)); } else if (actual_fmt.len != 0) { invalidFmtError(fmt, value); } else { - try bw.writeAll("error."); - return bw.writeAll(@errorName(value)); + var n: usize = 0; + n += try bw.writeAllCount("error."); + n += try bw.writeAllCount(@errorName(value)); + return n; } }, - .@"enum" => |enumInfo| { - try bw.writeAll(@typeName(T)); - if (enumInfo.is_exhaustive) { + .@"enum" => |enum_info| { + var n: usize = 0; + n += try bw.writeAllCount(@typeName(T)); + if (enum_info.is_exhaustive) { if (actual_fmt.len != 0) invalidFmtError(fmt, value); - try bw.writeAll("."); - try bw.writeAll(@tagName(value)); - return; + n += try bw.writeAllCount("."); + n += try bw.writeAllCount(@tagName(value)); + return n; } // Use @tagName only if value is one of known fields - @setEvalBranchQuota(3 * enumInfo.fields.len); - inline for (enumInfo.fields) |enumField| { + @setEvalBranchQuota(3 * enum_info.fields.len); + inline for (enum_info.fields) |enumField| { if (@intFromEnum(value) == enumField.value) { - try bw.writeAll("."); - try bw.writeAll(@tagName(value)); + n += try bw.writeAllCount("."); + n += try bw.writeAllCount(@tagName(value)); return; } } - try bw.writeByte('('); - try printValue(bw, actual_fmt, options, @intFromEnum(value), max_depth); - try bw.writeByte(')'); + n += try bw.writeByteCount('('); + n += try printValue(bw, actual_fmt, options, @intFromEnum(value), max_depth); + n += try bw.writeByteCount(')'); + return n; }, .@"union" => |info| { if (actual_fmt.len != 0) invalidFmtError(fmt, value); - try bw.writeAll(@typeName(T)); + var n: usize = 0; + n += try bw.writeAllCount(@typeName(T)); if (max_depth == 0) { - return bw.writeAll("{ ... }"); + n += bw.writeAllCount("{ ... }"); + return n; } if (info.tag_type) |UnionTagType| { - try bw.writeAll("{ ."); - try bw.writeAll(@tagName(@as(UnionTagType, value))); - try bw.writeAll(" = "); + n += try bw.writeAllCount("{ ."); + n += try bw.writeAllCount(@tagName(@as(UnionTagType, value))); + n += try bw.writeAllCount(" = "); inline for (info.fields) |u_field| { if (value == @field(UnionTagType, u_field.name)) { - try printValue(bw, ANY, options, @field(value, u_field.name), max_depth - 1); + n += try printValue(bw, ANY, options, @field(value, u_field.name), max_depth - 1); } } - try bw.writeAll(" }"); + n += try bw.writeAllCount(" }"); } else { - try bw.writeByte('@'); - try bw.printIntOptions(@intFromPtr(&value), 16, .lower); + n += try bw.writeByte('@'); + n += try bw.printIntOptions(@intFromPtr(&value), 16, .lower); } + return n; }, .@"struct" => |info| { if (actual_fmt.len != 0) invalidFmtError(fmt, value); + var n: usize = 0; if (info.is_tuple) { // Skip the type and field names when formatting tuples. if (max_depth == 0) { - return bw.writeAll("{ ... }"); + n += try bw.writeAllCount("{ ... }"); + return n; } - try bw.writeAll("{"); + n += try bw.writeAllCount("{"); inline for (info.fields, 0..) |f, i| { if (i == 0) { - try bw.writeAll(" "); + n += try bw.writeAllCount(" "); } else { - try bw.writeAll(", "); + n += try bw.writeAllCount(", "); } - try printValue(bw, ANY, options, @field(value, f.name), max_depth - 1); + n += try printValue(bw, ANY, options, @field(value, f.name), max_depth - 1); } - return bw.writeAll(" }"); + n += try bw.writeAllCount(" }"); + return n; } - try bw.writeAll(@typeName(T)); + n += try bw.writeAllCount(@typeName(T)); if (max_depth == 0) { - return bw.writeAll("{ ... }"); + n += try bw.writeAllCount("{ ... }"); + return n; } - try bw.writeAll("{"); + n += try bw.writeAllCount("{"); inline for (info.fields, 0..) |f, i| { if (i == 0) { - try bw.writeAll(" ."); + n += try bw.writeAllCount(" ."); } else { - try bw.writeAll(", ."); + n += try bw.writeAllCount(", ."); } - try bw.writeAll(f.name); - try bw.writeAll(" = "); - try printValue(bw, ANY, options, @field(value, f.name), max_depth - 1); + n += try bw.writeAllCount(f.name); + n += try bw.writeAllCount(" = "); + n += try printValue(bw, ANY, options, @field(value, f.name), max_depth - 1); } - try bw.writeAll(" }"); + n += try bw.writeAllCount(" }"); + return n; }, .pointer => |ptr_info| switch (ptr_info.size) { .one => switch (@typeInfo(ptr_info.child)) { @@ -756,8 +884,10 @@ pub fn printValue( }, else => { var buffers: [2][]const u8 = .{ @typeName(ptr_info.child), "@" }; - try writevAll(bw, &buffers); - try printIntOptions(bw, @intFromPtr(value), 16, .lower, options); + var n: usize = 0; + n += try writevAll(bw, &buffers); + n += try printIntOptions(bw, @intFromPtr(value), 16, .lower, options); + return n; }, }, .many, .c => { @@ -775,7 +905,7 @@ pub fn printValue( if (actual_fmt.len == 0) @compileError("cannot format slice without a specifier (i.e. {s}, {x}, {b64}, or {any})"); if (max_depth == 0) { - return bw.writeAll("{ ... }"); + return bw.writeAllCount("{ ... }"); } if (ptr_info.child == u8) switch (actual_fmt.len) { 1 => switch (actual_fmt[0]) { @@ -789,21 +919,23 @@ pub fn printValue( }, else => {}, }; - try bw.writeAll("{ "); + var n: usize = 0; + n += try bw.writeAllCount("{ "); for (value, 0..) |elem, i| { - try printValue(bw, actual_fmt, options, elem, max_depth - 1); + n += try printValue(bw, actual_fmt, options, elem, max_depth - 1); if (i != value.len - 1) { - try bw.writeAll(", "); + n += try bw.writeAllCount(", "); } } - try bw.writeAll(" }"); + n += try bw.writeAllCount(" }"); + return n; }, }, .array => |info| { if (actual_fmt.len == 0) @compileError("cannot format array without a specifier (i.e. {s} or {any})"); if (max_depth == 0) { - return bw.writeAll("{ ... }"); + return bw.writeAllCount("{ ... }"); } if (info.child == u8) { if (actual_fmt[0] == 's') { @@ -814,28 +946,32 @@ pub fn printValue( return printHex(bw, &value, .upper); } } - try bw.writeAll("{ "); + var n: usize = 0; + n += try bw.writeAllCount("{ "); for (value, 0..) |elem, i| { - try printValue(bw, actual_fmt, options, elem, max_depth - 1); + n += try printValue(bw, actual_fmt, options, elem, max_depth - 1); if (i < value.len - 1) { - try bw.writeAll(", "); + n += try bw.writeAllCount(", "); } } - try bw.writeAll(" }"); + n += try bw.writeAllCount(" }"); + return n; }, .vector => |info| { if (max_depth == 0) { - return bw.writeAll("{ ... }"); + return bw.writeAllCount("{ ... }"); } - try bw.writeAll("{ "); + var n: usize = 0; + n += try bw.writeAllCount("{ "); var i: usize = 0; while (i < info.len) : (i += 1) { - try printValue(bw, actual_fmt, options, value[i], max_depth - 1); + n += try printValue(bw, actual_fmt, options, value[i], max_depth - 1); if (i < info.len - 1) { - try bw.writeAll(", "); + n += try bw.writeAllCount(", "); } } - try bw.writeAll(" }"); + n += try bw.writeAllCount(" }"); + return n; }, .@"fn" => @compileError("unable to format function body type, use '*const " ++ @typeName(T) ++ "' for a function pointer type"), .type => { @@ -860,7 +996,7 @@ pub fn printInt( comptime fmt: []const u8, options: std.fmt.Options, value: anytype, -) anyerror!void { +) anyerror!usize { const int_value = if (@TypeOf(value) == comptime_int) blk: { const Int = std.math.IntFittingRange(value, value); break :blk @as(Int, value); @@ -904,15 +1040,15 @@ pub fn printInt( comptime unreachable; } -pub fn printAsciiChar(bw: *BufferedWriter, c: u8, options: std.fmt.Options) anyerror!void { +pub fn printAsciiChar(bw: *BufferedWriter, c: u8, options: std.fmt.Options) anyerror!usize { return alignBufferOptions(bw, @as(*const [1]u8, &c), options); } -pub fn printAscii(bw: *BufferedWriter, bytes: []const u8, options: std.fmt.Options) anyerror!void { +pub fn printAscii(bw: *BufferedWriter, bytes: []const u8, options: std.fmt.Options) anyerror!usize { return alignBufferOptions(bw, bytes, options); } -pub fn printUnicodeCodepoint(bw: *BufferedWriter, c: u21, options: std.fmt.Options) anyerror!void { +pub fn printUnicodeCodepoint(bw: *BufferedWriter, c: u21, options: std.fmt.Options) anyerror!usize { var buf: [4]u8 = undefined; const len = try std.unicode.utf8Encode(c, &buf); return alignBufferOptions(bw, buf[0..len], options); @@ -924,7 +1060,7 @@ pub fn printIntOptions( base: u8, case: std.fmt.Case, options: std.fmt.Options, -) anyerror!void { +) anyerror!usize { assert(base >= 2); const int_value = if (@TypeOf(value) == comptime_int) blk: { @@ -991,7 +1127,7 @@ pub fn printFloat( comptime fmt: []const u8, options: std.fmt.Options, value: anytype, -) anyerror!void { +) anyerror!usize { var buf: [std.fmt.float.bufferSize(.decimal, f64)]u8 = undefined; if (fmt.len > 1) invalidFmtError(fmt, value); @@ -1279,7 +1415,7 @@ pub fn printDuration(bw: *BufferedWriter, nanoseconds: anytype, options: std.fmt return alignBufferOptions(bw, sub_bw.getWritten(), options); } -pub fn printHex(bw: *BufferedWriter, bytes: []const u8, case: std.fmt.Case) anyerror!void { +pub fn printHex(bw: *BufferedWriter, bytes: []const u8, case: std.fmt.Case) anyerror!usize { const charset = switch (case) { .upper => "0123456789ABCDEF", .lower => "0123456789abcdef", @@ -1288,12 +1424,68 @@ pub fn printHex(bw: *BufferedWriter, bytes: []const u8, case: std.fmt.Case) anye try writeByte(bw, charset[c >> 4]); try writeByte(bw, charset[c & 15]); } + return bytes.len * 2; } -pub fn printBase64(bw: *BufferedWriter, bytes: []const u8) anyerror!void { +pub fn printBase64(bw: *BufferedWriter, bytes: []const u8) anyerror!usize { var chunker = std.mem.window(u8, bytes, 3, 3); var temp: [5]u8 = undefined; - while (chunker.next()) |chunk| try bw.writeAll(std.base64.standard.Encoder.encode(&temp, chunk)); + var n: usize = 0; + while (chunker.next()) |chunk| { + n += try bw.writeAllCount(std.base64.standard.Encoder.encode(&temp, chunk)); + } + return n; +} + +/// Write a single unsigned integer as unsigned LEB128 to the given writer. +pub fn writeUleb128(bw: *std.io.BufferedWriter, arg: anytype) anyerror!usize { + const Arg = @TypeOf(arg); + const Int = switch (Arg) { + comptime_int => std.math.IntFittingRange(arg, arg), + else => Arg, + }; + const Value = if (@typeInfo(Int).int.bits < 8) u8 else Int; + var value: Value = arg; + var n: usize = 0; + + while (true) { + const byte: u8 = @truncate(value & 0x7f); + value >>= 7; + if (value == 0) { + try bw.writeByte(byte); + return n + 1; + } else { + try bw.writeByte(byte | 0x80); + n += 1; + } + } +} + +/// Write a single signed integer as signed LEB128 to the given writer. +pub fn writeIleb128(bw: *std.io.BufferedWriter, arg: anytype) anyerror!usize { + const Arg = @TypeOf(arg); + const Int = switch (Arg) { + comptime_int => std.math.IntFittingRange(-@abs(arg), @abs(arg)), + else => Arg, + }; + const Signed = if (@typeInfo(Int).int.bits < 8) i8 else Int; + const Unsigned = std.meta.Int(.unsigned, @typeInfo(Signed).int.bits); + var value: Signed = arg; + var n: usize = 0; + + while (true) { + const unsigned: Unsigned = @bitCast(value); + const byte: u8 = @truncate(unsigned); + value >>= 6; + if (value == -1 or value == 0) { + try bw.writeByte(byte & 0x7F); + return n + 1; + } else { + value >>= 1; + try bw.writeByte(byte | 0x80); + n += 1; + } + } } test "formatValue max_depth" { @@ -1590,15 +1782,15 @@ test "fixed output" { try bw.writeAll("world"); try testing.expect(std.mem.eql(u8, bw.getWritten(), "Helloworld")); - try testing.expectError(error.NoSpaceLeft, bw.writeAll("!")); + try testing.expectError(error.WriteStreamEnd, bw.writeAll("!")); try testing.expect(std.mem.eql(u8, bw.getWritten(), "Helloworld")); bw.reset(); try testing.expect(bw.getWritten().len == 0); - try testing.expectError(error.NoSpaceLeft, bw.writeAll("Hello world!")); + try testing.expectError(error.WriteStreamEnd, bw.writeAll("Hello world!")); try testing.expect(std.mem.eql(u8, bw.getWritten(), "Hello worl")); try bw.seekTo((try bw.getEndPos()) + 1); - try testing.expectError(error.NoSpaceLeft, bw.writeAll("H")); + try testing.expectError(error.WriteStreamEnd, bw.writeAll("H")); } diff --git a/lib/std/io/CountingReader.zig b/lib/std/io/CountingReader.zig deleted file mode 100644 index 4be9d123e4..0000000000 --- a/lib/std/io/CountingReader.zig +++ /dev/null @@ -1,29 +0,0 @@ -//! A Reader that counts how many bytes has been read from it. - -const std = @import("../std.zig"); -const CountingReader = @This(); - -child_reader: std.io.Reader, -bytes_read: u64 = 0, - -pub fn read(self: *@This(), buf: []u8) anyerror!usize { - const amt = try self.child_reader.read(buf); - self.bytes_read += amt; - return amt; -} - -pub fn reader(self: *@This()) std.io.Reader { - return .{ .context = self }; -} - -test CountingReader { - const bytes = "yay" ** 20; - var fbs: std.io.BufferedReader = undefined; - fbs.initFixed(bytes); - var counting_stream: CountingReader = .{ .child_reader = fbs.reader() }; - var stream = counting_stream.reader().unbuffered(); - while (stream.readByte()) |_| {} else |err| { - try std.testing.expectError(error.EndOfStream, err); - } - try std.testing.expect(counting_stream.bytes_read == bytes.len); -} diff --git a/lib/std/io/CountingWriter.zig b/lib/std/io/CountingWriter.zig deleted file mode 100644 index f421b0d217..0000000000 --- a/lib/std/io/CountingWriter.zig +++ /dev/null @@ -1,52 +0,0 @@ -//! TODO make this more like AllocatingWriter, managing the state of -//! BufferedWriter both as the output and the input, but with only -//! one buffer. -const std = @import("../std.zig"); -const CountingWriter = @This(); -const assert = std.debug.assert; -const native_endian = @import("builtin").target.cpu.arch.endian(); -const Writer = std.io.Writer; -const testing = std.testing; - -/// Underlying stream to passthrough bytes to. -child_writer: Writer, -bytes_written: u64 = 0, - -pub fn writer(cw: *CountingWriter) Writer { - return .{ - .context = cw, - .vtable = &.{ - .writeSplat = passthru_writeSplat, - .writeFile = passthru_writeFile, - }, - }; -} - -fn passthru_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { - const cw: *CountingWriter = @alignCast(@ptrCast(context)); - const n = try cw.child_writer.writeSplat(data, splat); - cw.bytes_written += n; - return n; -} - -fn passthru_writeFile( - context: *anyopaque, - file: std.fs.File, - offset: u64, - len: Writer.FileLen, - headers_and_trailers: []const []const u8, - headers_len: usize, -) anyerror!usize { - const cw: *CountingWriter = @alignCast(@ptrCast(context)); - const n = try cw.child_writer.writeFile(file, offset, len, headers_and_trailers, headers_len); - cw.bytes_written += n; - return n; -} - -test CountingWriter { - var cw: CountingWriter = .{ .child_writer = std.io.null_writer }; - var bw = cw.writer().unbuffered(); - const bytes = "yay"; - try bw.writeAll(bytes); - try testing.expect(cw.bytes_written == bytes.len); -} diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index cacf0604da..2177e917b3 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -19,8 +19,8 @@ pub const VTable = struct { /// /// If this is `null` it is equivalent to always returning /// `error.Unseekable`. - posRead: ?*const fn (ctx: ?*anyopaque, bw: *std.io.BufferedWriter, limit: Limit, offset: u64) Result, - posReadVec: ?*const fn (ctx: ?*anyopaque, data: []const []u8, offset: u64) VecResult, + posRead: ?*const fn (ctx: ?*anyopaque, bw: *std.io.BufferedWriter, limit: Limit, offset: u64) RwResult, + posReadVec: ?*const fn (ctx: ?*anyopaque, data: []const []u8, offset: u64) Result, /// Writes bytes from the internally tracked stream position to `bw`, or /// returns `error.Unstreamable`, indicating `posRead` should be used @@ -37,38 +37,34 @@ pub const VTable = struct { /// /// If this is `null` it is equivalent to always returning /// `error.Unstreamable`. - streamRead: ?*const fn (ctx: ?*anyopaque, bw: *std.io.BufferedWriter, limit: Limit) Result, - streamReadVec: ?*const fn (ctx: ?*anyopaque, data: []const []u8) VecResult, + streamRead: ?*const fn (ctx: ?*anyopaque, bw: *std.io.BufferedWriter, limit: Limit) RwResult, + streamReadVec: ?*const fn (ctx: ?*anyopaque, data: []const []u8) Result, + + pub const eof: VTable = .{ + .posRead = eof_posRead, + .posReadVec = eof_posReadVec, + .streamRead = eof_streamRead, + .streamReadVec = eof_streamReadVec, + }; }; -pub const Len = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(usize) - 1 } }); +pub const Result = std.io.Writer.Result; -pub const VecResult = struct { - /// Even when a failure occurs, `Effect.written` may be nonzero, and - /// `Effect.end` may be true. - failure: anyerror!void, - effect: VecEffect, -}; - -pub const Result = struct { - /// Even when a failure occurs, `Effect.written` may be nonzero, and - /// `Effect.end` may be true. - failure: anyerror!void, - write_effect: Effect, - read_effect: Effect, -}; - -pub const Effect = packed struct(usize) { - /// Number of bytes that were read from the reader or written to the - /// writer. - len: Len, - /// Indicates end of stream. - end: bool, +pub const RwResult = struct { + len: usize = 0, + read_err: anyerror!void = {}, + write_err: anyerror!void = {}, + read_end: bool = false, + write_end: bool = false, }; pub const Limit = enum(usize) { none = std.math.maxInt(usize), _, + + pub fn min(l: Limit, int: usize) usize { + return @min(int, @intFromEnum(l)); + } }; /// Returns total number of bytes written to `w`. @@ -133,25 +129,11 @@ pub fn streamReadAlloc(r: Reader, gpa: std.mem.Allocator, max_size: usize) anyer /// Reads the stream until the end, ignoring all the data. /// Returns the number of bytes discarded. -pub fn discardAll(r: Reader) anyerror!usize { +pub fn discardUntilEnd(r: Reader) anyerror!usize { var bw = std.io.null_writer.unbuffered(); return streamReadAll(r, &bw); } -pub fn buffered(r: Reader, buffer: []u8) std.io.BufferedReader { - return .{ - .reader = r, - .buffered_writer = .{ - .buffer = buffer, - .mode = .fixed, - }, - }; -} - -pub fn unbuffered(r: Reader) std.io.BufferedReader { - return buffered(r, &.{}); -} - pub fn allocating(r: Reader, gpa: std.mem.Allocator) std.io.BufferedReader { return .{ .reader = r, @@ -189,3 +171,31 @@ test "when the backing reader provides one byte at a time" { defer std.testing.allocator.free(res); try std.testing.expectEqualStrings(str, res); } + +fn eof_posRead(ctx: ?*anyopaque, bw: *std.io.BufferedWriter, limit: Limit, offset: u64) RwResult { + _ = ctx; + _ = bw; + _ = limit; + _ = offset; + return .{ .end = true }; +} + +fn eof_posReadVec(ctx: ?*anyopaque, data: []const []u8, offset: u64) Result { + _ = ctx; + _ = data; + _ = offset; + return .{ .end = true }; +} + +fn eof_streamRead(ctx: ?*anyopaque, bw: *std.io.BufferedWriter, limit: Limit) RwResult { + _ = ctx; + _ = bw; + _ = limit; + return .{ .end = true }; +} + +fn eof_streamReadVec(ctx: ?*anyopaque, data: []const []u8) Result { + _ = ctx; + _ = data; + return .{ .end = true }; +} diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index 7882d3518b..eff9962927 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -2,7 +2,7 @@ const std = @import("../std.zig"); const assert = std.debug.assert; const Writer = @This(); -context: *anyopaque, +context: ?*anyopaque, vtable: *const VTable, pub const VTable = struct { @@ -17,7 +17,7 @@ pub const VTable = struct { /// Number of bytes returned may be zero, which does not mean /// end-of-stream. A subsequent call may return nonzero, or may signal end /// of stream via an error. - writeSplat: *const fn (ctx: *anyopaque, data: []const []const u8, splat: usize) Result, + writeSplat: *const fn (ctx: ?*anyopaque, data: []const []const u8, splat: usize) Result, /// Writes contents from an open file. `headers` are written first, then `len` /// bytes of `file` starting from `offset`, then `trailers`. @@ -29,7 +29,7 @@ pub const VTable = struct { /// end-of-stream. A subsequent call may return nonzero, or may signal end /// of stream via an error. writeFile: *const fn ( - ctx: *anyopaque, + ctx: ?*anyopaque, file: std.fs.File, offset: Offset, /// When zero, it means copy until the end of the file is reached. @@ -39,25 +39,26 @@ pub const VTable = struct { headers_and_trailers: []const []const u8, headers_len: usize, ) Result, -}; -pub const Len = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(usize) - 1 } }); + pub const eof: VTable = .{ + .writeSplat = eof_writeSplat, + .writeFile = eof_writeFile, + }; +}; pub const Result = struct { - /// Even when a failure occurs, `Effect.written` may be nonzero, and - /// `Effect.end` may be true. - failure: anyerror!void, - effect: Effect, -}; - -pub const Effect = packed struct(usize) { - /// Number of bytes that were written to `writer`. - len: Len, + /// Even when a failure occurs, `len` may be nonzero, and `end` may be + /// true. + err: anyerror!void = {}, + /// Number of bytes that were transferred. When an error occurs, ideally + /// this will be zero, but may not always be the case. + len: usize = 0, /// Indicates end of stream. - end: bool, + end: bool = false, }; pub const Offset = enum(u64) { + /// Indicates to read the file as a stream. none = std.math.maxInt(u64), _, @@ -66,6 +67,11 @@ pub const Offset = enum(u64) { assert(result != .none); return result; } + + pub fn toInt(o: Offset) ?u64 { + if (o == .none) return null; + return @intFromEnum(o); + } }; pub const FileLen = enum(u64) { @@ -84,11 +90,11 @@ pub const FileLen = enum(u64) { } }; -pub fn writev(w: Writer, data: []const []const u8) anyerror!usize { +pub fn writev(w: Writer, data: []const []const u8) Result { return w.vtable.writeSplat(w.context, data, 1); } -pub fn writeSplat(w: Writer, data: []const []const u8, splat: usize) anyerror!usize { +pub fn writeSplat(w: Writer, data: []const []const u8, splat: usize) Result { return w.vtable.writeSplat(w.context, data, splat); } @@ -99,25 +105,25 @@ pub fn writeFile( len: FileLen, headers_and_trailers: []const []const u8, headers_len: usize, -) anyerror!usize { +) Result { return w.vtable.writeFile(w.context, file, offset, len, headers_and_trailers, headers_len); } pub fn unimplemented_writeFile( - context: *anyopaque, + context: ?*anyopaque, file: std.fs.File, offset: u64, len: FileLen, headers_and_trailers: []const []const u8, headers_len: usize, -) anyerror!usize { +) Result { _ = context; _ = file; _ = offset; _ = len; _ = headers_and_trailers; _ = headers_len; - return error.Unimplemented; + return .{ .err = error.Unimplemented }; } pub fn buffered(w: Writer, buffer: []u8) std.io.BufferedWriter { @@ -130,3 +136,74 @@ pub fn buffered(w: Writer, buffer: []u8) std.io.BufferedWriter { pub fn unbuffered(w: Writer) std.io.BufferedWriter { return buffered(w, &.{}); } + +/// A `Writer` that discards all data. +pub const @"null": Writer = .{ + .context = undefined, + .vtable = &.{ + .writeSplat = null_writeSplat, + .writeFile = null_writeFile, + }, +}; + +fn null_writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) Result { + _ = context; + const headers = data[0 .. data.len - 1]; + const pattern = data[headers.len..]; + var written: usize = pattern.len * splat; + for (headers) |bytes| written += bytes.len; + return .{ .len = written }; +} + +fn null_writeFile( + context: ?*anyopaque, + file: std.fs.File, + offset: Offset, + len: FileLen, + headers_and_trailers: []const []const u8, + headers_len: usize, +) Result { + _ = context; + var n: usize = 0; + if (len == .entire_file) { + const headers = headers_and_trailers[0..headers_len]; + for (headers) |bytes| n += bytes.len; + if (offset.toInt()) |off| { + const stat = file.stat() catch |err| return .{ .err = err, .len = n }; + n += stat.size - off; + for (headers_and_trailers[headers_len..]) |bytes| n += bytes.len; + return .{ .len = n }; + } + @panic("TODO stream from file until eof, counting"); + } + for (headers_and_trailers) |bytes| n += bytes.len; + return .{ .len = len.int() + n }; +} + +test @"null" { + try @"null".writeAll("yay"); +} + +fn eof_writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) Result { + _ = context; + _ = data; + _ = splat; + return .{ .end = true }; +} + +fn eof_writeFile( + context: ?*anyopaque, + file: std.fs.File, + offset: u64, + len: FileLen, + headers_and_trailers: []const []const u8, + headers_len: usize, +) Result { + _ = context; + _ = file; + _ = offset; + _ = len; + _ = headers_and_trailers; + _ = headers_len; + return .{ .end = true }; +} diff --git a/lib/std/leb128.zig b/lib/std/leb128.zig index dd630b1d29..5e48fa107f 100644 --- a/lib/std/leb128.zig +++ b/lib/std/leb128.zig @@ -2,151 +2,6 @@ const builtin = @import("builtin"); const std = @import("std"); const testing = std.testing; -/// Read a single unsigned LEB128 value from the given reader as type T, -/// or error.Overflow if the value cannot fit. -pub fn readUleb128(comptime T: type, reader: anytype) !T { - const U = if (@typeInfo(T).int.bits < 8) u8 else T; - const ShiftT = std.math.Log2Int(U); - - const max_group = (@typeInfo(U).int.bits + 6) / 7; - - var value: U = 0; - var group: ShiftT = 0; - - while (group < max_group) : (group += 1) { - const byte = try reader.readByte(); - - const ov = @shlWithOverflow(@as(U, byte & 0x7f), group * 7); - if (ov[1] != 0) return error.Overflow; - - value |= ov[0]; - if (byte & 0x80 == 0) break; - } else { - return error.Overflow; - } - - // only applies in the case that we extended to u8 - if (U != T) { - if (value > std.math.maxInt(T)) return error.Overflow; - } - - return @as(T, @truncate(value)); -} - -/// Deprecated: use `readUleb128` -pub const readULEB128 = readUleb128; - -/// Write a single unsigned integer as unsigned LEB128 to the given writer. -pub fn writeUleb128(writer: anytype, arg: anytype) !void { - const Arg = @TypeOf(arg); - const Int = switch (Arg) { - comptime_int => std.math.IntFittingRange(arg, arg), - else => Arg, - }; - const Value = if (@typeInfo(Int).int.bits < 8) u8 else Int; - var value: Value = arg; - - while (true) { - const byte: u8 = @truncate(value & 0x7f); - value >>= 7; - if (value == 0) { - try writer.writeByte(byte); - break; - } else { - try writer.writeByte(byte | 0x80); - } - } -} - -/// Deprecated: use `writeUleb128` -pub const writeULEB128 = writeUleb128; - -/// Read a single signed LEB128 value from the given reader as type T, -/// or error.Overflow if the value cannot fit. -pub fn readIleb128(comptime T: type, reader: anytype) !T { - const S = if (@typeInfo(T).int.bits < 8) i8 else T; - const U = std.meta.Int(.unsigned, @typeInfo(S).int.bits); - const ShiftU = std.math.Log2Int(U); - - const max_group = (@typeInfo(U).int.bits + 6) / 7; - - var value = @as(U, 0); - var group = @as(ShiftU, 0); - - while (group < max_group) : (group += 1) { - const byte = try reader.readByte(); - - const shift = group * 7; - const ov = @shlWithOverflow(@as(U, byte & 0x7f), shift); - if (ov[1] != 0) { - // Overflow is ok so long as the sign bit is set and this is the last byte - if (byte & 0x80 != 0) return error.Overflow; - if (@as(S, @bitCast(ov[0])) >= 0) return error.Overflow; - - // and all the overflowed bits are 1 - const remaining_shift = @as(u3, @intCast(@typeInfo(U).int.bits - @as(u16, shift))); - const remaining_bits = @as(i8, @bitCast(byte | 0x80)) >> remaining_shift; - if (remaining_bits != -1) return error.Overflow; - } else { - // If we don't overflow and this is the last byte and the number being decoded - // is negative, check that the remaining bits are 1 - if ((byte & 0x80 == 0) and (@as(S, @bitCast(ov[0])) < 0)) { - const remaining_shift = @as(u3, @intCast(@typeInfo(U).int.bits - @as(u16, shift))); - const remaining_bits = @as(i8, @bitCast(byte | 0x80)) >> remaining_shift; - if (remaining_bits != -1) return error.Overflow; - } - } - - value |= ov[0]; - if (byte & 0x80 == 0) { - const needs_sign_ext = group + 1 < max_group; - if (byte & 0x40 != 0 and needs_sign_ext) { - const ones = @as(S, -1); - value |= @as(U, @bitCast(ones)) << (shift + 7); - } - break; - } - } else { - return error.Overflow; - } - - const result = @as(S, @bitCast(value)); - // Only applies if we extended to i8 - if (S != T) { - if (result > std.math.maxInt(T) or result < std.math.minInt(T)) return error.Overflow; - } - - return @as(T, @truncate(result)); -} - -/// Deprecated: use `readIleb128` -pub const readILEB128 = readIleb128; - -/// Write a single signed integer as signed LEB128 to the given writer. -pub fn writeIleb128(writer: anytype, arg: anytype) !void { - const Arg = @TypeOf(arg); - const Int = switch (Arg) { - comptime_int => std.math.IntFittingRange(-@abs(arg), @abs(arg)), - else => Arg, - }; - const Signed = if (@typeInfo(Int).int.bits < 8) i8 else Int; - const Unsigned = std.meta.Int(.unsigned, @typeInfo(Signed).int.bits); - var value: Signed = arg; - - while (true) { - const unsigned: Unsigned = @bitCast(value); - const byte: u8 = @truncate(unsigned); - value >>= 6; - if (value == -1 or value == 0) { - try writer.writeByte(byte & 0x7F); - break; - } else { - value >>= 1; - try writer.writeByte(byte | 0x80); - } - } -} - /// This is an "advanced" function. It allows one to use a fixed amount of memory to store a /// ULEB128. This defeats the entire purpose of using this data encoding; it will no longer use /// fewer bytes to store smaller numbers. The advantage of using a fixed width is that it makes @@ -176,9 +31,6 @@ pub fn writeUnsignedExtended(slice: []u8, arg: anytype) void { slice[slice.len - 1] = @as(u7, @intCast(value)); } -/// Deprecated: use `writeIleb128` -pub const writeILEB128 = writeIleb128; - test writeUnsignedFixed { { var buf: [4]u8 = undefined; @@ -261,42 +113,45 @@ test writeSignedFixed { } } -// tests fn test_read_stream_ileb128(comptime T: type, encoded: []const u8) !T { - var reader = std.io.fixedBufferStream(encoded); - return try readIleb128(T, reader.reader()); + var br: std.io.BufferedReader = undefined; + br.initFixed(encoded); + return br.takeIleb128(T); } fn test_read_stream_uleb128(comptime T: type, encoded: []const u8) !T { - var reader = std.io.fixedBufferStream(encoded); - return try readUleb128(T, reader.reader()); + var br: std.io.BufferedReader = undefined; + br.initFixed(encoded); + return br.takeUleb128(T); } fn test_read_ileb128(comptime T: type, encoded: []const u8) !T { - var reader = std.io.fixedBufferStream(encoded); - const v1 = try readIleb128(T, reader.reader()); - return v1; + var br: std.io.BufferedReader = undefined; + br.initFixed(encoded); + return br.readIleb128(T); } fn test_read_uleb128(comptime T: type, encoded: []const u8) !T { - var reader = std.io.fixedBufferStream(encoded); - const v1 = try readUleb128(T, reader.reader()); - return v1; + var br: std.io.BufferedReader = undefined; + br.initFixed(encoded); + return br.readUleb128(T); } fn test_read_ileb128_seq(comptime T: type, comptime N: usize, encoded: []const u8) !void { - var reader = std.io.fixedBufferStream(encoded); + var br: std.io.BufferedReader = undefined; + br.initFixed(encoded); var i: usize = 0; while (i < N) : (i += 1) { - _ = try readIleb128(T, reader.reader()); + _ = try br.readIleb128(T); } } fn test_read_uleb128_seq(comptime T: type, comptime N: usize, encoded: []const u8) !void { - var reader = std.io.fixedBufferStream(encoded); + var br: std.io.BufferedReader = undefined; + br.initFixed(encoded); var i: usize = 0; while (i < N) : (i += 1) { - _ = try readUleb128(T, reader.reader()); + _ = try br.readUleb128(T); } } @@ -392,8 +247,8 @@ fn test_write_leb128(value: anytype) !void { const signedness = @typeInfo(T).int.signedness; const t_signed = signedness == .signed; - const writeStream = if (t_signed) writeIleb128 else writeUleb128; - const readStream = if (t_signed) readIleb128 else readUleb128; + const writeStream = if (t_signed) std.io.BufferedWriter.writeIleb128 else std.io.BufferedWriter.writeUleb128; + const readStream = if (t_signed) std.io.BufferedReader.readIleb128 else std.io.BufferedReader.readUleb128; // decode to a larger bit size too, to ensure sign extension // is working as expected @@ -412,23 +267,24 @@ fn test_write_leb128(value: anytype) !void { const max_groups = if (@typeInfo(T).int.bits == 0) 1 else (@typeInfo(T).int.bits + 6) / 7; var buf: [max_groups]u8 = undefined; - var fbs = std.io.fixedBufferStream(&buf); + var bw: std.io.BufferedWriter = undefined; + bw.initFixed(&buf); // stream write - try writeStream(fbs.writer(), value); - const w1_pos = fbs.pos; - try testing.expect(w1_pos == bytes_needed); + try testing.expect((try writeStream(&bw, value)) == bytes_needed); + try testing.expect(bw.buffer.items.len == bytes_needed); // stream read - fbs.pos = 0; - const sr = try readStream(T, fbs.reader()); - try testing.expect(fbs.pos == w1_pos); + var br: std.io.BufferedReader = undefined; + br.initFixed(&buf); + const sr = try readStream(&br, T); + try testing.expect(br.seek == bytes_needed); try testing.expect(sr == value); // bigger type stream read - fbs.pos = 0; - const bsr = try readStream(B, fbs.reader()); - try testing.expect(fbs.pos == w1_pos); + bw.buffer.items.len = 0; + const bsr = try readStream(&bw, B); + try testing.expect(bw.buffer.items.len == bytes_needed); try testing.expect(bsr == value); } diff --git a/lib/std/zig/ErrorBundle.zig b/lib/std/zig/ErrorBundle.zig index acc3d2939d..9d7a934c0d 100644 --- a/lib/std/zig/ErrorBundle.zig +++ b/lib/std/zig/ErrorBundle.zig @@ -189,24 +189,22 @@ fn renderErrorMessageToWriter( indent: usize, ) anyerror!void { const ttyconf = options.ttyconf; - var counting_writer: std.io.CountingWriter = .{ .child_writer = bw.writer() }; - var counting_bw = counting_writer.writer().unbuffered(); const err_msg = eb.getErrorMessage(err_msg_index); + // This is the length of the part before the error message: + // e.g. "file.zig:4:5: error: " + var prefix_len: usize = 0; if (err_msg.src_loc != .none) { const src = eb.extraData(SourceLocation, @intFromEnum(err_msg.src_loc)); - try counting_bw.splatByteAll(' ', indent); + prefix_len += try bw.splatByteAllCount(' ', indent); try ttyconf.setColor(bw, .bold); - try counting_bw.print("{s}:{d}:{d}: ", .{ + prefix_len += try bw.printCount("{s}:{d}:{d}: ", .{ eb.nullTerminatedString(src.data.src_path), src.data.line + 1, src.data.column + 1, }); try ttyconf.setColor(bw, color); - try counting_bw.writeAll(kind); - try counting_bw.writeAll(": "); - // This is the length of the part before the error message: - // e.g. "file.zig:4:5: error: " - const prefix_len: usize = @intCast(counting_writer.bytes_written); + prefix_len += try bw.writeAllCount(kind); + prefix_len += try bw.writeAllCount(": "); try ttyconf.setColor(bw, .reset); try ttyconf.setColor(bw, .bold); if (err_msg.count == 1) { diff --git a/lib/std/zip/test.zig b/lib/std/zip/test.zig index aba49e7af2..ce1c6198e7 100644 --- a/lib/std/zip/test.zig +++ b/lib/std/zip/test.zig @@ -103,13 +103,13 @@ pub const Zip64Options = struct { }; pub fn writeZip( - writer: anytype, + writer: *std.io.BufferedWriter, files: []const File, store: []FileStore, options: WriteZipOptions, ) !void { if (store.len < files.len) return error.FileStoreTooSmall; - var zipper = initZipper(writer); + var zipper: Zipper = .init(writer); for (files, 0..) |file, i| { store[i] = try zipper.writeFile(.{ .name = file.name, @@ -126,173 +126,172 @@ pub fn writeZip( try zipper.writeEndRecord(if (options.end) |e| e else .{}); } -pub fn initZipper(writer: anytype) Zipper(@TypeOf(writer)) { - return .{ .counting_writer = std.io.countingWriter(writer) }; -} - /// Provides methods to format and write the contents of a zip archive /// to the underlying Writer. -pub fn Zipper(comptime Writer: type) type { - return struct { - counting_writer: std.io.CountingWriter(Writer), - central_count: u64 = 0, - first_central_offset: ?u64 = null, - last_central_limit: ?u64 = null, +pub const Zipper = struct { + writer: *std.io.BufferedWriter, + bytes_written: u64, + central_count: u64 = 0, + first_central_offset: ?u64 = null, + last_central_limit: ?u64 = null, - const Self = @This(); + const Self = @This(); - pub fn writeFile( - self: *Self, - opt: struct { - name: []const u8, - content: []const u8, - compression: zip.CompressionMethod, - write_options: WriteZipOptions, - }, - ) !FileStore { - const writer = self.counting_writer.writer(); + pub fn init(writer: *std.io.BufferedWriter) Zipper { + return .{ .writer = writer, .bytes_written = 0 }; + } - const file_offset: u64 = @intCast(self.counting_writer.bytes_written); - const crc32 = std.hash.Crc32.hash(opt.content); + pub fn writeFile( + self: *Self, + opt: struct { + name: []const u8, + content: []const u8, + compression: zip.CompressionMethod, + write_options: WriteZipOptions, + }, + ) !FileStore { + const writer = self.writer; - const header_options = opt.write_options.local_header; - { - var compressed_size: u32 = 0; - var uncompressed_size: u32 = 0; - var extra_len: u16 = 0; - if (header_options) |hdr_options| { - compressed_size = if (hdr_options.compressed_size) |size| size else 0; - uncompressed_size = if (hdr_options.uncompressed_size) |size| size else @intCast(opt.content.len); - extra_len = if (hdr_options.extra_len) |len| len else 0; - } - const hdr: zip.LocalFileHeader = .{ - .signature = zip.local_file_header_sig, - .version_needed_to_extract = 10, - .flags = .{ .encrypted = false, ._ = 0 }, - .compression_method = opt.compression, - .last_modification_time = 0, - .last_modification_date = 0, - .crc32 = crc32, - .compressed_size = compressed_size, - .uncompressed_size = uncompressed_size, - .filename_len = @intCast(opt.name.len), - .extra_len = extra_len, - }; - try writer.writeStructEndian(hdr, .little); + const file_offset: u64 = @intCast(self.bytes_written); + const crc32 = std.hash.Crc32.hash(opt.content); + + const header_options = opt.write_options.local_header; + { + var compressed_size: u32 = 0; + var uncompressed_size: u32 = 0; + var extra_len: u16 = 0; + if (header_options) |hdr_options| { + compressed_size = if (hdr_options.compressed_size) |size| size else 0; + uncompressed_size = if (hdr_options.uncompressed_size) |size| size else @intCast(opt.content.len); + extra_len = if (hdr_options.extra_len) |len| len else 0; } - try writer.writeAll(opt.name); - - if (header_options) |hdr| { - if (hdr.zip64) |options| { - try writer.writeInt(u16, 0x0001, .little); - const data_size = if (options.data_size) |size| size else 8; - try writer.writeInt(u16, data_size, .little); - try writer.writeInt(u64, 0, .little); - try writer.writeInt(u64, @intCast(opt.content.len), .little); - } - } - - var compressed_size: u32 = undefined; - switch (opt.compression) { - .store => { - try writer.writeAll(opt.content); - compressed_size = @intCast(opt.content.len); - }, - .deflate => { - const offset = self.counting_writer.bytes_written; - var fbs = std.io.fixedBufferStream(opt.content); - try std.compress.flate.deflate.compress(.raw, fbs.reader(), writer, .{}); - std.debug.assert(fbs.pos == opt.content.len); - compressed_size = @intCast(self.counting_writer.bytes_written - offset); - }, - else => unreachable, - } - return .{ - .compression = opt.compression, - .file_offset = file_offset, - .crc32 = crc32, - .compressed_size = compressed_size, - .uncompressed_size = opt.content.len, - }; - } - - pub fn writeCentralRecord( - self: *Self, - store: FileStore, - opt: struct { - name: []const u8, - version_needed_to_extract: u16 = 10, - }, - ) !void { - if (self.first_central_offset == null) { - self.first_central_offset = self.counting_writer.bytes_written; - } - self.central_count += 1; - - const hdr: zip.CentralDirectoryFileHeader = .{ - .signature = zip.central_file_header_sig, - .version_made_by = 0, - .version_needed_to_extract = opt.version_needed_to_extract, + const hdr: zip.LocalFileHeader = .{ + .signature = zip.local_file_header_sig, + .version_needed_to_extract = 10, .flags = .{ .encrypted = false, ._ = 0 }, - .compression_method = store.compression, + .compression_method = opt.compression, .last_modification_time = 0, .last_modification_date = 0, - .crc32 = store.crc32, - .compressed_size = store.compressed_size, - .uncompressed_size = @intCast(store.uncompressed_size), + .crc32 = crc32, + .compressed_size = compressed_size, + .uncompressed_size = uncompressed_size, .filename_len = @intCast(opt.name.len), - .extra_len = 0, - .comment_len = 0, - .disk_number = 0, - .internal_file_attributes = 0, - .external_file_attributes = 0, - .local_file_header_offset = @intCast(store.file_offset), + .extra_len = extra_len, }; - try self.counting_writer.writer().writeStructEndian(hdr, .little); - try self.counting_writer.writer().writeAll(opt.name); - self.last_central_limit = self.counting_writer.bytes_written; + self.bytes_written += try writer.writeStructEndian(hdr, .little); } + self.bytes_written += try writer.writeAll(opt.name); - pub fn writeEndRecord(self: *Self, opt: EndRecordOptions) !void { - const cd_offset = self.first_central_offset orelse 0; - const cd_end = self.last_central_limit orelse 0; - - if (opt.zip64) |zip64| { - const end64_off = cd_end; - const fixed: zip.EndRecord64 = .{ - .signature = zip.end_record64_sig, - .end_record_size = @sizeOf(zip.EndRecord64) - 12, - .version_made_by = 0, - .version_needed_to_extract = 45, - .disk_number = 0, - .central_directory_disk_number = 0, - .record_count_disk = @intCast(self.central_count), - .record_count_total = @intCast(self.central_count), - .central_directory_size = @intCast(cd_end - cd_offset), - .central_directory_offset = @intCast(cd_offset), - }; - try self.counting_writer.writer().writeStructEndian(fixed, .little); - const locator: zip.EndLocator64 = .{ - .signature = if (zip64.locator_sig) |s| s else zip.end_locator64_sig, - .zip64_disk_count = if (zip64.locator_zip64_disk_count) |c| c else 0, - .record_file_offset = if (zip64.locator_record_file_offset) |o| o else @intCast(end64_off), - .total_disk_count = if (zip64.locator_total_disk_count) |c| c else 1, - }; - try self.counting_writer.writer().writeStructEndian(locator, .little); + if (header_options) |hdr| { + if (hdr.zip64) |options| { + self.bytes_written += try writer.writeInt(u16, 0x0001, .little); + const data_size = if (options.data_size) |size| size else 8; + self.bytes_written += try writer.writeInt(u16, data_size, .little); + self.bytes_written += try writer.writeInt(u64, 0, .little); + self.bytes_written += try writer.writeInt(u64, @intCast(opt.content.len), .little); } - const hdr: zip.EndRecord = .{ - .signature = if (opt.sig) |s| s else zip.end_record_sig, - .disk_number = if (opt.disk_number) |n| n else 0, - .central_directory_disk_number = if (opt.central_directory_disk_number) |n| n else 0, - .record_count_disk = if (opt.record_count_disk) |c| c else @intCast(self.central_count), - .record_count_total = if (opt.record_count_total) |c| c else @intCast(self.central_count), - .central_directory_size = if (opt.central_directory_size) |s| s else @intCast(cd_end - cd_offset), - .central_directory_offset = if (opt.central_directory_offset) |o| o else @intCast(cd_offset), - .comment_len = if (opt.comment_len) |l| l else (if (opt.comment) |c| @as(u16, @intCast(c.len)) else 0), - }; - try self.counting_writer.writer().writeStructEndian(hdr, .little); - if (opt.comment) |c| - try self.counting_writer.writer().writeAll(c); } - }; -} + + var compressed_size: u32 = undefined; + switch (opt.compression) { + .store => { + self.bytes_written += try writer.writeAll(opt.content); + compressed_size = @intCast(opt.content.len); + }, + .deflate => { + const offset = self.bytes_written; + var fbs = std.io.fixedBufferStream(opt.content); + self.bytes_written += try std.compress.flate.deflate.compress(.raw, fbs.reader(), writer, .{}); + std.debug.assert(fbs.pos == opt.content.len); + compressed_size = @intCast(self.bytes_written - offset); + }, + else => unreachable, + } + return .{ + .compression = opt.compression, + .file_offset = file_offset, + .crc32 = crc32, + .compressed_size = compressed_size, + .uncompressed_size = opt.content.len, + }; + } + + pub fn writeCentralRecord( + self: *Self, + store: FileStore, + opt: struct { + name: []const u8, + version_needed_to_extract: u16 = 10, + }, + ) !void { + if (self.first_central_offset == null) { + self.first_central_offset = self.bytes_written; + } + self.central_count += 1; + + const hdr: zip.CentralDirectoryFileHeader = .{ + .signature = zip.central_file_header_sig, + .version_made_by = 0, + .version_needed_to_extract = opt.version_needed_to_extract, + .flags = .{ .encrypted = false, ._ = 0 }, + .compression_method = store.compression, + .last_modification_time = 0, + .last_modification_date = 0, + .crc32 = store.crc32, + .compressed_size = store.compressed_size, + .uncompressed_size = @intCast(store.uncompressed_size), + .filename_len = @intCast(opt.name.len), + .extra_len = 0, + .comment_len = 0, + .disk_number = 0, + .internal_file_attributes = 0, + .external_file_attributes = 0, + .local_file_header_offset = @intCast(store.file_offset), + }; + self.bytes_written += try self.writer.writeStructEndian(hdr, .little); + self.bytes_written += try self.writer.writeAll(opt.name); + self.last_central_limit = self.bytes_written; + } + + pub fn writeEndRecord(self: *Self, opt: EndRecordOptions) !void { + const cd_offset = self.first_central_offset orelse 0; + const cd_end = self.last_central_limit orelse 0; + + if (opt.zip64) |zip64| { + const end64_off = cd_end; + const fixed: zip.EndRecord64 = .{ + .signature = zip.end_record64_sig, + .end_record_size = @sizeOf(zip.EndRecord64) - 12, + .version_made_by = 0, + .version_needed_to_extract = 45, + .disk_number = 0, + .central_directory_disk_number = 0, + .record_count_disk = @intCast(self.central_count), + .record_count_total = @intCast(self.central_count), + .central_directory_size = @intCast(cd_end - cd_offset), + .central_directory_offset = @intCast(cd_offset), + }; + self.bytes_written += try self.writer.writeStructEndian(fixed, .little); + const locator: zip.EndLocator64 = .{ + .signature = if (zip64.locator_sig) |s| s else zip.end_locator64_sig, + .zip64_disk_count = if (zip64.locator_zip64_disk_count) |c| c else 0, + .record_file_offset = if (zip64.locator_record_file_offset) |o| o else @intCast(end64_off), + .total_disk_count = if (zip64.locator_total_disk_count) |c| c else 1, + }; + self.bytes_written += try self.writer.writeStructEndian(locator, .little); + } + const hdr: zip.EndRecord = .{ + .signature = if (opt.sig) |s| s else zip.end_record_sig, + .disk_number = if (opt.disk_number) |n| n else 0, + .central_directory_disk_number = if (opt.central_directory_disk_number) |n| n else 0, + .record_count_disk = if (opt.record_count_disk) |c| c else @intCast(self.central_count), + .record_count_total = if (opt.record_count_total) |c| c else @intCast(self.central_count), + .central_directory_size = if (opt.central_directory_size) |s| s else @intCast(cd_end - cd_offset), + .central_directory_offset = if (opt.central_directory_offset) |o| o else @intCast(cd_offset), + .comment_len = if (opt.comment_len) |l| l else (if (opt.comment) |c| @as(u16, @intCast(c.len)) else 0), + }; + self.bytes_written += try self.writer.writeStructEndian(hdr, .little); + if (opt.comment) |c| + self.bytes_written += try self.writer.writeAll(c); + } +}; diff --git a/src/Type.zig b/src/Type.zig index 81dad15acc..db09ed94ac 100644 --- a/src/Type.zig +++ b/src/Type.zig @@ -121,11 +121,10 @@ pub fn eql(a: Type, b: Type, zcu: *const Zcu) bool { return a.toIntern() == b.toIntern(); } -pub fn format(ty: Type, comptime unused_fmt_string: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { +pub fn format(ty: Type, bw: *std.io.BufferedWriter, comptime f: []const u8) anyerror!usize { _ = ty; - _ = unused_fmt_string; - _ = options; - _ = writer; + _ = f; + _ = bw; @compileError("do not format types directly; use either ty.fmtDebug() or ty.fmt()"); } @@ -143,15 +142,9 @@ const FormatContext = struct { pt: Zcu.PerThread, }; -fn format2( - ctx: FormatContext, - comptime unused_format_string: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, -) !void { - comptime assert(unused_format_string.len == 0); - _ = options; - return print(ctx.ty, writer, ctx.pt); +fn format2(ctx: FormatContext, bw: *std.io.BufferedWriter, comptime f: []const u8) anyerror!usize { + comptime assert(f.len == 0); + return print(ctx.ty, bw, ctx.pt); } pub fn fmtDebug(ty: Type) std.fmt.Formatter(dump) { @@ -173,7 +166,7 @@ pub fn dump( /// Prints a name suitable for `@typeName`. /// TODO: take an `opt_sema` to pass to `fmtValue` when printing sentinels. -pub fn print(ty: Type, writer: *std.io.BufferedWriter, pt: Zcu.PerThread) anyerror!void { +pub fn print(ty: Type, bw: *std.io.BufferedWriter, pt: Zcu.PerThread) anyerror!usize { const zcu = pt.zcu; const ip = &zcu.intern_pool; switch (ip.indexToKey(ty.toIntern())) { @@ -183,22 +176,23 @@ pub fn print(ty: Type, writer: *std.io.BufferedWriter, pt: Zcu.PerThread) anyerr .signed => 'i', .unsigned => 'u', }; - return writer.print("{c}{d}", .{ sign_char, int_type.bits }); + return bw.print("{c}{d}", .{ sign_char, int_type.bits }); }, .ptr_type => { + var n: usize = 0; const info = ty.ptrInfo(zcu); if (info.sentinel != .none) switch (info.flags.size) { .one, .c => unreachable, - .many => try writer.print("[*:{}]", .{Value.fromInterned(info.sentinel).fmtValue(pt)}), - .slice => try writer.print("[:{}]", .{Value.fromInterned(info.sentinel).fmtValue(pt)}), + .many => n += try bw.print("[*:{}]", .{Value.fromInterned(info.sentinel).fmtValue(pt)}), + .slice => n += try bw.print("[:{}]", .{Value.fromInterned(info.sentinel).fmtValue(pt)}), } else switch (info.flags.size) { - .one => try writer.writeAll("*"), - .many => try writer.writeAll("[*]"), - .c => try writer.writeAll("[*c]"), - .slice => try writer.writeAll("[]"), + .one => n += try bw.writeAll("*"), + .many => n += try bw.writeAll("[*]"), + .c => n += try bw.writeAll("[*c]"), + .slice => n += try bw.writeAll("[]"), } - if (info.flags.is_allowzero and info.flags.size != .c) try writer.writeAll("allowzero "); + if (info.flags.is_allowzero and info.flags.size != .c) n += try bw.writeAll("allowzero "); if (info.flags.alignment != .none or info.packed_offset.host_size != 0 or info.flags.vector_index != .none) @@ -207,76 +201,83 @@ pub fn print(ty: Type, writer: *std.io.BufferedWriter, pt: Zcu.PerThread) anyerr info.flags.alignment else Type.fromInterned(info.child).abiAlignment(pt.zcu); - try writer.print("align({d}", .{alignment.toByteUnits() orelse 0}); + n += try bw.print("align({d}", .{alignment.toByteUnits() orelse 0}); if (info.packed_offset.bit_offset != 0 or info.packed_offset.host_size != 0) { - try writer.print(":{d}:{d}", .{ + n += try bw.print(":{d}:{d}", .{ info.packed_offset.bit_offset, info.packed_offset.host_size, }); } if (info.flags.vector_index == .runtime) { - try writer.writeAll(":?"); + n += try bw.writeAll(":?"); } else if (info.flags.vector_index != .none) { - try writer.print(":{d}", .{@intFromEnum(info.flags.vector_index)}); + n += try bw.print(":{d}", .{@intFromEnum(info.flags.vector_index)}); } - try writer.writeAll(") "); + n += try bw.writeAll(") "); } if (info.flags.address_space != .generic) { - try writer.print("addrspace(.{s}) ", .{@tagName(info.flags.address_space)}); + n += try bw.print("addrspace(.{s}) ", .{@tagName(info.flags.address_space)}); } - if (info.flags.is_const) try writer.writeAll("const "); - if (info.flags.is_volatile) try writer.writeAll("volatile "); + if (info.flags.is_const) n += try bw.writeAll("const "); + if (info.flags.is_volatile) n += try bw.writeAll("volatile "); - try print(Type.fromInterned(info.child), writer, pt); - return; + n += try print(Type.fromInterned(info.child), bw, pt); + return n; }, .array_type => |array_type| { + var n: usize = 0; if (array_type.sentinel == .none) { - try writer.print("[{d}]", .{array_type.len}); - try print(Type.fromInterned(array_type.child), writer, pt); + n += try bw.print("[{d}]", .{array_type.len}); + n += try print(Type.fromInterned(array_type.child), bw, pt); } else { - try writer.print("[{d}:{}]", .{ + n += try bw.print("[{d}:{}]", .{ array_type.len, Value.fromInterned(array_type.sentinel).fmtValue(pt), }); - try print(Type.fromInterned(array_type.child), writer, pt); + n += try print(Type.fromInterned(array_type.child), bw, pt); } - return; + return n; }, .vector_type => |vector_type| { - try writer.print("@Vector({d}, ", .{vector_type.len}); - try print(Type.fromInterned(vector_type.child), writer, pt); - try writer.writeAll(")"); - return; + var n: usize = 0; + n += try bw.print("@Vector({d}, ", .{vector_type.len}); + n += try print(Type.fromInterned(vector_type.child), bw, pt); + n += try bw.writeAll(")"); + return n; }, .opt_type => |child| { - try writer.writeByte('?'); - return print(Type.fromInterned(child), writer, pt); + var n: usize = 0; + n += try bw.writeByte('?'); + n += try print(Type.fromInterned(child), bw, pt); + return n; }, .error_union_type => |error_union_type| { - try print(Type.fromInterned(error_union_type.error_set_type), writer, pt); - try writer.writeByte('!'); + var n: usize = 0; + n += try print(Type.fromInterned(error_union_type.error_set_type), bw, pt); + n += try bw.writeByte('!'); if (error_union_type.payload_type == .generic_poison_type) { - try writer.writeAll("anytype"); + n += try bw.writeAll("anytype"); } else { - try print(Type.fromInterned(error_union_type.payload_type), writer, pt); + n += try print(Type.fromInterned(error_union_type.payload_type), bw, pt); } - return; + return n; }, .inferred_error_set_type => |func_index| { const func_nav = ip.getNav(zcu.funcInfo(func_index).owner_nav); - try writer.print("@typeInfo(@typeInfo(@TypeOf({})).@\"fn\".return_type.?).error_union.error_set", .{ + return bw.print("@typeInfo(@typeInfo(@TypeOf({})).@\"fn\".return_type.?).error_union.error_set", .{ func_nav.fqn.fmt(ip), }); }, .error_set_type => |error_set_type| { + var n: usize = 0; const names = error_set_type.names; - try writer.writeAll("error{"); + n += try bw.writeAll("error{"); for (names.get(ip), 0..) |name, i| { - if (i != 0) try writer.writeByte(','); - try writer.print("{}", .{name.fmt(ip)}); + if (i != 0) n += try bw.writeByte(','); + n += try bw.print("{}", .{name.fmt(ip)}); } - try writer.writeAll("}"); + n += try bw.writeAll("}"); + return n; }, .simple_type => |s| switch (s) { .f16, @@ -305,97 +306,103 @@ pub fn print(ty: Type, writer: *std.io.BufferedWriter, pt: Zcu.PerThread) anyerr .comptime_float, .noreturn, .adhoc_inferred_error_set, - => return writer.writeAll(@tagName(s)), + => return bw.writeAll(@tagName(s)), .null, .undefined, - => try writer.print("@TypeOf({s})", .{@tagName(s)}), + => return bw.print("@TypeOf({s})", .{@tagName(s)}), - .enum_literal => try writer.writeAll("@Type(.enum_literal)"), + .enum_literal => return bw.writeAll("@Type(.enum_literal)"), .generic_poison => unreachable, }, .struct_type => { const name = ip.loadStructType(ty.toIntern()).name; - try writer.print("{}", .{name.fmt(ip)}); + return bw.print("{}", .{name.fmt(ip)}); }, .tuple_type => |tuple| { if (tuple.types.len == 0) { - return writer.writeAll("@TypeOf(.{})"); + return bw.writeAll("@TypeOf(.{})"); } - try writer.writeAll("struct {"); + var n: usize = 0; + n += try bw.writeAll("struct {"); for (tuple.types.get(ip), tuple.values.get(ip), 0..) |field_ty, val, i| { - try writer.writeAll(if (i == 0) " " else ", "); - if (val != .none) try writer.writeAll("comptime "); - try print(Type.fromInterned(field_ty), writer, pt); - if (val != .none) try writer.print(" = {}", .{Value.fromInterned(val).fmtValue(pt)}); + n += try bw.writeAll(if (i == 0) " " else ", "); + if (val != .none) n += try bw.writeAll("comptime "); + n += try print(Type.fromInterned(field_ty), bw, pt); + if (val != .none) n += try bw.print(" = {}", .{Value.fromInterned(val).fmtValue(pt)}); } - try writer.writeAll(" }"); + n += try bw.writeAll(" }"); + return n; }, .union_type => { const name = ip.loadUnionType(ty.toIntern()).name; - try writer.print("{}", .{name.fmt(ip)}); + return bw.print("{}", .{name.fmt(ip)}); }, .opaque_type => { const name = ip.loadOpaqueType(ty.toIntern()).name; - try writer.print("{}", .{name.fmt(ip)}); + return bw.print("{}", .{name.fmt(ip)}); }, .enum_type => { const name = ip.loadEnumType(ty.toIntern()).name; - try writer.print("{}", .{name.fmt(ip)}); + return bw.print("{}", .{name.fmt(ip)}); }, .func_type => |fn_info| { + var n: usize = 0; if (fn_info.is_noinline) { - try writer.writeAll("noinline "); + n += try bw.writeAll("noinline "); } - try writer.writeAll("fn ("); + n += try bw.writeAll("fn ("); const param_types = fn_info.param_types.get(&zcu.intern_pool); for (param_types, 0..) |param_ty, i| { - if (i != 0) try writer.writeAll(", "); + if (i != 0) n += try bw.writeAll(", "); if (std.math.cast(u5, i)) |index| { if (fn_info.paramIsComptime(index)) { - try writer.writeAll("comptime "); + n += try bw.writeAll("comptime "); } if (fn_info.paramIsNoalias(index)) { - try writer.writeAll("noalias "); + n += try bw.writeAll("noalias "); } } if (param_ty == .generic_poison_type) { - try writer.writeAll("anytype"); + n += try bw.writeAll("anytype"); } else { - try print(Type.fromInterned(param_ty), writer, pt); + n += try print(Type.fromInterned(param_ty), bw, pt); } } if (fn_info.is_var_args) { if (param_types.len != 0) { - try writer.writeAll(", "); + n += try bw.writeAll(", "); } - try writer.writeAll("..."); + n += try bw.writeAll("..."); } - try writer.writeAll(") "); + n += try bw.writeAll(") "); if (fn_info.cc != .auto) print_cc: { if (zcu.getTarget().cCallingConvention()) |ccc| { if (fn_info.cc.eql(ccc)) { - try writer.writeAll("callconv(.c) "); + n += try bw.writeAll("callconv(.c) "); break :print_cc; } } switch (fn_info.cc) { - .auto, .@"async", .naked, .@"inline" => try writer.print("callconv(.{}) ", .{std.zig.fmtId(@tagName(fn_info.cc))}), - else => try writer.print("callconv({any}) ", .{fn_info.cc}), + .auto, .@"async", .naked, .@"inline" => n += try bw.print("callconv(.{}) ", .{std.zig.fmtId(@tagName(fn_info.cc))}), + else => n += try bw.print("callconv({any}) ", .{fn_info.cc}), } } if (fn_info.return_type == .generic_poison_type) { - try writer.writeAll("anytype"); + n += try bw.writeAll("anytype"); } else { - try print(Type.fromInterned(fn_info.return_type), writer, pt); + n += try print(Type.fromInterned(fn_info.return_type), bw, pt); } + return n; }, .anyframe_type => |child| { - if (child == .none) return writer.writeAll("anyframe"); - try writer.writeAll("anyframe->"); - return print(Type.fromInterned(child), writer, pt); + if (child == .none) return bw.writeAll("anyframe"); + var n: usize = 0; + n += try bw.writeAll("anyframe->"); + n += print(Type.fromInterned(child), bw, pt); + return n; }, // values, not types diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 7ee33577d8..42105d622a 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -1270,7 +1270,7 @@ pub const DeclGen = struct { } const ai = ty.arrayInfo(zcu); if (ai.elem_type.eql(.u8, zcu)) { - var literal = stringLiteral(writer, ty.arrayLenIncludingSentinel(zcu)); + var literal: StringLiteral = .init(writer, ty.arrayLenIncludingSentinel(zcu)); try literal.start(); var index: usize = 0; while (index < ai.len) : (index += 1) { @@ -1829,7 +1829,7 @@ pub const DeclGen = struct { const ai = ty.arrayInfo(zcu); if (ai.elem_type.eql(.u8, zcu)) { const c_len = ty.arrayLenIncludingSentinel(zcu); - var literal = stringLiteral(writer, c_len); + var literal: StringLiteral = .init(writer, c_len); try literal.start(); var index: u64 = 0; while (index < c_len) : (index += 1) @@ -8111,7 +8111,12 @@ fn compareOperatorC(operator: std.math.CompareOperator) []const u8 { }; } -fn StringLiteral(comptime WriterType: type) type { +const StringLiteral = struct { + len: usize, + cur_len: usize, + bytes_written: usize, + writer: *std.io.BufferedWriter, + // MSVC throws C2078 if an array of size 65536 or greater is initialized with a string literal, // regardless of the length of the string literal initializing it. Array initializer syntax is // used instead. @@ -8123,81 +8128,68 @@ fn StringLiteral(comptime WriterType: type) type { const max_char_len = 4; const max_literal_len = @min(16380 - max_char_len, 4095); - return struct { - len: u64, - cur_len: u64 = 0, - counting_writer: std.io.CountingWriter(WriterType), + fn init(writer: *std.io.BufferedWriter, len: usize) StringLiteral { + return .{ + .cur_len = 0, + .len = len, + .writer = writer, + .bytes_written = 0, + }; + } - pub const Error = WriterType.Error; - - const Self = @This(); - - pub fn start(self: *Self) Error!void { - const writer = self.counting_writer.writer(); - if (self.len <= max_string_initializer_len) { - try writer.writeByte('\"'); - } else { - try writer.writeByte('{'); - } + pub fn start(self: *StringLiteral) anyerror!void { + const writer = self.writer; + if (self.len <= max_string_initializer_len) { + self.bytes_written += try writer.writeByteCount('\"'); + } else { + self.bytes_written += try writer.writeByteCount('{'); } + } - pub fn end(self: *Self) Error!void { - const writer = self.counting_writer.writer(); - if (self.len <= max_string_initializer_len) { - try writer.writeByte('\"'); - } else { - try writer.writeByte('}'); - } + pub fn end(self: *StringLiteral) anyerror!void { + const writer = self.writer; + if (self.len <= max_string_initializer_len) { + self.bytes_written += try writer.writeByteCount('\"'); + } else { + self.bytes_written += try writer.writeByteCount('}'); } + } - fn writeStringLiteralChar(writer: anytype, c: u8) !void { - switch (c) { - 7 => try writer.writeAll("\\a"), - 8 => try writer.writeAll("\\b"), - '\t' => try writer.writeAll("\\t"), - '\n' => try writer.writeAll("\\n"), - 11 => try writer.writeAll("\\v"), - 12 => try writer.writeAll("\\f"), - '\r' => try writer.writeAll("\\r"), - '"', '\'', '?', '\\' => try writer.print("\\{c}", .{c}), - else => switch (c) { - ' '...'~' => try writer.writeByte(c), - else => try writer.print("\\{o:0>3}", .{c}), - }, - } + fn writeStringLiteralChar(writer: *std.io.BufferedWriter, c: u8) anyerror!usize { + switch (c) { + 7 => return writer.writeAllCount("\\a"), + 8 => return writer.writeAllCount("\\b"), + '\t' => return writer.writeAllCount("\\t"), + '\n' => return writer.writeAllCount("\\n"), + 11 => return writer.writeAllCount("\\v"), + 12 => return writer.writeAllCount("\\f"), + '\r' => return writer.writeAllCount("\\r"), + '"', '\'', '?', '\\' => return writer.printCount("\\{c}", .{c}), + else => switch (c) { + ' '...'~' => return writer.writeByteCount(c), + else => return writer.printCount("\\{o:0>3}", .{c}), + }, } + } - pub fn writeChar(self: *Self, c: u8) Error!void { - const writer = self.counting_writer.writer(); - if (self.len <= max_string_initializer_len) { - if (self.cur_len == 0 and self.counting_writer.bytes_written > 1) - try writer.writeAll("\"\""); + pub fn writeChar(self: *StringLiteral, c: u8) anyerror!void { + const writer = self.writer; + if (self.len <= max_string_initializer_len) { + if (self.cur_len == 0 and self.bytes_written > 1) + self.bytes_written += try writer.writeAllCount("\"\""); - const len = self.counting_writer.bytes_written; - try writeStringLiteralChar(writer, c); + const char_length = try writeStringLiteralChar(writer, c); + self.bytes_written += char_length; + assert(char_length <= max_char_len); + self.cur_len += char_length; - const char_length = self.counting_writer.bytes_written - len; - assert(char_length <= max_char_len); - self.cur_len += char_length; - - if (self.cur_len >= max_literal_len) self.cur_len = 0; - } else { - if (self.counting_writer.bytes_written > 1) try writer.writeByte(','); - try writer.print("'\\x{x}'", .{c}); - } + if (self.cur_len >= max_literal_len) self.cur_len = 0; + } else { + if (self.bytes_written > 1) self.bytes_written += try writer.writeByteCount(','); + self.bytes_written += try writer.printCount("'\\x{x}'", .{c}); } - }; -} - -fn stringLiteral( - child_stream: anytype, - len: u64, -) StringLiteral(@TypeOf(child_stream)) { - return .{ - .len = len, - .counting_writer = std.io.countingWriter(child_stream), - }; -} + } +}; const FormatStringContext = struct { str: []const u8, sentinel: ?u8 }; fn formatStringLiteral( @@ -8208,7 +8200,7 @@ fn formatStringLiteral( ) @TypeOf(writer).Error!void { if (fmt.len != 1 or fmt[0] != 's') @compileError("Invalid fmt: " ++ fmt); - var literal = stringLiteral(writer, data.str.len + @intFromBool(data.sentinel != null)); + var literal: StringLiteral = .init(writer, data.str.len + @intFromBool(data.sentinel != null)); try literal.start(); for (data.str) |c| try literal.writeChar(c); if (data.sentinel) |sentinel| if (sentinel != 0) try literal.writeChar(sentinel); diff --git a/src/link/Dwarf.zig b/src/link/Dwarf.zig index 605d1d23a4..e788360358 100644 --- a/src/link/Dwarf.zig +++ b/src/link/Dwarf.zig @@ -1768,34 +1768,36 @@ pub const WipNav = struct { } const ExprLocCounter = struct { - const Stream = std.io.CountingWriter(std.io.NullWriter); - stream: Stream, + stream: *std.io.BufferedWriter, section_offset_bytes: u32, address_size: AddressSize, - fn init(dwarf: *Dwarf) ExprLocCounter { + counter: usize, + fn init(dwarf: *Dwarf, stream: *std.io.BufferedWriter) ExprLocCounter { return .{ - .stream = std.io.countingWriter(std.io.null_writer), + .stream = stream, .section_offset_bytes = dwarf.sectionOffsetBytes(), .address_size = dwarf.address_size, }; } - fn writer(counter: *ExprLocCounter) Stream.Writer { - return counter.stream.writer(); + fn writer(counter: *ExprLocCounter) *std.io.BufferedWriter { + return counter.stream; } fn endian(_: ExprLocCounter) std.builtin.Endian { return @import("builtin").cpu.arch.endian(); } fn addrSym(counter: *ExprLocCounter, _: u32) error{}!void { - counter.stream.bytes_written += @intFromEnum(counter.address_size); + counter.count += @intFromEnum(counter.address_size); } fn infoEntry(counter: *ExprLocCounter, _: Unit.Index, _: Entry.Index) error{}!void { - counter.stream.bytes_written += counter.section_offset_bytes; + counter.count += counter.section_offset_bytes; } }; fn infoExprLoc(wip_nav: *WipNav, loc: Loc) UpdateError!void { - var counter: ExprLocCounter = .init(wip_nav.dwarf); - try loc.write(&counter); + var buffer: [std.atomic.cache_line]u8 = undefined; + var counter_bw = std.io.Writer.null.buffered(&buffer); + var counter: ExprLocCounter = .init(wip_nav.dwarf, &counter_bw); + counter.count += try loc.write(&counter); const adapter: struct { wip_nav: *WipNav, @@ -1812,8 +1814,8 @@ pub const WipNav = struct { try ctx.wip_nav.infoSectionOffset(.debug_info, unit, entry, 0); } } = .{ .wip_nav = wip_nav }; - try uleb128(adapter.writer(), counter.stream.bytes_written); - try loc.write(adapter); + try uleb128(adapter.writer(), counter.count); + _ = try loc.write(adapter); } fn infoAddrSym(wip_nav: *WipNav, sym_index: u32, sym_off: u64) UpdateError!void { @@ -1826,8 +1828,10 @@ pub const WipNav = struct { } fn frameExprLoc(wip_nav: *WipNav, loc: Loc) UpdateError!void { - var counter: ExprLocCounter = .init(wip_nav.dwarf); - try loc.write(&counter); + var buffer: [std.atomic.cache_line]u8 = undefined; + var counter_bw = std.io.Writer.null.buffered(&buffer); + var counter: ExprLocCounter = .init(wip_nav.dwarf, &counter_bw); + counter.count += try loc.write(&counter); const adapter: struct { wip_nav: *WipNav, @@ -1844,8 +1848,8 @@ pub const WipNav = struct { try ctx.wip_nav.sectionOffset(.debug_frame, .debug_info, unit, entry, 0); } } = .{ .wip_nav = wip_nav }; - try uleb128(adapter.writer(), counter.stream.bytes_written); - try loc.write(adapter); + try uleb128(adapter.writer(), counter.count); + _ = try loc.write(adapter); } fn frameAddrSym(wip_nav: *WipNav, sym_index: u32, sym_off: u64) UpdateError!void { @@ -6015,15 +6019,21 @@ fn sectionOffsetBytes(dwarf: *Dwarf) u32 { } fn uleb128Bytes(value: anytype) u32 { - var cw = std.io.countingWriter(std.io.null_writer); - try uleb128(cw.writer(), value); - return @intCast(cw.bytes_written); + var buffer: [std.atomic.cache_line]u8 = undefined; + var bw: std.io.BufferedWriter = .{ + .unbuffered_writer = .null, + .buffer = .initBuffer(&buffer), + }; + return try std.leb.writeUleb128Count(&bw, value); } fn sleb128Bytes(value: anytype) u32 { - var cw = std.io.countingWriter(std.io.null_writer); - try sleb128(cw.writer(), value); - return @intCast(cw.bytes_written); + var buffer: [std.atomic.cache_line]u8 = undefined; + var bw: std.io.BufferedWriter = .{ + .unbuffered_writer = .null, + .buffer = .initBuffer(&buffer), + }; + return try std.leb.writeIleb128Count(&bw, value); } /// overrides `-fno-incremental` for testing incremental debug info until `-fincremental` is functional