thinking about splat being the only function

This commit is contained in:
Andrew Kelley 2025-02-15 18:45:21 -08:00
parent 5356f3a307
commit b26aceba7d
6 changed files with 196 additions and 113 deletions

View File

@ -1592,6 +1592,7 @@ pub fn writer(file: File) std.io.Writer {
.context = interface.handleToOpaque(file.handle),
.vtable = &.{
.writev = interface.writev,
.splat = interface.splat,
.writeFile = interface.writeFile,
},
};
@ -1628,6 +1629,28 @@ const interface = struct {
return std.posix.writev(file, iovecs);
}
fn splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
const file = opaqueToHandle(context);
if (is_windows) {
// TODO improve this to use WriteFileScatter
if (headers.len > 0) {
const first = headers[0];
return windows.WriteFile(file, first, null);
}
if (n > 0) return windows.WriteFile(file, pattern, null);
return 0;
}
var iovecs_buffer: [max_buffers_len]std.posix.iovec_const = undefined;
const iovecs = iovecs_buffer[0..@min(iovecs_buffer.len, headers.len)];
for (iovecs, headers[0..iovecs.len]) |*v, d| v.* = .{
.base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length.
.len = d.len,
};
return std.posix.writev(file, iovecs);
}
fn writeFile(
context: *anyopaque,
in_file: std.fs.File,

View File

@ -345,15 +345,23 @@ pub const null_writer: Writer = .{
.context = undefined,
.vtable = &.{
.writev = null_writev,
.splat = null_splat,
.writeFile = null_writeFile,
},
};
fn null_writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
_ = context;
var n: usize = 0;
for (data) |bytes| n += bytes.len;
return n;
var written: usize = 0;
for (data) |bytes| written += bytes.len;
return written;
}
fn null_splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
_ = context;
var written: usize = pattern.len * n;
for (headers) |bytes| written += bytes.len;
return written;
}
fn null_writeFile(

View File

@ -99,20 +99,29 @@ pub fn clearRetainingCapacity(aw: *AllocatingWriter) void {
}
fn writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
return splat(context, data, &.{}, 0);
}
fn splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
const aw: *AllocatingWriter = @alignCast(@ptrCast(context));
const start_len = aw.written.len;
const bw = &aw.buffered_writer;
assert(data[0].ptr == aw.written.ptr + start_len);
assert(headers[0].ptr == aw.written.ptr + start_len);
var list: std.ArrayListUnmanaged(u8) = .{
.items = aw.written.ptr[0 .. start_len + data[0].len],
.items = aw.written.ptr[0 .. start_len + headers[0].len],
.capacity = start_len + bw.buffer.len,
};
defer setArrayList(aw, list);
const rest = data[1..];
var new_capacity: usize = list.capacity;
const rest = headers[1..];
var new_capacity: usize = list.capacity + pattern.len * n;
for (rest) |bytes| new_capacity += bytes.len;
try list.ensureTotalCapacity(aw.allocator, new_capacity + 1);
for (rest) |bytes| list.appendSliceAssumeCapacity(bytes);
if (pattern.len == 1) {
list.appendNTimesAssumeCapacity(pattern[0], n);
} else {
for (0..n) |_| list.appendSliceAssumeCapacity(pattern);
}
aw.written = list.items;
bw.buffer = list.unusedCapacitySlice();
return list.items.len - start_len;

View File

@ -6,6 +6,13 @@ const Writer = std.io.Writer;
const testing = std.testing;
/// Underlying stream to send bytes to.
///
/// A write will only be sent here if it could not fit into `buffer`, or if it
/// is a `writeFile`.
///
/// `unbuffered_writer` may modify `buffer` if the number of bytes returned
/// equals number of bytes provided. This property is exploited by
/// `std.io.AllocatingWriter` for example.
unbuffered_writer: Writer,
/// User-provided storage that must outlive this `BufferedWriter`.
///
@ -27,6 +34,7 @@ pub fn writer(bw: *BufferedWriter) Writer {
.context = bw,
.vtable = &.{
.writev = passthru_writev,
.splat = passthru_splat,
.writeFile = passthru_writeFile,
},
};
@ -34,6 +42,7 @@ pub fn writer(bw: *BufferedWriter) Writer {
const fixed_vtable: Writer.VTable = .{
.writev = fixed_writev,
.splat = fixed_splat,
.writeFile = fixed_writeFile,
};
@ -62,7 +71,8 @@ pub fn reset(bw: *BufferedWriter) void {
}
pub fn flush(bw: *BufferedWriter) anyerror!void {
try bw.unbuffered_writer.writeAll(bw.buffer[0..bw.end]);
const send_buffer = bw.buffer[0..bw.end];
try bw.unbuffered_writer.writeAll(send_buffer);
bw.end = 0;
}
@ -120,6 +130,102 @@ fn passthru_writev(context: *anyopaque, data: []const []const u8) anyerror!usize
return end - start_end;
}
fn passthru_splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
const bw: *BufferedWriter = @alignCast(@ptrCast(context));
const buffer = bw.buffer;
const start_end = bw.end;
var end = bw.end;
for (headers, 0..) |bytes, i| {
const new_end = end + bytes.len;
if (new_end <= buffer.len) {
@branchHint(.likely);
@memcpy(buffer[end..new_end], bytes);
end = new_end;
continue;
}
if (end == 0) return bw.unbuffered_writer.splat(headers, pattern, n);
var buffers: [max_buffers_len][]const u8 = undefined;
buffers[0] = buffer[0..end];
const remaining_headers = headers[i..];
const remaining_buffers = buffers[1..];
const len: usize = @min(remaining_headers.len, remaining_buffers.len);
@memcpy(remaining_buffers[0..len], remaining_headers[0..len]);
const send_buffers = buffers[0 .. len + 1];
if (len >= remaining_headers.len) {
@branchHint(.likely);
// Made it past the headers, so we can call `splat`.
const written = try bw.unbuffered_writer.splat(send_buffers, pattern, n);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return written - start_end;
}
const written = try bw.unbuffered_writer.writev(send_buffers);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return written - start_end;
}
switch (pattern.len) {
0 => {
bw.end = end;
return end - start_end;
},
1 => {
const new_end = end + n;
if (new_end <= buffer.len) {
@branchHint(.likely);
@memset(buffer[end..new_end], pattern[0]);
bw.end = new_end;
return end - start_end;
}
const written = try bw.unbuffered_writer.splat(buffer[0..end], pattern, n);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return written - start_end;
},
else => {
const new_end = end + pattern.len * n;
if (new_end <= buffer.len) {
@branchHint(.likely);
while (end < new_end) : (end += pattern.len) {
@memcpy(buffer[end..][0..pattern.len], pattern);
}
bw.end = end;
return end - start_end;
}
const written = try bw.unbuffered_writer.splat(buffer[0..end], pattern, n);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return written - start_end;
},
}
}
fn fixed_writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
const bw: *BufferedWriter = @alignCast(@ptrCast(context));
// When this function is called it means the buffer got full, so it's time
@ -131,6 +237,19 @@ fn fixed_writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
return error.NoSpaceLeft;
}
fn fixed_splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
const bw: *BufferedWriter = @alignCast(@ptrCast(context));
const dest = bw.buffer[bw.end..];
if (headers.len > 0) {
@memcpy(dest, headers[0][0..dest.len]);
} else switch (pattern.len) {
0 => unreachable,
1 => @memset(dest, pattern[0]),
else => for (0..n) |i| @memcpy(dest[i * pattern.len ..][0..pattern.len], pattern),
}
return error.NoSpaceLeft;
}
pub fn write(bw: *BufferedWriter, bytes: []const u8) anyerror!usize {
const buffer = bw.buffer;
const end = bw.end;
@ -210,71 +329,7 @@ pub fn splatByteAll(bw: *BufferedWriter, byte: u8, n: usize) anyerror!void {
///
/// Does maximum of one underlying `Writer.VTable.writev`.
pub fn splatByte(bw: *BufferedWriter, byte: u8, n: usize) anyerror!usize {
const buffer = bw.buffer;
const end = bw.end;
const new_end = end + n;
if (new_end <= buffer.len) {
@branchHint(.likely);
@memset(buffer[end..][0..n], byte);
bw.end = new_end;
return n;
}
if (n <= buffer.len) {
const written = try bw.unbuffered_writer.write(buffer[0..end]);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return 0;
}
assert(bw.buffer.ptr == buffer.ptr); // TODO this is not a valid assertion
@memset(buffer[0..n], byte);
bw.end = n;
return n;
}
// First try to use only the unused buffer region, to make an attempt for a
// single `writev`.
const free_space = buffer[end..];
var remaining = n - free_space.len;
@memset(free_space, byte);
var buffers: [max_buffers_len][]const u8 = undefined;
buffers[0] = buffer;
var buffer_i: usize = 1;
while (remaining > free_space.len and buffer_i < buffers.len) {
buffers[buffer_i] = free_space;
buffer_i += 1;
remaining -= free_space.len;
}
if (remaining > 0 and buffer_i < buffers.len) {
buffers[buffer_i] = free_space[0..remaining];
buffer_i += 1;
const written = try bw.unbuffered_writer.writev(buffers[0..buffer_i]);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return 0;
}
bw.end = 0;
return written - end;
}
const written = try bw.unbuffered_writer.writev(buffers[0..buffer_i]);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return 0;
}
bw.end = 0;
return written - end;
return passthru_splat(bw, &.{}, &.{byte}, n);
}
/// Writes the same slice many times, performing the underlying write call as
@ -288,40 +343,7 @@ pub fn splatBytesAll(bw: *BufferedWriter, bytes: []const u8, n: usize) anyerror!
///
/// Does maximum of one underlying `Writer.VTable.writev`.
pub fn splatBytes(bw: *BufferedWriter, bytes: []const u8, n: usize) anyerror!usize {
const buffer = bw.buffer;
const start_end = bw.end;
var end = start_end;
var remaining = n;
while (remaining > 0 and end + bytes.len <= buffer.len) {
@memcpy(buffer[end..][0..bytes.len], bytes);
end += bytes.len;
remaining -= 1;
}
if (remaining == 0) {
bw.end = end;
return end - start_end;
}
var buffers: [max_buffers_len][]const u8 = undefined;
var buffer_i: usize = 1;
buffers[0] = buffer[0..end];
const remaining_buffers = buffers[1..];
const buffers_len: usize = @min(remaining, remaining_buffers.len);
@memset(remaining_buffers[0..buffers_len], bytes);
remaining -= buffers_len;
buffer_i += buffers_len;
const written = try bw.unbuffered_writer.writev(buffers[0..buffer_i]);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return written - start_end;
return passthru_splat(bw, &.{}, bytes, n);
}
/// Asserts the `buffer` was initialized with a capacity of at least `@sizeOf(T)` bytes.

View File

@ -14,6 +14,7 @@ pub fn writer(cw: *CountingWriter) Writer {
.context = cw,
.vtable = &.{
.writev = passthru_writev,
.splat = passthru_splat,
.writeFile = passthru_writeFile,
},
};
@ -28,9 +29,16 @@ pub fn unbufferedWriter(cw: *CountingWriter) std.io.BufferedWriter {
fn passthru_writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
const cw: *CountingWriter = @alignCast(@ptrCast(context));
const n = try cw.child_writer.writev(data);
cw.bytes_written += n;
return n;
const written = try cw.child_writer.writev(data);
cw.bytes_written += written;
return written;
}
fn passthru_splat(context: *anyopaque, header: []const u8, pattern: []const u8, n: usize) anyerror!usize {
const cw: *CountingWriter = @alignCast(@ptrCast(context));
const written = try cw.child_writer.splat(header, pattern, n);
cw.bytes_written += written;
return written;
}
fn passthru_writeFile(
@ -42,9 +50,9 @@ fn passthru_writeFile(
headers_len: usize,
) anyerror!usize {
const cw: *CountingWriter = @alignCast(@ptrCast(context));
const n = try cw.child_writer.writeFile(file, offset, len, headers_and_trailers, headers_len);
cw.bytes_written += n;
return n;
const written = try cw.child_writer.writeFile(file, offset, len, headers_and_trailers, headers_len);
cw.bytes_written += written;
return written;
}
test CountingWriter {

View File

@ -15,6 +15,19 @@ pub const VTable = struct {
/// of stream via an error.
writev: *const fn (context: *anyopaque, data: []const []const u8) anyerror!usize,
/// `headers_and_pattern` must have length of at least one. The last slice
/// is `pattern` which is the byte sequence to repeat `n` times. The rest
/// of the slices are headers to write before the pattern.
///
/// When `n == 1`, this is equivalent to `writev`.
///
/// Number of bytes actually written is returned.
///
/// Number of bytes returned may be zero, which does not mean
/// end-of-stream. A subsequent call may return nonzero, or may signal end
/// of stream via an error.
splat: *const fn (context: *anyopaque, headers_and_pattern: []const []const u8, n: usize) anyerror!usize,
/// Writes contents from an open file. `headers` are written first, then `len`
/// bytes of `file` starting from `offset`, then `trailers`.
///