diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index f8547edea5..5f294b9fe9 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -1592,6 +1592,7 @@ pub fn writer(file: File) std.io.Writer { .context = interface.handleToOpaque(file.handle), .vtable = &.{ .writev = interface.writev, + .splat = interface.splat, .writeFile = interface.writeFile, }, }; @@ -1628,6 +1629,28 @@ const interface = struct { return std.posix.writev(file, iovecs); } + fn splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize { + const file = opaqueToHandle(context); + + if (is_windows) { + // TODO improve this to use WriteFileScatter + if (headers.len > 0) { + const first = headers[0]; + return windows.WriteFile(file, first, null); + } + if (n > 0) return windows.WriteFile(file, pattern, null); + return 0; + } + + var iovecs_buffer: [max_buffers_len]std.posix.iovec_const = undefined; + const iovecs = iovecs_buffer[0..@min(iovecs_buffer.len, headers.len)]; + for (iovecs, headers[0..iovecs.len]) |*v, d| v.* = .{ + .base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length. + .len = d.len, + }; + return std.posix.writev(file, iovecs); + } + fn writeFile( context: *anyopaque, in_file: std.fs.File, diff --git a/lib/std/io.zig b/lib/std/io.zig index 3f38e1fc93..8d84a8a36f 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -345,15 +345,23 @@ pub const null_writer: Writer = .{ .context = undefined, .vtable = &.{ .writev = null_writev, + .splat = null_splat, .writeFile = null_writeFile, }, }; fn null_writev(context: *anyopaque, data: []const []const u8) anyerror!usize { _ = context; - var n: usize = 0; - for (data) |bytes| n += bytes.len; - return n; + var written: usize = 0; + for (data) |bytes| written += bytes.len; + return written; +} + +fn null_splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize { + _ = context; + var written: usize = pattern.len * n; + for (headers) |bytes| written += bytes.len; + return written; } fn null_writeFile( diff --git a/lib/std/io/AllocatingWriter.zig b/lib/std/io/AllocatingWriter.zig index 8aaf028ad0..20cda9de19 100644 --- a/lib/std/io/AllocatingWriter.zig +++ b/lib/std/io/AllocatingWriter.zig @@ -99,20 +99,29 @@ pub fn clearRetainingCapacity(aw: *AllocatingWriter) void { } fn writev(context: *anyopaque, data: []const []const u8) anyerror!usize { + return splat(context, data, &.{}, 0); +} + +fn splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize { const aw: *AllocatingWriter = @alignCast(@ptrCast(context)); const start_len = aw.written.len; const bw = &aw.buffered_writer; - assert(data[0].ptr == aw.written.ptr + start_len); + assert(headers[0].ptr == aw.written.ptr + start_len); var list: std.ArrayListUnmanaged(u8) = .{ - .items = aw.written.ptr[0 .. start_len + data[0].len], + .items = aw.written.ptr[0 .. start_len + headers[0].len], .capacity = start_len + bw.buffer.len, }; defer setArrayList(aw, list); - const rest = data[1..]; - var new_capacity: usize = list.capacity; + const rest = headers[1..]; + var new_capacity: usize = list.capacity + pattern.len * n; for (rest) |bytes| new_capacity += bytes.len; try list.ensureTotalCapacity(aw.allocator, new_capacity + 1); for (rest) |bytes| list.appendSliceAssumeCapacity(bytes); + if (pattern.len == 1) { + list.appendNTimesAssumeCapacity(pattern[0], n); + } else { + for (0..n) |_| list.appendSliceAssumeCapacity(pattern); + } aw.written = list.items; bw.buffer = list.unusedCapacitySlice(); return list.items.len - start_len; diff --git a/lib/std/io/BufferedWriter.zig b/lib/std/io/BufferedWriter.zig index 679fc82ae1..ee15da1ea8 100644 --- a/lib/std/io/BufferedWriter.zig +++ b/lib/std/io/BufferedWriter.zig @@ -6,6 +6,13 @@ const Writer = std.io.Writer; const testing = std.testing; /// Underlying stream to send bytes to. +/// +/// A write will only be sent here if it could not fit into `buffer`, or if it +/// is a `writeFile`. +/// +/// `unbuffered_writer` may modify `buffer` if the number of bytes returned +/// equals number of bytes provided. This property is exploited by +/// `std.io.AllocatingWriter` for example. unbuffered_writer: Writer, /// User-provided storage that must outlive this `BufferedWriter`. /// @@ -27,6 +34,7 @@ pub fn writer(bw: *BufferedWriter) Writer { .context = bw, .vtable = &.{ .writev = passthru_writev, + .splat = passthru_splat, .writeFile = passthru_writeFile, }, }; @@ -34,6 +42,7 @@ pub fn writer(bw: *BufferedWriter) Writer { const fixed_vtable: Writer.VTable = .{ .writev = fixed_writev, + .splat = fixed_splat, .writeFile = fixed_writeFile, }; @@ -62,7 +71,8 @@ pub fn reset(bw: *BufferedWriter) void { } pub fn flush(bw: *BufferedWriter) anyerror!void { - try bw.unbuffered_writer.writeAll(bw.buffer[0..bw.end]); + const send_buffer = bw.buffer[0..bw.end]; + try bw.unbuffered_writer.writeAll(send_buffer); bw.end = 0; } @@ -120,6 +130,102 @@ fn passthru_writev(context: *anyopaque, data: []const []const u8) anyerror!usize return end - start_end; } +fn passthru_splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize { + const bw: *BufferedWriter = @alignCast(@ptrCast(context)); + const buffer = bw.buffer; + const start_end = bw.end; + + var end = bw.end; + for (headers, 0..) |bytes, i| { + const new_end = end + bytes.len; + if (new_end <= buffer.len) { + @branchHint(.likely); + @memcpy(buffer[end..new_end], bytes); + end = new_end; + continue; + } + if (end == 0) return bw.unbuffered_writer.splat(headers, pattern, n); + var buffers: [max_buffers_len][]const u8 = undefined; + buffers[0] = buffer[0..end]; + const remaining_headers = headers[i..]; + const remaining_buffers = buffers[1..]; + const len: usize = @min(remaining_headers.len, remaining_buffers.len); + @memcpy(remaining_buffers[0..len], remaining_headers[0..len]); + const send_buffers = buffers[0 .. len + 1]; + if (len >= remaining_headers.len) { + @branchHint(.likely); + // Made it past the headers, so we can call `splat`. + const written = try bw.unbuffered_writer.splat(send_buffers, pattern, n); + if (written < end) { + @branchHint(.unlikely); + const remainder = buffer[written..end]; + std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); + bw.end = remainder.len; + return end - start_end; + } + bw.end = 0; + return written - start_end; + } + const written = try bw.unbuffered_writer.writev(send_buffers); + if (written < end) { + @branchHint(.unlikely); + const remainder = buffer[written..end]; + std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); + bw.end = remainder.len; + return end - start_end; + } + bw.end = 0; + return written - start_end; + } + + switch (pattern.len) { + 0 => { + bw.end = end; + return end - start_end; + }, + 1 => { + const new_end = end + n; + if (new_end <= buffer.len) { + @branchHint(.likely); + @memset(buffer[end..new_end], pattern[0]); + bw.end = new_end; + return end - start_end; + } + const written = try bw.unbuffered_writer.splat(buffer[0..end], pattern, n); + if (written < end) { + @branchHint(.unlikely); + const remainder = buffer[written..end]; + std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); + bw.end = remainder.len; + return end - start_end; + } + bw.end = 0; + return written - start_end; + }, + else => { + const new_end = end + pattern.len * n; + if (new_end <= buffer.len) { + @branchHint(.likely); + while (end < new_end) : (end += pattern.len) { + @memcpy(buffer[end..][0..pattern.len], pattern); + } + bw.end = end; + return end - start_end; + } + const written = try bw.unbuffered_writer.splat(buffer[0..end], pattern, n); + if (written < end) { + @branchHint(.unlikely); + const remainder = buffer[written..end]; + std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); + bw.end = remainder.len; + return end - start_end; + } + bw.end = 0; + return written - start_end; + }, + } +} + fn fixed_writev(context: *anyopaque, data: []const []const u8) anyerror!usize { const bw: *BufferedWriter = @alignCast(@ptrCast(context)); // When this function is called it means the buffer got full, so it's time @@ -131,6 +237,19 @@ fn fixed_writev(context: *anyopaque, data: []const []const u8) anyerror!usize { return error.NoSpaceLeft; } +fn fixed_splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize { + const bw: *BufferedWriter = @alignCast(@ptrCast(context)); + const dest = bw.buffer[bw.end..]; + if (headers.len > 0) { + @memcpy(dest, headers[0][0..dest.len]); + } else switch (pattern.len) { + 0 => unreachable, + 1 => @memset(dest, pattern[0]), + else => for (0..n) |i| @memcpy(dest[i * pattern.len ..][0..pattern.len], pattern), + } + return error.NoSpaceLeft; +} + pub fn write(bw: *BufferedWriter, bytes: []const u8) anyerror!usize { const buffer = bw.buffer; const end = bw.end; @@ -210,71 +329,7 @@ pub fn splatByteAll(bw: *BufferedWriter, byte: u8, n: usize) anyerror!void { /// /// Does maximum of one underlying `Writer.VTable.writev`. pub fn splatByte(bw: *BufferedWriter, byte: u8, n: usize) anyerror!usize { - const buffer = bw.buffer; - const end = bw.end; - - const new_end = end + n; - if (new_end <= buffer.len) { - @branchHint(.likely); - @memset(buffer[end..][0..n], byte); - bw.end = new_end; - return n; - } - - if (n <= buffer.len) { - const written = try bw.unbuffered_writer.write(buffer[0..end]); - if (written < end) { - @branchHint(.unlikely); - const remainder = buffer[written..end]; - std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); - bw.end = remainder.len; - return 0; - } - assert(bw.buffer.ptr == buffer.ptr); // TODO this is not a valid assertion - @memset(buffer[0..n], byte); - bw.end = n; - return n; - } - - // First try to use only the unused buffer region, to make an attempt for a - // single `writev`. - const free_space = buffer[end..]; - var remaining = n - free_space.len; - @memset(free_space, byte); - var buffers: [max_buffers_len][]const u8 = undefined; - buffers[0] = buffer; - var buffer_i: usize = 1; - while (remaining > free_space.len and buffer_i < buffers.len) { - buffers[buffer_i] = free_space; - buffer_i += 1; - remaining -= free_space.len; - } - if (remaining > 0 and buffer_i < buffers.len) { - buffers[buffer_i] = free_space[0..remaining]; - buffer_i += 1; - const written = try bw.unbuffered_writer.writev(buffers[0..buffer_i]); - if (written < end) { - @branchHint(.unlikely); - const remainder = buffer[written..end]; - std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); - bw.end = remainder.len; - return 0; - } - bw.end = 0; - return written - end; - } - - const written = try bw.unbuffered_writer.writev(buffers[0..buffer_i]); - if (written < end) { - @branchHint(.unlikely); - const remainder = buffer[written..end]; - std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); - bw.end = remainder.len; - return 0; - } - - bw.end = 0; - return written - end; + return passthru_splat(bw, &.{}, &.{byte}, n); } /// Writes the same slice many times, performing the underlying write call as @@ -288,40 +343,7 @@ pub fn splatBytesAll(bw: *BufferedWriter, bytes: []const u8, n: usize) anyerror! /// /// Does maximum of one underlying `Writer.VTable.writev`. pub fn splatBytes(bw: *BufferedWriter, bytes: []const u8, n: usize) anyerror!usize { - const buffer = bw.buffer; - const start_end = bw.end; - var end = start_end; - var remaining = n; - while (remaining > 0 and end + bytes.len <= buffer.len) { - @memcpy(buffer[end..][0..bytes.len], bytes); - end += bytes.len; - remaining -= 1; - } - - if (remaining == 0) { - bw.end = end; - return end - start_end; - } - - var buffers: [max_buffers_len][]const u8 = undefined; - var buffer_i: usize = 1; - buffers[0] = buffer[0..end]; - const remaining_buffers = buffers[1..]; - const buffers_len: usize = @min(remaining, remaining_buffers.len); - @memset(remaining_buffers[0..buffers_len], bytes); - remaining -= buffers_len; - buffer_i += buffers_len; - - const written = try bw.unbuffered_writer.writev(buffers[0..buffer_i]); - if (written < end) { - @branchHint(.unlikely); - const remainder = buffer[written..end]; - std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); - bw.end = remainder.len; - return end - start_end; - } - bw.end = 0; - return written - start_end; + return passthru_splat(bw, &.{}, bytes, n); } /// Asserts the `buffer` was initialized with a capacity of at least `@sizeOf(T)` bytes. diff --git a/lib/std/io/CountingWriter.zig b/lib/std/io/CountingWriter.zig index cc4e2ee00e..52a8f4f207 100644 --- a/lib/std/io/CountingWriter.zig +++ b/lib/std/io/CountingWriter.zig @@ -14,6 +14,7 @@ pub fn writer(cw: *CountingWriter) Writer { .context = cw, .vtable = &.{ .writev = passthru_writev, + .splat = passthru_splat, .writeFile = passthru_writeFile, }, }; @@ -28,9 +29,16 @@ pub fn unbufferedWriter(cw: *CountingWriter) std.io.BufferedWriter { fn passthru_writev(context: *anyopaque, data: []const []const u8) anyerror!usize { const cw: *CountingWriter = @alignCast(@ptrCast(context)); - const n = try cw.child_writer.writev(data); - cw.bytes_written += n; - return n; + const written = try cw.child_writer.writev(data); + cw.bytes_written += written; + return written; +} + +fn passthru_splat(context: *anyopaque, header: []const u8, pattern: []const u8, n: usize) anyerror!usize { + const cw: *CountingWriter = @alignCast(@ptrCast(context)); + const written = try cw.child_writer.splat(header, pattern, n); + cw.bytes_written += written; + return written; } fn passthru_writeFile( @@ -42,9 +50,9 @@ fn passthru_writeFile( headers_len: usize, ) anyerror!usize { const cw: *CountingWriter = @alignCast(@ptrCast(context)); - const n = try cw.child_writer.writeFile(file, offset, len, headers_and_trailers, headers_len); - cw.bytes_written += n; - return n; + const written = try cw.child_writer.writeFile(file, offset, len, headers_and_trailers, headers_len); + cw.bytes_written += written; + return written; } test CountingWriter { diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index d0b0b28f82..cd05eb5672 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -15,6 +15,19 @@ pub const VTable = struct { /// of stream via an error. writev: *const fn (context: *anyopaque, data: []const []const u8) anyerror!usize, + /// `headers_and_pattern` must have length of at least one. The last slice + /// is `pattern` which is the byte sequence to repeat `n` times. The rest + /// of the slices are headers to write before the pattern. + /// + /// When `n == 1`, this is equivalent to `writev`. + /// + /// Number of bytes actually written is returned. + /// + /// Number of bytes returned may be zero, which does not mean + /// end-of-stream. A subsequent call may return nonzero, or may signal end + /// of stream via an error. + splat: *const fn (context: *anyopaque, headers_and_pattern: []const []const u8, n: usize) anyerror!usize, + /// Writes contents from an open file. `headers` are written first, then `len` /// bytes of `file` starting from `offset`, then `trailers`. ///