From 9d163c7ac3fbc6adfa7c6fa299647d0731f96167 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 3 Jun 2025 11:32:16 -0700 Subject: [PATCH] std.net: update to new IO API --- lib/std/fs/File.zig | 5 + lib/std/io/Reader.zig | 6 +- lib/std/io/Writer.zig | 27 ++-- lib/std/net.zig | 345 ++++++++++++++++++++++++------------------ 4 files changed, 219 insertions(+), 164 deletions(-) diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index f941e37120..45eab505cb 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -1213,6 +1213,10 @@ pub const Writer = struct { const max_buffers_len = 16; pub fn init(file: File, buffer: []u8) std.io.Writer { + return initMode(file, buffer, .positional); + } + + pub fn initMode(file: File, buffer: []u8, init_mode: Writer.Mode) std.io.Writer { return .{ .file = file, .interface = .{ @@ -1223,6 +1227,7 @@ pub const Writer = struct { }, .buffer = buffer, }, + .mode = init_mode, }; } diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index 2f13b467bb..3af3feaad3 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -1392,19 +1392,19 @@ pub fn Hashed(comptime Hasher: type) type { const this: *@This() = @alignCast(@fieldParentPtr("interface", r)); const data = w.writableVector(limit); const n = try this.in.readVec(data); - w.advanceVector(n); + const result = w.advanceVector(n); var remaining: usize = n; for (data) |slice| { if (remaining < slice.len) { this.hasher.update(slice[0..remaining]); - return n; + return result; } else { remaining -= slice.len; this.hasher.update(slice); } } assert(remaining == 0); - return n; + return result; } fn discard(r: *Reader, limit: Limit) Error!usize { diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index 19346f58c0..d2d890802a 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -1672,16 +1672,7 @@ pub fn discardingDrain(w: *Writer, data: []const []const u8, splat: usize) Error pub fn discardingSendFile(w: *Writer, file_reader: *File.Reader, limit: Limit) FileError!usize { if (File.Handle == void) return error.Unimplemented; - if (w.end != 0) { - if (@intFromEnum(limit) >= w.end) { - w.end = 0; - } else { - const remaining = w.buffer[@intFromEnum(limit)..w.end]; - @memmove(w.buffer[0..remaining.len], remaining); - w.end = remaining.len; - } - return 0; - } + w.end = 0; if (file_reader.getSize()) |size| { const n = limit.minInt(size - file_reader.pos); file_reader.seekBy(@intCast(n)) catch return error.Unimplemented; @@ -1694,6 +1685,19 @@ pub fn discardingSendFile(w: *Writer, file_reader: *File.Reader, limit: Limit) F } } +/// This function is used by `VTable.drain` function implementations to +/// implement partial drains. +pub fn consume(w: *Writer, n: usize) usize { + if (n < w.end) { + const remaining = w.buffer[n..w.end]; + @memmove(w.buffer[0..remaining.len], remaining); + w.end = remaining.len; + return 0; + } + defer w.end = 0; + return n - w.end; +} + /// For use when the `Writer` implementation can cannot offer a more efficient /// implementation than a basic read/write loop on the file. pub fn unimplementedSendFile(w: *Writer, file_reader: *File.Reader, limit: Limit) FileError!usize { @@ -1768,13 +1772,14 @@ pub fn Hashed(comptime Hasher: type) type { fn drain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { const this: *@This() = @alignCast(@fieldParentPtr("interface", w)); const aux_n = try this.out.writeSplatAux(w.buffered(), data, splat); - if (aux_n <= w.end) { + if (aux_n < w.end) { this.hasher.update(w.buffer[0..aux_n]); const remaining = w.buffer[aux_n..w.end]; @memmove(w.buffer[0..remaining.len], remaining); w.end = remaining.len; return 0; } + this.hasher.update(w.buffered()); const n = aux_n - w.end; w.end = 0; var remaining: usize = n; diff --git a/lib/std/net.zig b/lib/std/net.zig index d13d69ef87..ebfc998ca2 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -13,6 +13,7 @@ const native_os = builtin.os.tag; const windows = std.os.windows; const Allocator = std.mem.Allocator; const ArrayList = std.ArrayListUnmanaged; +const File = std.fs.File; // Windows 10 added support for unix sockets in build 17063, redstone 4 is the // first release to support them. @@ -853,7 +854,7 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { // TODO: Instead of having a massive error set, make the error set have categories, and then // store the sub-error as a diagnostic value. -const GetAddressListError = Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || posix.SocketError || posix.BindError || posix.SetSockOptError || error{ +const GetAddressListError = Allocator.Error || File.OpenError || File.ReadError || posix.SocketError || posix.BindError || posix.SetSockOptError || error{ TemporaryNameServerFailure, NameServerFailure, AddressFamilyNotSupported, @@ -1363,9 +1364,8 @@ fn linuxLookupNameFromHosts( defer file.close(); var line_buf: [512]u8 = undefined; - var file_reader = file.reader(); - var br = file_reader.interface().buffered(&line_buf); - return parseHosts(gpa, addrs, canon, name, family, port, &br) catch |err| switch (err) { + var file_reader = file.reader(&line_buf); + return parseHosts(gpa, addrs, canon, name, family, port, &file_reader.interface) catch |err| switch (err) { error.OutOfMemory => return error.OutOfMemory, error.ReadFailed => return file_reader.err.?, }; @@ -1378,7 +1378,7 @@ fn parseHosts( name: []const u8, family: posix.sa_family_t, port: u16, - br: *std.io.Reader, + br: *io.Reader, ) error{ OutOfMemory, ReadFailed }!void { while (true) { const line = br.takeDelimiterExclusive('\n') catch |err| switch (err) { @@ -1584,15 +1584,14 @@ const ResolvConf = struct { defer file.close(); var line_buf: [512]u8 = undefined; - var file_reader = file.reader(); - var br = file_reader.interface().buffered(&line_buf); - return parse(rc, &br) catch |err| switch (err) { + var file_reader = file.reader(&line_buf); + return parse(rc, &file_reader.interface) catch |err| switch (err) { error.ReadFailed => return file_reader.err.?, else => |e| return e, }; } - fn parse(rc: *ResolvConf, br: *std.io.Reader) !void { + fn parse(rc: *ResolvConf, br: *io.Reader) !void { const gpa = rc.gpa; while (br.takeSentinel('\n')) |line_with_comment| { const line = line: { @@ -1893,7 +1892,10 @@ pub const Stream = struct { pub const Reader = switch (native_os) { .windows => struct { - stream: Stream, + /// Use `interface` to access portably. + interface_state: io.Reader, + /// Use `getStream` to access portably. + net_stream: Stream, err: ?Error = null, pub const Error = ReadError; @@ -1902,29 +1904,27 @@ pub const Stream = struct { return r.stream; } - pub fn interface(r: *Reader) std.io.Reader { + pub fn interface(r: *Reader) *io.Reader { + return &r.interface_state; + } + + pub fn init(net_stream: Stream, buffer: []u8) Reader { return .{ - .context = r.stream.handle, - .vtable = &.{ - .read = read, - .readVec = readVec, - .discard = discard, + .interface_state = .{ + .context = undefined, + .vtable = &.{ + .stream = stream, + .discard = discard, + }, + .buffer = buffer, }, + .net_stream = net_stream, }; } - fn read( - context: ?*anyopaque, - bw: *std.io.Writer, - limit: std.io.Limit, - ) std.io.Reader.Error!usize { - const buf = limit.slice(try bw.writableSliceGreedy(1)); - const n = try readVec(context, &.{buf}); - bw.advance(n); - return n; - } - - fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { + fn stream(io_r: *io.Reader, io_w: *io.Writer, limit: io.Limit) io.Reader.StreamError!usize { + const r: *Reader = @fieldParentPtr("interface", io_r); + const data = io_w.writableVector(limit); var iovecs: [max_buffers_len]windows.WSABUF = undefined; var iovecs_i: usize = 0; for (data) |d| { @@ -1939,7 +1939,7 @@ pub const Stream = struct { if (bufs.len == 0) return .{}; // Prevent false positive end detection on empty `data`. var n: u32 = undefined; var flags: u32 = 0; - const rc = windows.ws2_32.WSARecvFrom(context, bufs.ptr, bufs.len, &n, &flags, null, null, null, null); + const rc = windows.ws2_32.WSARecvFrom(r.net_stream.handle, bufs.ptr, bufs.len, &n, &flags, null, null, null, null); if (rc != 0) switch (windows.ws2_32.WSAGetLastError()) { .WSAECONNRESET => return error.ConnectionResetByPeer, .WSAEFAULT => unreachable, // a pointer is not completely contained in user address space. @@ -1955,22 +1955,35 @@ pub const Stream = struct { .WSA_OPERATION_ABORTED => unreachable, // not using overlapped I/O else => |err| return windows.unexpectedWSAError(err), }; - return .{ .len = n, .end = n == 0 }; + if (n == 0) return error.EndOfStream; + return io_w.advanceVector(n); } - fn discard(context: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize { - _ = context; + fn discard(io_r: *io.Reader, limit: io.Limit) io.Reader.Error!usize { + const r: *Reader = @fieldParentPtr("interface", io_r); + _ = r; _ = limit; @panic("TODO"); } }, else => struct { - file_reader: std.fs.File.Reader, + file_reader: File.Reader, pub const Error = ReadError; - pub fn interface(r: *Reader) std.io.Reader { - return r.file_reader.interface(); + pub fn interface(r: *Reader) *io.Reader { + return &r.file_reader.interface; + } + + pub fn init(net_stream: Stream, buffer: []u8) Reader { + return .{ + .file_reader = .{ + .interface = File.Reader.initInterface(buffer), + .file = .{ .handle = net_stream.handle }, + .mode = .streaming, + .seek_err = error.Unseekable, + }, + }; } pub fn getStream(r: *const Reader) Stream { @@ -1981,16 +1994,22 @@ pub const Stream = struct { pub const Writer = switch (native_os) { .windows => struct { + /// This field is present on all systems. + interface: io.Writer, + /// Use `getStream` for cross-platform support. stream: Stream, pub const Error = WriteError; - pub fn interface(w: *Writer) std.io.Writer { + pub fn init(stream: Stream, buffer: []u8) Writer { return .{ - .context = w.stream.handle, - .vtable = &.{ - .writeSplat = writeSplat, - .writeFile = writeFile, + .stream = stream, + .interface = .{ + .context = undefined, + .vtable = &.{ + .drain = drain, + }, + .buffer = buffer, }, }; } @@ -1999,41 +2018,63 @@ pub const Stream = struct { return w.stream; } - fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize { + fn drain(io_w: *io.Writer, data: []const []const u8, splat: usize) io.Writer.Error!usize { + const w: *Writer = @fieldParentPtr("interface", io_w); + const buffered = io_w.buffered(); comptime assert(native_os == .windows); - if (data.len == 1 and splat == 0) return 0; - var splat_buffer: [256]u8 = undefined; + var splat_buffer: [splat_buffer_len]u8 = undefined; var iovecs: [max_buffers_len]windows.WSABUF = undefined; - var len: u32 = @min(iovecs.len, data.len); - for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{ - .buf = if (d.len == 0) "" else d.ptr, // TODO: does Windows allow ptr=undefined len=0 ? - .len = d.len, - }; + var len: u32 = 0; + if (buffered.len != 0) { + iovecs[len] = .{ + .base = buffered.ptr, + .len = buffered.len, + }; + len += 1; + } + for (data[0..data.len]) |bytes| { + if (bytes.len == 0) continue; + iovecs[len] = .{ + .buf = bytes.ptr, + .len = bytes.len, + }; + len += 1; + } + const pattern = data[data.len - 1]; switch (splat) { 0 => len -= 1, 1 => {}, - else => { - const pattern = data[data.len - 1]; - if (pattern.len == 1) { + else => switch (pattern.len) { + 0 => {}, + 1 => { + // Replace the 1-byte buffer with a bigger one. const memset_len = @min(splat_buffer.len, splat); const buf = splat_buffer[0..memset_len]; @memset(buf, pattern[0]); - iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len }; + iovecs[len - 1] = .{ .buf = buf.ptr, .len = buf.len }; var remaining_splat = splat - buf.len; while (remaining_splat > splat_buffer.len and len < iovecs.len) { - iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len }; + iovecs[len] = .{ .buf = &splat_buffer, .len = splat_buffer.len }; remaining_splat -= splat_buffer.len; len += 1; } if (remaining_splat > 0 and len < iovecs.len) { - iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat }; + iovecs[len] = .{ .buf = &splat_buffer, .len = remaining_splat }; len += 1; } - } + }, + else => for (0..splat - 1) |_| { + if (iovecs.len - len == 0) break; + iovecs[len] = .{ + .buf = pattern.ptr, + .len = pattern.len, + }; + len += 1; + }, }, } var n: u32 = undefined; - const rc = windows.ws2_32.WSASend(context, &iovecs, len, &n, 0, null, null); + const rc = windows.ws2_32.WSASend(w.stream.handle, &iovecs, len, &n, 0, null, null); if (rc == windows.ws2_32.SOCKET_ERROR) switch (windows.ws2_32.WSAGetLastError()) { .WSAECONNABORTED => return error.ConnectionResetByPeer, .WSAECONNRESET => return error.ConnectionResetByPeer, @@ -2054,123 +2095,127 @@ pub const Stream = struct { .WSA_OPERATION_ABORTED => unreachable, // not using overlapped I/O else => |err| return windows.unexpectedWSAError(err), }; - return n; - } - - fn writeFile( - context: *anyopaque, - in_file: std.fs.File, - in_offset: u64, - in_len: std.io.Writer.FileLen, - headers_and_trailers: []const []const u8, - headers_len: usize, - ) std.io.Writer.FileError!usize { - const len_int = switch (in_len) { - .zero => return writeSplat(context, headers_and_trailers, 1), - .entire_file => std.math.maxInt(usize), - else => in_len.int(), - }; - if (headers_len > 0) return writeSplat(context, headers_and_trailers[0..headers_len], 1); - var file_contents_buffer: [4096]u8 = undefined; - const read_buffer = file_contents_buffer[0..@min(file_contents_buffer.len, len_int)]; - const n = try windows.ReadFile(in_file.handle, read_buffer, in_offset); - return writeSplat(context, &.{read_buffer[0..n]}, 1); + return io_w.consume(n); } }, else => struct { - file_writer: std.fs.File.Writer, + /// This field is present on all systems. + interface: io.Writer, + err: ?Error = null, + file_writer: File.Writer, pub const Error = WriteError; - pub fn interface(w: *Writer) std.io.Writer { + pub fn init(stream: Stream, buffer: []u8) Writer { return .{ - .context = &w.file_writer, - .vtable = &.{ - .writeSplat = writeSplat, - .writeFile = std.fs.File.Writer.writeFile, + .interface = .{ + .context = undefined, + .vtable = &.{ + .drain = drain, + .sendFile = sendFile, + }, + .buffer = buffer, }, - }; - } - - fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize { - const fw: *std.fs.File.Writer = @alignCast(@ptrCast(context)); - const w: *Writer = @fieldParentPtr("file_writer", fw); - var splat_buffer: [256]u8 = undefined; - var iovecs: [max_buffers_len]std.posix.iovec_const = undefined; - var len: usize = @min(iovecs.len, data.len); - for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{ - .base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length. - .len = d.len, - }; - var msg: posix.msghdr_const = .{ - .name = null, - .namelen = 0, - .iov = &iovecs, - .iovlen = len, - .control = null, - .controllen = 0, - .flags = 0, - }; - switch (splat) { - 0 => msg.iovlen = len - 1, - 1 => {}, - else => { - const pattern = data[data.len - 1]; - if (pattern.len == 1) { - const memset_len = @min(splat_buffer.len, splat); - const buf = splat_buffer[0..memset_len]; - @memset(buf, pattern[0]); - iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len }; - var remaining_splat = splat - buf.len; - while (remaining_splat > splat_buffer.len and len < iovecs.len) { - iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len }; - remaining_splat -= splat_buffer.len; - len += 1; - } - if (remaining_splat > 0 and len < iovecs.len) { - iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat }; - len += 1; - } - msg.iovlen = len; - } - }, - } - const flags = posix.MSG.NOSIGNAL; - return std.posix.sendmsg(fw.file.handle, &msg, flags) catch |err| { - w.err = err; - return error.WriteFailed; + .file_writer = .initMode(stream.handle, &.{}, .streaming), }; } pub fn getStream(w: *const Writer) Stream { return .{ .handle = w.file_writer.file.handle }; } + + fn drain(io_w: *io.Writer, data: []const []const u8, splat: usize) io.Writer.Error!usize { + const w: *Writer = @fieldParentPtr("interface", io_w); + const buffered = io_w.buffered(); + var splat_buffer: [splat_buffer_len]u8 = undefined; + var iovecs: [max_buffers_len]std.posix.iovec_const = undefined; + var msg: posix.msghdr_const = msg: { + var i: usize = 0; + if (buffered.len != 0) { + iovecs[i] = .{ + .base = buffered.ptr, + .len = buffered.len, + }; + i += 1; + } + for (data[0..data.len]) |bytes| { + // OS checks ptr addr before length so zero length vectors must be omitted. + if (bytes.len == 0) continue; + iovecs[i] = .{ + .base = bytes.ptr, + .len = bytes.len, + }; + i += 1; + if (iovecs.len - i == 0) break; + } + break :msg .{ + .name = null, + .namelen = 0, + .iov = &iovecs, + .iovlen = i, + .control = null, + .controllen = 0, + .flags = 0, + }; + }; + const pattern = data[data.len - 1]; + switch (splat) { + 0 => msg.iovlen -= 1, + 1 => {}, + else => switch (pattern.len) { + 0 => {}, + 1 => { + // Replace the 1-byte buffer with a bigger one. + const memset_len = @min(splat_buffer.len, splat); + const buf = splat_buffer[0..memset_len]; + @memset(buf, pattern[0]); + iovecs[msg.iovlen - 1] = .{ .base = buf.ptr, .len = buf.len }; + var remaining_splat = splat - buf.len; + while (remaining_splat > splat_buffer.len and msg.iovlen < iovecs.len) { + iovecs[msg.iovlen] = .{ .base = &splat_buffer, .len = splat_buffer.len }; + remaining_splat -= splat_buffer.len; + msg.iovlen += 1; + } + if (remaining_splat > 0 and msg.iovlen < iovecs.len) { + iovecs[msg.iovlen] = .{ .base = &splat_buffer, .len = remaining_splat }; + msg.iovlen += 1; + } + }, + else => for (0..splat - 1) |_| { + if (iovecs.len - msg.iovlen == 0) break; + iovecs[msg.iovlen] = .{ + .base = pattern.ptr, + .len = pattern.len, + }; + msg.iovlen += 1; + }, + }, + } + const flags = posix.MSG.NOSIGNAL; + return io_w.consume(std.posix.sendmsg(w.file_writer.file.handle, &msg, flags) catch |err| { + w.err = err; + return error.WriteFailed; + }); + } + + fn sendFile(io_w: *io.Writer, file_reader: *File.Reader, limit: io.Limit) io.Writer.FileError!usize { + const w: *Writer = @fieldParentPtr("interface", io_w); + return io_w.sendFileTo(&w.file_writer.interface, file_reader, limit); + } }, }; - pub fn reader(stream: Stream) Reader { - return switch (native_os) { - .windows => .{ .stream = stream }, - else => .{ .file_reader = .{ - .file = .{ .handle = stream.handle }, - .mode = .streaming, - .seek_err = error.Unseekable, - } }, - }; + pub fn reader(stream: Stream, buffer: []u8) Reader { + return .init(stream, buffer); } - pub fn writer(stream: Stream) Writer { - return switch (native_os) { - .windows => .{ .stream = stream }, - else => .{ .file_writer = .{ - .file = .{ .handle = stream.handle }, - .mode = .streaming, - } }, - }; + pub fn writer(stream: Stream, buffer: []u8) Writer { + return .init(stream, buffer); } const max_buffers_len = 8; + const splat_buffer_len = 256; }; pub const Server = struct {