std.net: update to new IO API

This commit is contained in:
Andrew Kelley 2025-06-03 11:32:16 -07:00
parent 1e2aab2f97
commit 9d163c7ac3
4 changed files with 219 additions and 164 deletions

View File

@ -1213,6 +1213,10 @@ pub const Writer = struct {
const max_buffers_len = 16;
pub fn init(file: File, buffer: []u8) std.io.Writer {
return initMode(file, buffer, .positional);
}
pub fn initMode(file: File, buffer: []u8, init_mode: Writer.Mode) std.io.Writer {
return .{
.file = file,
.interface = .{
@ -1223,6 +1227,7 @@ pub const Writer = struct {
},
.buffer = buffer,
},
.mode = init_mode,
};
}

View File

@ -1392,19 +1392,19 @@ pub fn Hashed(comptime Hasher: type) type {
const this: *@This() = @alignCast(@fieldParentPtr("interface", r));
const data = w.writableVector(limit);
const n = try this.in.readVec(data);
w.advanceVector(n);
const result = w.advanceVector(n);
var remaining: usize = n;
for (data) |slice| {
if (remaining < slice.len) {
this.hasher.update(slice[0..remaining]);
return n;
return result;
} else {
remaining -= slice.len;
this.hasher.update(slice);
}
}
assert(remaining == 0);
return n;
return result;
}
fn discard(r: *Reader, limit: Limit) Error!usize {

View File

@ -1672,16 +1672,7 @@ pub fn discardingDrain(w: *Writer, data: []const []const u8, splat: usize) Error
pub fn discardingSendFile(w: *Writer, file_reader: *File.Reader, limit: Limit) FileError!usize {
if (File.Handle == void) return error.Unimplemented;
if (w.end != 0) {
if (@intFromEnum(limit) >= w.end) {
w.end = 0;
} else {
const remaining = w.buffer[@intFromEnum(limit)..w.end];
@memmove(w.buffer[0..remaining.len], remaining);
w.end = remaining.len;
}
return 0;
}
w.end = 0;
if (file_reader.getSize()) |size| {
const n = limit.minInt(size - file_reader.pos);
file_reader.seekBy(@intCast(n)) catch return error.Unimplemented;
@ -1694,6 +1685,19 @@ pub fn discardingSendFile(w: *Writer, file_reader: *File.Reader, limit: Limit) F
}
}
/// This function is used by `VTable.drain` function implementations to
/// implement partial drains.
pub fn consume(w: *Writer, n: usize) usize {
if (n < w.end) {
const remaining = w.buffer[n..w.end];
@memmove(w.buffer[0..remaining.len], remaining);
w.end = remaining.len;
return 0;
}
defer w.end = 0;
return n - w.end;
}
/// For use when the `Writer` implementation can cannot offer a more efficient
/// implementation than a basic read/write loop on the file.
pub fn unimplementedSendFile(w: *Writer, file_reader: *File.Reader, limit: Limit) FileError!usize {
@ -1768,13 +1772,14 @@ pub fn Hashed(comptime Hasher: type) type {
fn drain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
const this: *@This() = @alignCast(@fieldParentPtr("interface", w));
const aux_n = try this.out.writeSplatAux(w.buffered(), data, splat);
if (aux_n <= w.end) {
if (aux_n < w.end) {
this.hasher.update(w.buffer[0..aux_n]);
const remaining = w.buffer[aux_n..w.end];
@memmove(w.buffer[0..remaining.len], remaining);
w.end = remaining.len;
return 0;
}
this.hasher.update(w.buffered());
const n = aux_n - w.end;
w.end = 0;
var remaining: usize = n;

View File

@ -13,6 +13,7 @@ const native_os = builtin.os.tag;
const windows = std.os.windows;
const Allocator = std.mem.Allocator;
const ArrayList = std.ArrayListUnmanaged;
const File = std.fs.File;
// Windows 10 added support for unix sockets in build 17063, redstone 4 is the
// first release to support them.
@ -853,7 +854,7 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream {
// TODO: Instead of having a massive error set, make the error set have categories, and then
// store the sub-error as a diagnostic value.
const GetAddressListError = Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || posix.SocketError || posix.BindError || posix.SetSockOptError || error{
const GetAddressListError = Allocator.Error || File.OpenError || File.ReadError || posix.SocketError || posix.BindError || posix.SetSockOptError || error{
TemporaryNameServerFailure,
NameServerFailure,
AddressFamilyNotSupported,
@ -1363,9 +1364,8 @@ fn linuxLookupNameFromHosts(
defer file.close();
var line_buf: [512]u8 = undefined;
var file_reader = file.reader();
var br = file_reader.interface().buffered(&line_buf);
return parseHosts(gpa, addrs, canon, name, family, port, &br) catch |err| switch (err) {
var file_reader = file.reader(&line_buf);
return parseHosts(gpa, addrs, canon, name, family, port, &file_reader.interface) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
error.ReadFailed => return file_reader.err.?,
};
@ -1378,7 +1378,7 @@ fn parseHosts(
name: []const u8,
family: posix.sa_family_t,
port: u16,
br: *std.io.Reader,
br: *io.Reader,
) error{ OutOfMemory, ReadFailed }!void {
while (true) {
const line = br.takeDelimiterExclusive('\n') catch |err| switch (err) {
@ -1584,15 +1584,14 @@ const ResolvConf = struct {
defer file.close();
var line_buf: [512]u8 = undefined;
var file_reader = file.reader();
var br = file_reader.interface().buffered(&line_buf);
return parse(rc, &br) catch |err| switch (err) {
var file_reader = file.reader(&line_buf);
return parse(rc, &file_reader.interface) catch |err| switch (err) {
error.ReadFailed => return file_reader.err.?,
else => |e| return e,
};
}
fn parse(rc: *ResolvConf, br: *std.io.Reader) !void {
fn parse(rc: *ResolvConf, br: *io.Reader) !void {
const gpa = rc.gpa;
while (br.takeSentinel('\n')) |line_with_comment| {
const line = line: {
@ -1893,7 +1892,10 @@ pub const Stream = struct {
pub const Reader = switch (native_os) {
.windows => struct {
stream: Stream,
/// Use `interface` to access portably.
interface_state: io.Reader,
/// Use `getStream` to access portably.
net_stream: Stream,
err: ?Error = null,
pub const Error = ReadError;
@ -1902,29 +1904,27 @@ pub const Stream = struct {
return r.stream;
}
pub fn interface(r: *Reader) std.io.Reader {
pub fn interface(r: *Reader) *io.Reader {
return &r.interface_state;
}
pub fn init(net_stream: Stream, buffer: []u8) Reader {
return .{
.context = r.stream.handle,
.vtable = &.{
.read = read,
.readVec = readVec,
.discard = discard,
.interface_state = .{
.context = undefined,
.vtable = &.{
.stream = stream,
.discard = discard,
},
.buffer = buffer,
},
.net_stream = net_stream,
};
}
fn read(
context: ?*anyopaque,
bw: *std.io.Writer,
limit: std.io.Limit,
) std.io.Reader.Error!usize {
const buf = limit.slice(try bw.writableSliceGreedy(1));
const n = try readVec(context, &.{buf});
bw.advance(n);
return n;
}
fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
fn stream(io_r: *io.Reader, io_w: *io.Writer, limit: io.Limit) io.Reader.StreamError!usize {
const r: *Reader = @fieldParentPtr("interface", io_r);
const data = io_w.writableVector(limit);
var iovecs: [max_buffers_len]windows.WSABUF = undefined;
var iovecs_i: usize = 0;
for (data) |d| {
@ -1939,7 +1939,7 @@ pub const Stream = struct {
if (bufs.len == 0) return .{}; // Prevent false positive end detection on empty `data`.
var n: u32 = undefined;
var flags: u32 = 0;
const rc = windows.ws2_32.WSARecvFrom(context, bufs.ptr, bufs.len, &n, &flags, null, null, null, null);
const rc = windows.ws2_32.WSARecvFrom(r.net_stream.handle, bufs.ptr, bufs.len, &n, &flags, null, null, null, null);
if (rc != 0) switch (windows.ws2_32.WSAGetLastError()) {
.WSAECONNRESET => return error.ConnectionResetByPeer,
.WSAEFAULT => unreachable, // a pointer is not completely contained in user address space.
@ -1955,22 +1955,35 @@ pub const Stream = struct {
.WSA_OPERATION_ABORTED => unreachable, // not using overlapped I/O
else => |err| return windows.unexpectedWSAError(err),
};
return .{ .len = n, .end = n == 0 };
if (n == 0) return error.EndOfStream;
return io_w.advanceVector(n);
}
fn discard(context: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize {
_ = context;
fn discard(io_r: *io.Reader, limit: io.Limit) io.Reader.Error!usize {
const r: *Reader = @fieldParentPtr("interface", io_r);
_ = r;
_ = limit;
@panic("TODO");
}
},
else => struct {
file_reader: std.fs.File.Reader,
file_reader: File.Reader,
pub const Error = ReadError;
pub fn interface(r: *Reader) std.io.Reader {
return r.file_reader.interface();
pub fn interface(r: *Reader) *io.Reader {
return &r.file_reader.interface;
}
pub fn init(net_stream: Stream, buffer: []u8) Reader {
return .{
.file_reader = .{
.interface = File.Reader.initInterface(buffer),
.file = .{ .handle = net_stream.handle },
.mode = .streaming,
.seek_err = error.Unseekable,
},
};
}
pub fn getStream(r: *const Reader) Stream {
@ -1981,16 +1994,22 @@ pub const Stream = struct {
pub const Writer = switch (native_os) {
.windows => struct {
/// This field is present on all systems.
interface: io.Writer,
/// Use `getStream` for cross-platform support.
stream: Stream,
pub const Error = WriteError;
pub fn interface(w: *Writer) std.io.Writer {
pub fn init(stream: Stream, buffer: []u8) Writer {
return .{
.context = w.stream.handle,
.vtable = &.{
.writeSplat = writeSplat,
.writeFile = writeFile,
.stream = stream,
.interface = .{
.context = undefined,
.vtable = &.{
.drain = drain,
},
.buffer = buffer,
},
};
}
@ -1999,41 +2018,63 @@ pub const Stream = struct {
return w.stream;
}
fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
fn drain(io_w: *io.Writer, data: []const []const u8, splat: usize) io.Writer.Error!usize {
const w: *Writer = @fieldParentPtr("interface", io_w);
const buffered = io_w.buffered();
comptime assert(native_os == .windows);
if (data.len == 1 and splat == 0) return 0;
var splat_buffer: [256]u8 = undefined;
var splat_buffer: [splat_buffer_len]u8 = undefined;
var iovecs: [max_buffers_len]windows.WSABUF = undefined;
var len: u32 = @min(iovecs.len, data.len);
for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{
.buf = if (d.len == 0) "" else d.ptr, // TODO: does Windows allow ptr=undefined len=0 ?
.len = d.len,
};
var len: u32 = 0;
if (buffered.len != 0) {
iovecs[len] = .{
.base = buffered.ptr,
.len = buffered.len,
};
len += 1;
}
for (data[0..data.len]) |bytes| {
if (bytes.len == 0) continue;
iovecs[len] = .{
.buf = bytes.ptr,
.len = bytes.len,
};
len += 1;
}
const pattern = data[data.len - 1];
switch (splat) {
0 => len -= 1,
1 => {},
else => {
const pattern = data[data.len - 1];
if (pattern.len == 1) {
else => switch (pattern.len) {
0 => {},
1 => {
// Replace the 1-byte buffer with a bigger one.
const memset_len = @min(splat_buffer.len, splat);
const buf = splat_buffer[0..memset_len];
@memset(buf, pattern[0]);
iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len };
iovecs[len - 1] = .{ .buf = buf.ptr, .len = buf.len };
var remaining_splat = splat - buf.len;
while (remaining_splat > splat_buffer.len and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len };
iovecs[len] = .{ .buf = &splat_buffer, .len = splat_buffer.len };
remaining_splat -= splat_buffer.len;
len += 1;
}
if (remaining_splat > 0 and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat };
iovecs[len] = .{ .buf = &splat_buffer, .len = remaining_splat };
len += 1;
}
}
},
else => for (0..splat - 1) |_| {
if (iovecs.len - len == 0) break;
iovecs[len] = .{
.buf = pattern.ptr,
.len = pattern.len,
};
len += 1;
},
},
}
var n: u32 = undefined;
const rc = windows.ws2_32.WSASend(context, &iovecs, len, &n, 0, null, null);
const rc = windows.ws2_32.WSASend(w.stream.handle, &iovecs, len, &n, 0, null, null);
if (rc == windows.ws2_32.SOCKET_ERROR) switch (windows.ws2_32.WSAGetLastError()) {
.WSAECONNABORTED => return error.ConnectionResetByPeer,
.WSAECONNRESET => return error.ConnectionResetByPeer,
@ -2054,123 +2095,127 @@ pub const Stream = struct {
.WSA_OPERATION_ABORTED => unreachable, // not using overlapped I/O
else => |err| return windows.unexpectedWSAError(err),
};
return n;
}
fn writeFile(
context: *anyopaque,
in_file: std.fs.File,
in_offset: u64,
in_len: std.io.Writer.FileLen,
headers_and_trailers: []const []const u8,
headers_len: usize,
) std.io.Writer.FileError!usize {
const len_int = switch (in_len) {
.zero => return writeSplat(context, headers_and_trailers, 1),
.entire_file => std.math.maxInt(usize),
else => in_len.int(),
};
if (headers_len > 0) return writeSplat(context, headers_and_trailers[0..headers_len], 1);
var file_contents_buffer: [4096]u8 = undefined;
const read_buffer = file_contents_buffer[0..@min(file_contents_buffer.len, len_int)];
const n = try windows.ReadFile(in_file.handle, read_buffer, in_offset);
return writeSplat(context, &.{read_buffer[0..n]}, 1);
return io_w.consume(n);
}
},
else => struct {
file_writer: std.fs.File.Writer,
/// This field is present on all systems.
interface: io.Writer,
err: ?Error = null,
file_writer: File.Writer,
pub const Error = WriteError;
pub fn interface(w: *Writer) std.io.Writer {
pub fn init(stream: Stream, buffer: []u8) Writer {
return .{
.context = &w.file_writer,
.vtable = &.{
.writeSplat = writeSplat,
.writeFile = std.fs.File.Writer.writeFile,
.interface = .{
.context = undefined,
.vtable = &.{
.drain = drain,
.sendFile = sendFile,
},
.buffer = buffer,
},
};
}
fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
const fw: *std.fs.File.Writer = @alignCast(@ptrCast(context));
const w: *Writer = @fieldParentPtr("file_writer", fw);
var splat_buffer: [256]u8 = undefined;
var iovecs: [max_buffers_len]std.posix.iovec_const = undefined;
var len: usize = @min(iovecs.len, data.len);
for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{
.base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length.
.len = d.len,
};
var msg: posix.msghdr_const = .{
.name = null,
.namelen = 0,
.iov = &iovecs,
.iovlen = len,
.control = null,
.controllen = 0,
.flags = 0,
};
switch (splat) {
0 => msg.iovlen = len - 1,
1 => {},
else => {
const pattern = data[data.len - 1];
if (pattern.len == 1) {
const memset_len = @min(splat_buffer.len, splat);
const buf = splat_buffer[0..memset_len];
@memset(buf, pattern[0]);
iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len };
var remaining_splat = splat - buf.len;
while (remaining_splat > splat_buffer.len and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len };
remaining_splat -= splat_buffer.len;
len += 1;
}
if (remaining_splat > 0 and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat };
len += 1;
}
msg.iovlen = len;
}
},
}
const flags = posix.MSG.NOSIGNAL;
return std.posix.sendmsg(fw.file.handle, &msg, flags) catch |err| {
w.err = err;
return error.WriteFailed;
.file_writer = .initMode(stream.handle, &.{}, .streaming),
};
}
pub fn getStream(w: *const Writer) Stream {
return .{ .handle = w.file_writer.file.handle };
}
fn drain(io_w: *io.Writer, data: []const []const u8, splat: usize) io.Writer.Error!usize {
const w: *Writer = @fieldParentPtr("interface", io_w);
const buffered = io_w.buffered();
var splat_buffer: [splat_buffer_len]u8 = undefined;
var iovecs: [max_buffers_len]std.posix.iovec_const = undefined;
var msg: posix.msghdr_const = msg: {
var i: usize = 0;
if (buffered.len != 0) {
iovecs[i] = .{
.base = buffered.ptr,
.len = buffered.len,
};
i += 1;
}
for (data[0..data.len]) |bytes| {
// OS checks ptr addr before length so zero length vectors must be omitted.
if (bytes.len == 0) continue;
iovecs[i] = .{
.base = bytes.ptr,
.len = bytes.len,
};
i += 1;
if (iovecs.len - i == 0) break;
}
break :msg .{
.name = null,
.namelen = 0,
.iov = &iovecs,
.iovlen = i,
.control = null,
.controllen = 0,
.flags = 0,
};
};
const pattern = data[data.len - 1];
switch (splat) {
0 => msg.iovlen -= 1,
1 => {},
else => switch (pattern.len) {
0 => {},
1 => {
// Replace the 1-byte buffer with a bigger one.
const memset_len = @min(splat_buffer.len, splat);
const buf = splat_buffer[0..memset_len];
@memset(buf, pattern[0]);
iovecs[msg.iovlen - 1] = .{ .base = buf.ptr, .len = buf.len };
var remaining_splat = splat - buf.len;
while (remaining_splat > splat_buffer.len and msg.iovlen < iovecs.len) {
iovecs[msg.iovlen] = .{ .base = &splat_buffer, .len = splat_buffer.len };
remaining_splat -= splat_buffer.len;
msg.iovlen += 1;
}
if (remaining_splat > 0 and msg.iovlen < iovecs.len) {
iovecs[msg.iovlen] = .{ .base = &splat_buffer, .len = remaining_splat };
msg.iovlen += 1;
}
},
else => for (0..splat - 1) |_| {
if (iovecs.len - msg.iovlen == 0) break;
iovecs[msg.iovlen] = .{
.base = pattern.ptr,
.len = pattern.len,
};
msg.iovlen += 1;
},
},
}
const flags = posix.MSG.NOSIGNAL;
return io_w.consume(std.posix.sendmsg(w.file_writer.file.handle, &msg, flags) catch |err| {
w.err = err;
return error.WriteFailed;
});
}
fn sendFile(io_w: *io.Writer, file_reader: *File.Reader, limit: io.Limit) io.Writer.FileError!usize {
const w: *Writer = @fieldParentPtr("interface", io_w);
return io_w.sendFileTo(&w.file_writer.interface, file_reader, limit);
}
},
};
pub fn reader(stream: Stream) Reader {
return switch (native_os) {
.windows => .{ .stream = stream },
else => .{ .file_reader = .{
.file = .{ .handle = stream.handle },
.mode = .streaming,
.seek_err = error.Unseekable,
} },
};
pub fn reader(stream: Stream, buffer: []u8) Reader {
return .init(stream, buffer);
}
pub fn writer(stream: Stream) Writer {
return switch (native_os) {
.windows => .{ .stream = stream },
else => .{ .file_writer = .{
.file = .{ .handle = stream.handle },
.mode = .streaming,
} },
};
pub fn writer(stream: Stream, buffer: []u8) Writer {
return .init(stream, buffer);
}
const max_buffers_len = 8;
const splat_buffer_len = 256;
};
pub const Server = struct {