From 6eac56caf715bdc2dbd63b628ca48c0e32d5a70c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 31 Jul 2025 17:31:54 -0700 Subject: [PATCH] std.compress.flate.Decompress: allow users to swap out Writer --- lib/std/compress/flate/Decompress.zig | 33 +++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/lib/std/compress/flate/Decompress.zig b/lib/std/compress/flate/Decompress.zig index 8b16b3b353..05a2354f09 100644 --- a/lib/std/compress/flate/Decompress.zig +++ b/lib/std/compress/flate/Decompress.zig @@ -58,7 +58,7 @@ pub fn init(input: *Reader, container: Container, buffer: []u8) Decompress { .reader = .{ .vtable = &.{ .stream = stream, - .rebase = rebase, + .rebase = rebaseFallible, .discard = discard, .readVec = readVec, }, @@ -78,12 +78,19 @@ pub fn init(input: *Reader, container: Container, buffer: []u8) Decompress { }; } -fn rebase(r: *Reader, capacity: usize) Reader.RebaseError!void { +fn rebaseFallible(r: *Reader, capacity: usize) Reader.RebaseError!void { + const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); + rebase(d, capacity); +} + +fn rebase(d: *Decompress, capacity: usize) void { + const r = &d.reader; assert(capacity <= r.buffer.len - flate.history_len); assert(r.end + capacity > r.buffer.len); const discard_n = r.end - flate.history_len; const keep = r.buffer[discard_n..r.end]; @memmove(r.buffer[0..keep.len], keep); + assert(keep.len != 0); r.end = keep.len; r.seek -= discard_n; } @@ -101,6 +108,7 @@ fn discard(r: *Reader, limit: std.Io.Limit) Reader.Error!usize { .end = r.end, }; defer { + assert(writer.end != 0); r.end = writer.end; r.seek = r.end; } @@ -115,14 +123,20 @@ fn discard(r: *Reader, limit: std.Io.Limit) Reader.Error!usize { fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize { _ = data; - assert(r.seek == r.end); - r.rebase(flate.history_len) catch unreachable; + const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); + return streamIndirect(d); +} + +fn streamIndirect(d: *Decompress) Reader.Error!usize { + const r = &d.reader; + if (r.end + flate.history_len > r.buffer.len) rebase(d, flate.history_len); var writer: Writer = .{ .buffer = r.buffer, .end = r.end, .vtable = &.{ .drain = Writer.fixedDrain }, }; - r.end += r.vtable.stream(r, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) { + defer r.end = writer.end; + _ = streamFallible(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) { error.WriteFailed => unreachable, else => |e| return e, }; @@ -188,7 +202,12 @@ fn decodeSymbol(self: *Decompress, decoder: anytype) !Symbol { pub fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize { const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); - return readInner(d, w, limit) catch |err| switch (err) { + if (w.end >= r.end) return streamFallible(d, w, limit); + return streamIndirect(d); +} + +fn streamFallible(d: *Decompress, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize { + return streamInner(d, w, limit) catch |err| switch (err) { error.EndOfStream => { if (d.state == .end) { return error.EndOfStream; @@ -207,7 +226,7 @@ pub fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!us }; } -fn readInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader.StreamError)!usize { +fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader.StreamError)!usize { var remaining = @intFromEnum(limit); const in = d.input; sw: switch (d.state) {