std.Io.Writer: introduce rebase to the vtable

fixes #24814
This commit is contained in:
Andrew Kelley 2025-08-13 22:02:30 -07:00
parent 96e4825fbb
commit af7e142485
2 changed files with 103 additions and 61 deletions

View File

@ -4,7 +4,7 @@ const native_endian = builtin.target.cpu.arch.endian();
const Writer = @This(); const Writer = @This();
const std = @import("../std.zig"); const std = @import("../std.zig");
const assert = std.debug.assert; const assert = std.debug.assert;
const Limit = std.io.Limit; const Limit = std.Io.Limit;
const File = std.fs.File; const File = std.fs.File;
const testing = std.testing; const testing = std.testing;
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
@ -76,6 +76,14 @@ pub const VTable = struct {
/// There may be subsequent calls to `drain` and `sendFile` after a `flush` /// There may be subsequent calls to `drain` and `sendFile` after a `flush`
/// operation. /// operation.
flush: *const fn (w: *Writer) Error!void = defaultFlush, flush: *const fn (w: *Writer) Error!void = defaultFlush,
/// Ensures `capacity` more bytes can be buffered without rebasing.
///
/// The most recent `preserve` bytes must remain buffered.
///
/// Only called when `capacity` bytes cannot fit into the unused capacity
/// of `buffer`.
rebase: *const fn (w: *Writer, preserve: usize, capacity: usize) Error!void = defaultRebase,
}; };
pub const Error = error{ pub const Error = error{
@ -117,6 +125,7 @@ pub fn fixed(buffer: []u8) Writer {
.vtable = &.{ .vtable = &.{
.drain = fixedDrain, .drain = fixedDrain,
.flush = noopFlush, .flush = noopFlush,
.rebase = failingRebase,
}, },
.buffer = buffer, .buffer = buffer,
}; };
@ -130,6 +139,7 @@ pub const failing: Writer = .{
.vtable = &.{ .vtable = &.{
.drain = failingDrain, .drain = failingDrain,
.sendFile = failingSendFile, .sendFile = failingSendFile,
.rebase = failingRebase,
}, },
}; };
@ -276,7 +286,7 @@ fn writeSplatHeaderLimitFinish(
test "writeSplatHeader splatting avoids buffer aliasing temptation" { test "writeSplatHeader splatting avoids buffer aliasing temptation" {
const initial_buf = try testing.allocator.alloc(u8, 8); const initial_buf = try testing.allocator.alloc(u8, 8);
var aw: std.io.Writer.Allocating = .initOwnedSlice(testing.allocator, initial_buf); var aw: Allocating = .initOwnedSlice(testing.allocator, initial_buf);
defer aw.deinit(); defer aw.deinit();
// This test assumes 8 vector buffer in this function. // This test assumes 8 vector buffer in this function.
const n = try aw.writer.writeSplatHeader("header which is longer than buf ", &.{ const n = try aw.writer.writeSplatHeader("header which is longer than buf ", &.{
@ -307,24 +317,41 @@ pub fn noopFlush(w: *Writer) Error!void {
test "fixed buffer flush" { test "fixed buffer flush" {
var buffer: [1]u8 = undefined; var buffer: [1]u8 = undefined;
var writer: std.io.Writer = .fixed(&buffer); var writer: Writer = .fixed(&buffer);
try writer.writeByte(10); try writer.writeByte(10);
try writer.flush(); try writer.flush();
try testing.expectEqual(10, buffer[0]); try testing.expectEqual(10, buffer[0]);
} }
/// Calls `VTable.drain` but hides the last `preserve_len` bytes from the pub fn rebase(w: *Writer, preserve: usize, unused_capacity_len: usize) Error!void {
/// implementation, keeping them buffered. if (w.buffer.len - w.end >= unused_capacity_len) {
pub fn drainPreserve(w: *Writer, preserve_len: usize) Error!void { @branchHint(.likely);
const preserved_head = w.end -| preserve_len; return;
const preserved_tail = w.end; }
const preserved_len = preserved_tail - preserved_head; try w.vtable.rebase(w, preserve, unused_capacity_len);
w.end = preserved_head; }
defer w.end += preserved_len;
assert(0 == try w.vtable.drain(w, &.{""}, 1)); pub fn defaultRebase(w: *Writer, preserve: usize, minimum_len: usize) Error!void {
assert(w.end <= preserved_head + preserved_len); while (w.buffer.len - w.end < minimum_len) {
@memmove(w.buffer[w.end..][0..preserved_len], w.buffer[preserved_head..preserved_tail]); {
// TODO: instead of this logic that "hides" data from
// the implementation, introduce a seek index to Writer
const preserved_head = w.end -| preserve;
const preserved_tail = w.end;
const preserved_len = preserved_tail - preserved_head;
w.end = preserved_head;
defer w.end += preserved_len;
assert(0 == try w.vtable.drain(w, &.{""}, 1));
assert(w.end <= preserved_head + preserved_len);
@memmove(w.buffer[w.end..][0..preserved_len], w.buffer[preserved_head..preserved_tail]);
}
// If the loop condition was false this assertion would have passed
// anyway. Otherwise, give the implementation a chance to grow the
// buffer before asserting on the buffer length.
assert(w.buffer.len - preserve >= minimum_len);
}
} }
pub fn unusedCapacitySlice(w: *const Writer) []u8 { pub fn unusedCapacitySlice(w: *const Writer) []u8 {
@ -353,53 +380,44 @@ pub fn writableSlice(w: *Writer, len: usize) Error![]u8 {
return big_slice[0..len]; return big_slice[0..len];
} }
/// Asserts the provided buffer has total capacity enough for `minimum_length`. /// Asserts the provided buffer has total capacity enough for `minimum_len`.
/// ///
/// Does not `advance` the buffer end position. /// Does not `advance` the buffer end position.
/// ///
/// If `minimum_length` is zero, this is equivalent to `unusedCapacitySlice`. /// If `minimum_len` is zero, this is equivalent to `unusedCapacitySlice`.
pub fn writableSliceGreedy(w: *Writer, minimum_length: usize) Error![]u8 { pub fn writableSliceGreedy(w: *Writer, minimum_len: usize) Error![]u8 {
while (w.buffer.len - w.end < minimum_length) { return writableSliceGreedyPreserve(w, 0, minimum_len);
assert(0 == try w.vtable.drain(w, &.{""}, 1));
// If the loop condition was false this assertion would have passed
// anyway. Otherwise, give the implementation a chance to grow the
// buffer before asserting on the buffer length.
assert(w.buffer.len >= minimum_length);
} else {
@branchHint(.likely);
return w.buffer[w.end..];
}
} }
/// Asserts the provided buffer has total capacity enough for `minimum_length` /// Asserts the provided buffer has total capacity enough for `minimum_len`
/// and `preserve_len` combined. /// and `preserve` combined.
/// ///
/// Does not `advance` the buffer end position. /// Does not `advance` the buffer end position.
/// ///
/// When draining the buffer, ensures that at least `preserve_len` bytes /// When draining the buffer, ensures that at least `preserve` bytes
/// remain buffered. /// remain buffered.
/// ///
/// If `preserve_len` is zero, this is equivalent to `writableSliceGreedy`. /// If `preserve` is zero, this is equivalent to `writableSliceGreedy`.
pub fn writableSliceGreedyPreserve(w: *Writer, preserve_len: usize, minimum_length: usize) Error![]u8 { pub fn writableSliceGreedyPreserve(w: *Writer, preserve: usize, minimum_len: usize) Error![]u8 {
assert(w.buffer.len >= preserve_len + minimum_length); if (w.buffer.len - w.end >= minimum_len) {
while (w.buffer.len - w.end < minimum_length) {
try drainPreserve(w, preserve_len);
} else {
@branchHint(.likely); @branchHint(.likely);
return w.buffer[w.end..]; return w.buffer[w.end..];
} }
try rebase(w, preserve, minimum_len);
assert(w.buffer.len >= preserve + minimum_len);
return w.buffer[w.end..];
} }
/// Asserts the provided buffer has total capacity enough for `len`. /// Asserts the provided buffer has total capacity enough for `len`.
/// ///
/// Advances the buffer end position by `len`. /// Advances the buffer end position by `len`.
/// ///
/// When draining the buffer, ensures that at least `preserve_len` bytes /// When draining the buffer, ensures that at least `preserve` bytes
/// remain buffered. /// remain buffered.
/// ///
/// If `preserve_len` is zero, this is equivalent to `writableSlice`. /// If `preserve` is zero, this is equivalent to `writableSlice`.
pub fn writableSlicePreserve(w: *Writer, preserve_len: usize, len: usize) Error![]u8 { pub fn writableSlicePreserve(w: *Writer, preserve: usize, len: usize) Error![]u8 {
const big_slice = try w.writableSliceGreedyPreserve(preserve_len, len); const big_slice = try w.writableSliceGreedyPreserve(preserve, len);
advance(w, len); advance(w, len);
return big_slice[0..len]; return big_slice[0..len];
} }
@ -708,16 +726,18 @@ pub fn writeByte(w: *Writer, byte: u8) Error!void {
} }
} }
/// When draining the buffer, ensures that at least `preserve_len` bytes /// When draining the buffer, ensures that at least `preserve` bytes
/// remain buffered. /// remain buffered.
pub fn writeBytePreserve(w: *Writer, preserve_len: usize, byte: u8) Error!void { pub fn writeBytePreserve(w: *Writer, preserve: usize, byte: u8) Error!void {
while (w.buffer.len - w.end == 0) { if (w.buffer.len - w.end != 0) {
try drainPreserve(w, preserve_len);
} else {
@branchHint(.likely); @branchHint(.likely);
w.buffer[w.end] = byte; w.buffer[w.end] = byte;
w.end += 1; w.end += 1;
return;
} }
try w.vtable.rebase(w, preserve, 1);
w.buffer[w.end] = byte;
w.end += 1;
} }
/// Writes the same byte many times, performing the underlying write call as /// Writes the same byte many times, performing the underlying write call as
@ -735,18 +755,18 @@ test splatByteAll {
try testing.expectEqualStrings("7" ** 45, aw.writer.buffered()); try testing.expectEqualStrings("7" ** 45, aw.writer.buffered());
} }
pub fn splatBytePreserve(w: *Writer, preserve_len: usize, byte: u8, n: usize) Error!void { pub fn splatBytePreserve(w: *Writer, preserve: usize, byte: u8, n: usize) Error!void {
const new_end = w.end + n; const new_end = w.end + n;
if (new_end <= w.buffer.len) { if (new_end <= w.buffer.len) {
@memset(w.buffer[w.end..][0..n], byte); @memset(w.buffer[w.end..][0..n], byte);
w.end = new_end; w.end = new_end;
return; return;
} }
// If `n` is large, we can ignore `preserve_len` up to a point. // If `n` is large, we can ignore `preserve` up to a point.
var remaining = n; var remaining = n;
while (remaining > preserve_len) { while (remaining > preserve) {
assert(remaining != 0); assert(remaining != 0);
remaining -= try splatByte(w, byte, remaining - preserve_len); remaining -= try splatByte(w, byte, remaining - preserve);
if (w.end + remaining <= w.buffer.len) { if (w.end + remaining <= w.buffer.len) {
@memset(w.buffer[w.end..][0..remaining], byte); @memset(w.buffer[w.end..][0..remaining], byte);
w.end += remaining; w.end += remaining;
@ -754,9 +774,9 @@ pub fn splatBytePreserve(w: *Writer, preserve_len: usize, byte: u8, n: usize) Er
} }
} }
// All the next bytes received must be preserved. // All the next bytes received must be preserved.
if (preserve_len < w.end) { if (preserve < w.end) {
@memmove(w.buffer[0..preserve_len], w.buffer[w.end - preserve_len ..][0..preserve_len]); @memmove(w.buffer[0..preserve], w.buffer[w.end - preserve ..][0..preserve]);
w.end = preserve_len; w.end = preserve;
} }
while (remaining > 0) remaining -= try w.splatByte(byte, remaining); while (remaining > 0) remaining -= try w.splatByte(byte, remaining);
} }
@ -1667,7 +1687,7 @@ pub const ByteSizeUnits = enum {
/// Format option `precision` is ignored when `value` is less than 1kB /// Format option `precision` is ignored when `value` is less than 1kB
pub fn printByteSize( pub fn printByteSize(
w: *std.io.Writer, w: *Writer,
value: u64, value: u64,
comptime units: ByteSizeUnits, comptime units: ByteSizeUnits,
options: std.fmt.Options, options: std.fmt.Options,
@ -2169,7 +2189,7 @@ test "fixed output" {
test "writeSplat 0 len splat larger than capacity" { test "writeSplat 0 len splat larger than capacity" {
var buf: [8]u8 = undefined; var buf: [8]u8 = undefined;
var w: std.io.Writer = .fixed(&buf); var w: Writer = .fixed(&buf);
const n = try w.writeSplat(&.{"something that overflows buf"}, 0); const n = try w.writeSplat(&.{"something that overflows buf"}, 0);
try testing.expectEqual(0, n); try testing.expectEqual(0, n);
} }
@ -2188,6 +2208,13 @@ pub fn failingSendFile(w: *Writer, file_reader: *File.Reader, limit: Limit) File
return error.WriteFailed; return error.WriteFailed;
} }
pub fn failingRebase(w: *Writer, preserve: usize, capacity: usize) Error!void {
_ = w;
_ = preserve;
_ = capacity;
return error.WriteFailed;
}
pub const Discarding = struct { pub const Discarding = struct {
count: u64, count: u64,
writer: Writer, writer: Writer,
@ -2455,7 +2482,7 @@ pub fn Hashing(comptime Hasher: type) type {
/// Maintains `Writer` state such that it writes to the unused capacity of an /// Maintains `Writer` state such that it writes to the unused capacity of an
/// array list, filling it up completely before making a call through the /// array list, filling it up completely before making a call through the
/// vtable, causing a resize. Consequently, the same, optimized, non-generic /// vtable, causing a resize. Consequently, the same, optimized, non-generic
/// machine code that uses `std.io.Reader`, such as formatted printing, takes /// machine code that uses `std.Io.Reader`, such as formatted printing, takes
/// the hot paths when using this API. /// the hot paths when using this API.
/// ///
/// When using this API, it is not necessary to call `flush`. /// When using this API, it is not necessary to call `flush`.
@ -2514,6 +2541,7 @@ pub const Allocating = struct {
.drain = Allocating.drain, .drain = Allocating.drain,
.sendFile = Allocating.sendFile, .sendFile = Allocating.sendFile,
.flush = noopFlush, .flush = noopFlush,
.rebase = growingRebase,
}; };
pub fn deinit(a: *Allocating) void { pub fn deinit(a: *Allocating) void {
@ -2595,7 +2623,7 @@ pub const Allocating = struct {
return list.items.len - start_len; return list.items.len - start_len;
} }
fn sendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) FileError!usize { fn sendFile(w: *Writer, file_reader: *File.Reader, limit: Limit) FileError!usize {
if (File.Handle == void) return error.Unimplemented; if (File.Handle == void) return error.Unimplemented;
if (limit == .nothing) return 0; if (limit == .nothing) return 0;
const a: *Allocating = @fieldParentPtr("writer", w); const a: *Allocating = @fieldParentPtr("writer", w);
@ -2612,6 +2640,15 @@ pub const Allocating = struct {
return n; return n;
} }
fn growingRebase(w: *Writer, preserve: usize, minimum_len: usize) Error!void {
_ = preserve; // This implementation always preserves the entire buffer.
const a: *Allocating = @fieldParentPtr("writer", w);
const gpa = a.allocator;
var list = a.toArrayList();
defer setArrayList(a, list);
list.ensureUnusedCapacity(gpa, minimum_len) catch return error.WriteFailed;
}
fn setArrayList(a: *Allocating, list: std.ArrayListUnmanaged(u8)) void { fn setArrayList(a: *Allocating, list: std.ArrayListUnmanaged(u8)) void {
a.writer.buffer = list.allocatedSlice(); a.writer.buffer = list.allocatedSlice();
a.writer.end = list.items.len; a.writer.end = list.items.len;
@ -2645,7 +2682,7 @@ test "discarding sendFile" {
try file_reader.seekTo(0); try file_reader.seekTo(0);
var w_buffer: [256]u8 = undefined; var w_buffer: [256]u8 = undefined;
var discarding: std.io.Writer.Discarding = .init(&w_buffer); var discarding: Writer.Discarding = .init(&w_buffer);
_ = try file_reader.interface.streamRemaining(&discarding.writer); _ = try file_reader.interface.streamRemaining(&discarding.writer);
} }
@ -2664,7 +2701,7 @@ test "allocating sendFile" {
var file_reader = file_writer.moveToReader(); var file_reader = file_writer.moveToReader();
try file_reader.seekTo(0); try file_reader.seekTo(0);
var allocating: std.io.Writer.Allocating = .init(testing.allocator); var allocating: Writer.Allocating = .init(testing.allocator);
defer allocating.deinit(); defer allocating.deinit();
_ = try file_reader.interface.streamRemaining(&allocating.writer); _ = try file_reader.interface.streamRemaining(&allocating.writer);
@ -2702,3 +2739,10 @@ test writeSliceEndian {
try writeSliceEndian(&w, u16, &array, .big); try writeSliceEndian(&w, u16, &array, .big);
try testing.expectEqualSlices(u8, &.{ 'x', 0x12, 0x34, 0x56, 0x78 }, &buffer); try testing.expectEqualSlices(u8, &.{ 'x', 0x12, 0x34, 0x56, 0x78 }, &buffer);
} }
test "writableSlice with fixed writer" {
var buf: [2]u8 = undefined;
var w: std.Io.Writer = .fixed(&buf);
try w.writeByte(1);
try std.testing.expectError(error.WriteFailed, w.writableSlice(2));
}

View File

@ -76,7 +76,7 @@ const indirect_vtable: Reader.VTable = .{
/// `input` buffer is asserted to be at least 10 bytes, or EOF before then. /// `input` buffer is asserted to be at least 10 bytes, or EOF before then.
/// ///
/// If `buffer` is provided then asserted to have `flate.max_window_len` /// If `buffer` is provided then asserted to have `flate.max_window_len`
/// capacity, as well as `flate.history_len` unused capacity on every write. /// capacity.
pub fn init(input: *Reader, container: Container, buffer: []u8) Decompress { pub fn init(input: *Reader, container: Container, buffer: []u8) Decompress {
if (buffer.len != 0) assert(buffer.len >= flate.max_window_len); if (buffer.len != 0) assert(buffer.len >= flate.max_window_len);
return .{ return .{
@ -239,8 +239,6 @@ fn decodeSymbol(self: *Decompress, decoder: anytype) !Symbol {
} }
fn streamDirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize { fn streamDirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
assert(w.buffer.len >= flate.max_window_len);
assert(w.unusedCapacityLen() >= flate.history_len);
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
return streamFallible(d, w, limit); return streamFallible(d, w, limit);
} }