diff --git a/lib/std/compress/flate.zig b/lib/std/compress/flate.zig index 354f95b6cc..65af44b7b4 100644 --- a/lib/std/compress/flate.zig +++ b/lib/std/compress/flate.zig @@ -13,7 +13,7 @@ pub fn decompress(reader: anytype, writer: anytype) !void { /// Decompressor type pub fn Decompressor(comptime ReaderType: type) type { - return inflate.Inflate(.raw, ReaderType); + return inflate.Decompressor(.raw, ReaderType); } /// Create Decompressor which will read compressed data from reader. diff --git a/lib/std/compress/flate/bit_reader.zig b/lib/std/compress/flate/bit_reader.zig index 193849836e..8fc94cd4b4 100644 --- a/lib/std/compress/flate/bit_reader.zig +++ b/lib/std/compress/flate/bit_reader.zig @@ -2,8 +2,16 @@ const std = @import("std"); const assert = std.debug.assert; const testing = std.testing; -pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) { - return BitReader(@TypeOf(reader)).init(reader); +pub fn bitReader(comptime T: type, reader: anytype) BitReader(T, @TypeOf(reader)) { + return BitReader(T, @TypeOf(reader)).init(reader); +} + +pub fn BitReader64(comptime ReaderType: type) type { + return BitReader(u64, ReaderType); +} + +pub fn BitReader32(comptime ReaderType: type) type { + return BitReader(u32, ReaderType); } /// Bit reader used during inflate (decompression). Has internal buffer of 64 @@ -15,12 +23,16 @@ pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) { /// fill buffer from forward_reader by calling fill in advance and readF with /// buffered flag set. /// -pub fn BitReader(comptime ReaderType: type) type { +pub fn BitReader(comptime T: type, comptime ReaderType: type) type { + assert(T == u32 or T == u64); + const t_bytes: usize = @sizeOf(T); + const Tshift = if (T == u64) u6 else u5; + return struct { // Underlying reader used for filling internal bits buffer forward_reader: ReaderType = undefined, // Internal buffer of 64 bits - bits: u64 = 0, + bits: T = 0, // Number of bits in the buffer nbits: u32 = 0, @@ -44,21 +56,21 @@ pub fn BitReader(comptime ReaderType: type) type { /// that number of bits available. If end of forward stream is reached /// it may be some extra zero bits in buffer. pub inline fn fill(self: *Self, nice: u6) !void { - if (self.nbits >= nice) { + if (self.nbits >= nice and nice != 0) { return; // We have enought bits } // Read more bits from forward reader // Number of empty bytes in bits, round nbits to whole bytes. const empty_bytes = - @as(u8, if (self.nbits & 0x7 == 0) 8 else 7) - // 8 for 8, 16, 24..., 7 otherwise + @as(u8, if (self.nbits & 0x7 == 0) t_bytes else t_bytes - 1) - // 8 for 8, 16, 24..., 7 otherwise (self.nbits >> 3); // 0 for 0-7, 1 for 8-16, ... same as / 8 - var buf: [8]u8 = [_]u8{0} ** 8; + var buf: [t_bytes]u8 = [_]u8{0} ** t_bytes; const bytes_read = self.forward_reader.readAll(buf[0..empty_bytes]) catch 0; if (bytes_read > 0) { - const u: u64 = std.mem.readInt(u64, buf[0..8], .little); - self.bits |= u << @as(u6, @intCast(self.nbits)); + const u: T = std.mem.readInt(T, buf[0..t_bytes], .little); + self.bits |= u << @as(Tshift, @intCast(self.nbits)); self.nbits += 8 * @as(u8, @intCast(bytes_read)); return; } @@ -99,7 +111,17 @@ pub fn BitReader(comptime ReaderType: type) type { /// Read with flags provided. pub fn readF(self: *Self, comptime U: type, comptime how: u3) !U { - const n: u6 = @bitSizeOf(U); + if (U == T) { + assert(how == 0); + assert(self.alignBits() == 0); + try self.fill(@bitSizeOf(T)); + if (self.nbits != @bitSizeOf(T)) return error.EndOfStream; + const v = self.bits; + self.nbits = 0; + self.bits = 0; + return v; + } + const n: Tshift = @bitSizeOf(U); switch (how) { 0 => { // `normal` read try self.fill(n); // ensure that there are n bits in the buffer @@ -157,7 +179,7 @@ pub fn BitReader(comptime ReaderType: type) type { } /// Advance buffer for n bits. - pub fn shift(self: *Self, n: u6) !void { + pub fn shift(self: *Self, n: Tshift) !void { if (n > self.nbits) return error.EndOfStream; self.bits >>= n; self.nbits -= n; @@ -218,10 +240,10 @@ pub fn BitReader(comptime ReaderType: type) type { }; } -test "BitReader" { +test "readF" { var fbs = std.io.fixedBufferStream(&[_]u8{ 0xf3, 0x48, 0xcd, 0xc9, 0x00, 0x00 }); - var br = bitReader(fbs.reader()); - const F = BitReader(@TypeOf(fbs.reader())).flag; + var br = bitReader(u64, fbs.reader()); + const F = BitReader64(@TypeOf(fbs.reader())).flag; try testing.expectEqual(@as(u8, 48), br.nbits); try testing.expectEqual(@as(u64, 0xc9cd48f3), br.bits); @@ -254,36 +276,38 @@ test "BitReader" { } test "read block type 1 data" { - const data = [_]u8{ - 0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1 - 0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00, - 0x0c, 0x01, 0x02, 0x03, // - 0xaa, 0xbb, 0xcc, 0xdd, - }; - var fbs = std.io.fixedBufferStream(&data); - var br = bitReader(fbs.reader()); - const F = BitReader(@TypeOf(fbs.reader())).flag; + inline for ([_]type{ u64, u32 }) |T| { + const data = [_]u8{ + 0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1 + 0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00, + 0x0c, 0x01, 0x02, 0x03, // + 0xaa, 0xbb, 0xcc, 0xdd, + }; + var fbs = std.io.fixedBufferStream(&data); + var br = bitReader(T, fbs.reader()); + const F = BitReader(T, @TypeOf(fbs.reader())).flag; - try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal - try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type + try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal + try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type - for ("Hello world\n") |c| { - try testing.expectEqual(@as(u8, c), try br.readF(u8, F.reverse) - 0x30); + for ("Hello world\n") |c| { + try testing.expectEqual(@as(u8, c), try br.readF(u8, F.reverse) - 0x30); + } + try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block + br.alignToByte(); + try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0)); + try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0)); + try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0)); } - try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block - br.alignToByte(); - try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0)); - try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0)); - try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0)); } -test "init" { +test "shift/fill" { const data = [_]u8{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, }; var fbs = std.io.fixedBufferStream(&data); - var br = bitReader(fbs.reader()); + var br = bitReader(u64, fbs.reader()); try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits); try br.shift(8); @@ -303,31 +327,96 @@ test "init" { } test "readAll" { - const data = [_]u8{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - }; - var fbs = std.io.fixedBufferStream(&data); - var br = bitReader(fbs.reader()); + inline for ([_]type{ u64, u32 }) |T| { + const data = [_]u8{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + }; + var fbs = std.io.fixedBufferStream(&data); + var br = bitReader(T, fbs.reader()); - try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits); + switch (T) { + u64 => try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits), + u32 => try testing.expectEqual(@as(u32, 0x04_03_02_01), br.bits), + else => unreachable, + } - var out: [16]u8 = undefined; - try br.readAll(out[0..]); - try testing.expect(br.nbits == 0); - try testing.expect(br.bits == 0); + var out: [16]u8 = undefined; + try br.readAll(out[0..]); + try testing.expect(br.nbits == 0); + try testing.expect(br.bits == 0); - try testing.expectEqualSlices(u8, data[0..16], &out); + try testing.expectEqualSlices(u8, data[0..16], &out); + } } test "readFixedCode" { - const fixed_codes = @import("huffman_encoder.zig").fixed_codes; + inline for ([_]type{ u64, u32 }) |T| { + const fixed_codes = @import("huffman_encoder.zig").fixed_codes; - var fbs = std.io.fixedBufferStream(&fixed_codes); - var rdr = bitReader(fbs.reader()); + var fbs = std.io.fixedBufferStream(&fixed_codes); + var rdr = bitReader(T, fbs.reader()); - for (0..286) |c| { - try testing.expectEqual(c, try rdr.readFixedCode()); + for (0..286) |c| { + try testing.expectEqual(c, try rdr.readFixedCode()); + } + try testing.expect(rdr.nbits == 0); } - try testing.expect(rdr.nbits == 0); +} + +test "u32 leaves no bits on u32 reads" { + const data = [_]u8{ + 0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + }; + var fbs = std.io.fixedBufferStream(&data); + var br = bitReader(u32, fbs.reader()); + + _ = try br.read(u3); + try testing.expectEqual(29, br.nbits); + br.alignToByte(); + try testing.expectEqual(24, br.nbits); + try testing.expectEqual(0x04_03_02_01, try br.read(u32)); + try testing.expectEqual(0, br.nbits); + try testing.expectEqual(0x08_07_06_05, try br.read(u32)); + try testing.expectEqual(0, br.nbits); + + _ = try br.read(u9); + try testing.expectEqual(23, br.nbits); + br.alignToByte(); + try testing.expectEqual(16, br.nbits); + try testing.expectEqual(0x0e_0d_0c_0b, try br.read(u32)); + try testing.expectEqual(0, br.nbits); +} + +test "u64 need fill after alignToByte" { + const data = [_]u8{ + 0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + }; + + // without fill + var fbs = std.io.fixedBufferStream(&data); + var br = bitReader(u64, fbs.reader()); + _ = try br.read(u23); + try testing.expectEqual(41, br.nbits); + br.alignToByte(); + try testing.expectEqual(40, br.nbits); + try testing.expectEqual(0x06_05_04_03, try br.read(u32)); + try testing.expectEqual(8, br.nbits); + try testing.expectEqual(0x0a_09_08_07, try br.read(u32)); + try testing.expectEqual(32, br.nbits); + + // fill after align ensures all bits filled + fbs.reset(); + br = bitReader(u64, fbs.reader()); + _ = try br.read(u23); + try testing.expectEqual(41, br.nbits); + br.alignToByte(); + try br.fill(0); + try testing.expectEqual(64, br.nbits); + try testing.expectEqual(0x06_05_04_03, try br.read(u32)); + try testing.expectEqual(32, br.nbits); + try testing.expectEqual(0x0a_09_08_07, try br.read(u32)); + try testing.expectEqual(0, br.nbits); } diff --git a/lib/std/compress/flate/container.zig b/lib/std/compress/flate/container.zig index 23eec920de..fe6dec446d 100644 --- a/lib/std/compress/flate/container.zig +++ b/lib/std/compress/flate/container.zig @@ -154,6 +154,7 @@ pub const Container = enum { pub fn parseFooter(comptime wrap: Container, hasher: *Hasher(wrap), reader: anytype) !void { switch (wrap) { .gzip => { + try reader.fill(0); if (try reader.read(u32) != hasher.chksum()) return error.WrongGzipChecksum; if (try reader.read(u32) != hasher.bytesRead()) return error.WrongGzipSize; }, diff --git a/lib/std/compress/flate/inflate.zig b/lib/std/compress/flate/inflate.zig index a6aee7b56f..cf23961b21 100644 --- a/lib/std/compress/flate/inflate.zig +++ b/lib/std/compress/flate/inflate.zig @@ -17,8 +17,16 @@ pub fn decompress(comptime container: Container, reader: anytype, writer: anytyp } /// Inflate decompressor for the reader type. -pub fn decompressor(comptime container: Container, reader: anytype) Inflate(container, @TypeOf(reader)) { - return Inflate(container, @TypeOf(reader)).init(reader); +pub fn decompressor(comptime container: Container, reader: anytype) Decompressor(container, @TypeOf(reader)) { + return Decompressor(container, @TypeOf(reader)).init(reader); +} + +pub fn Decompressor(comptime container: Container, comptime ReaderType: type) type { + // zlib has 4 bytes footer, lookahead of 4 bytes ensures that we will not overshoot. + // gzip has 8 bytes footer so we will not overshoot even with 8 bytes of lookahead. + // For raw deflate there is always possibility of overshot so we use 8 bytes lookahead. + const lookahead: type = if (container == .zlib) u32 else u64; + return Inflate(container, lookahead, ReaderType); } /// Inflate decompresses deflate bit stream. Reads compressed data from reader @@ -40,9 +48,12 @@ pub fn decompressor(comptime container: Container, reader: anytype) Inflate(cont /// * 64K for history (CircularBuffer) /// * ~10K huffman decoders (Literal and DistanceDecoder) /// -pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { +pub fn Inflate(comptime container: Container, comptime LookaheadType: type, comptime ReaderType: type) type { + assert(LookaheadType == u32 or LookaheadType == u64); + const BitReaderType = BitReader(LookaheadType, ReaderType); + return struct { - const BitReaderType = BitReader(ReaderType); + //const BitReaderType = BitReader(ReaderType); const F = BitReaderType.flag; bits: BitReaderType = .{}, @@ -219,9 +230,14 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { switch (sym.kind) { .literal => self.hist.write(sym.symbol), .match => { // Decode match backreference - try self.bits.fill(5 + 15 + 13); // so we can use buffered reads + // fill so we can use buffered reads + if (LookaheadType == u32) + try self.bits.fill(5 + 15) + else + try self.bits.fill(5 + 15 + 13); const length = try self.decodeLength(sym.symbol); const dsm = try self.decodeSymbol(&self.dst_dec); + if (LookaheadType == u32) try self.bits.fill(13); const distance = try self.decodeDistance(dsm.symbol); try self.hist.writeMatch(length, distance); }, diff --git a/lib/std/compress/gzip.zig b/lib/std/compress/gzip.zig index 8bb09c612a..e619b575de 100644 --- a/lib/std/compress/gzip.zig +++ b/lib/std/compress/gzip.zig @@ -8,7 +8,7 @@ pub fn decompress(reader: anytype, writer: anytype) !void { /// Decompressor type pub fn Decompressor(comptime ReaderType: type) type { - return inflate.Inflate(.gzip, ReaderType); + return inflate.Decompressor(.gzip, ReaderType); } /// Create Decompressor which will read compressed data from reader. diff --git a/lib/std/compress/zlib.zig b/lib/std/compress/zlib.zig index 33401ce845..554f6f894b 100644 --- a/lib/std/compress/zlib.zig +++ b/lib/std/compress/zlib.zig @@ -8,7 +8,7 @@ pub fn decompress(reader: anytype, writer: anytype) !void { /// Decompressor type pub fn Decompressor(comptime ReaderType: type) type { - return inflate.Inflate(.zlib, ReaderType); + return inflate.Decompressor(.zlib, ReaderType); } /// Create Decompressor which will read compressed data from reader. @@ -64,3 +64,38 @@ pub const store = struct { return deflate.store.compressor(.zlib, writer); } }; + +test "should not overshoot" { + const std = @import("std"); + + // Compressed zlib data with extra 4 bytes at the end. + const data = [_]u8{ + 0x78, 0x9c, 0x73, 0xce, 0x2f, 0xa8, 0x2c, 0xca, 0x4c, 0xcf, 0x28, 0x51, 0x08, 0xcf, 0xcc, 0xc9, + 0x49, 0xcd, 0x55, 0x28, 0x4b, 0xcc, 0x53, 0x08, 0x4e, 0xce, 0x48, 0xcc, 0xcc, 0xd6, 0x51, 0x08, + 0xce, 0xcc, 0x4b, 0x4f, 0x2c, 0xc8, 0x2f, 0x4a, 0x55, 0x30, 0xb4, 0xb4, 0x34, 0xd5, 0xb5, 0x34, + 0x03, 0x00, 0x8b, 0x61, 0x0f, 0xa4, 0x52, 0x5a, 0x94, 0x12, + }; + + var stream = std.io.fixedBufferStream(data[0..]); + const reader = stream.reader(); + + var dcp = decompressor(reader); + var out: [128]u8 = undefined; + + // Decompress + var n = try dcp.reader().readAll(out[0..]); + + // Expected decompressed data + try std.testing.expectEqual(46, n); + try std.testing.expectEqualStrings("Copyright Willem van Schaik, Singapore 1995-96", out[0..n]); + + // Decompressor don't overshoot underlying reader. + // It is leaving it at the end of compressed data chunk. + try std.testing.expectEqual(data.len - 4, stream.getPos()); + try std.testing.expectEqual(0, dcp.unreadBytes()); + + // 4 bytes after compressed chunk are available in reader. + n = try reader.readAll(out[0..]); + try std.testing.expectEqual(n, 4); + try std.testing.expectEqualSlices(u8, data[data.len - 4 .. data.len], out[0..n]); +}