std.net.Stream: fix interface() infinite recursion

This commit is contained in:
Andrew Kelley 2025-04-20 14:20:25 -07:00
parent e9fd9798f4
commit d7b081882a

View File

@ -1829,276 +1829,277 @@ pub const Stream = struct {
}
}
pub const Reader = struct {
impl: switch (native_os) {
.windows => Stream,
else => struct {
fr: std.fs.File.Reader,
err: Error!void,
},
},
const ReadError = posix.ReadError;
pub const Error = posix.ReadError;
pub fn interface(r: *Reader) std.io.Reader {
return switch (native_os) {
.windows => .{
.context = r.impl.stream.handle,
.vtable = &.{
.read = windows_read,
.readVec = windows_readVec,
.discard = windows_discard,
},
},
else => r.interface(),
};
}
fn windows_read(
context: ?*anyopaque,
bw: *std.io.BufferedWriter,
limit: std.io.Reader.Limit,
) std.io.Reader.Error!usize {
const buf = limit.slice(try bw.writableSliceGreedy(1));
const status = try windows_readVec(context, &.{buf});
bw.advance(status.len);
return status;
}
fn windows_readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
var iovecs: [max_buffers_len]windows.WSABUF = undefined;
var iovecs_i: usize = 0;
for (data) |d| {
// In case Windows checks pointer address before length, we must omit
// length-zero vectors.
if (d.len == 0) continue;
iovecs[iovecs_i] = .{ .buf = d.ptr, .len = d.len };
iovecs_i += 1;
if (iovecs_i >= iovecs.len) break;
}
const bufs = iovecs[0..iovecs_i];
if (bufs.len == 0) return .{}; // Prevent false positive end detection on empty `data`.
var n: u32 = undefined;
var flags: u32 = 0;
const rc = windows.ws2_32.WSARecvFrom(context, bufs.ptr, bufs.len, &n, &flags, null, null, null, null);
if (rc != 0) switch (windows.ws2_32.WSAGetLastError()) {
.WSAECONNRESET => return error.ConnectionResetByPeer,
.WSAEFAULT => unreachable, // a pointer is not completely contained in user address space.
.WSAEINPROGRESS, .WSAEINTR => unreachable, // deprecated and removed in WSA 2.2
.WSAEINVAL => return error.SocketNotBound,
.WSAEMSGSIZE => return error.MessageTooBig,
.WSAENETDOWN => return error.NetworkSubsystemFailed,
.WSAENETRESET => return error.ConnectionResetByPeer,
.WSAENOTCONN => return error.SocketNotConnected,
.WSAEWOULDBLOCK => return error.WouldBlock,
.WSANOTINITIALISED => unreachable, // WSAStartup must be called before this function
.WSA_IO_PENDING => unreachable, // not using overlapped I/O
.WSA_OPERATION_ABORTED => unreachable, // not using overlapped I/O
else => |err| return windows.unexpectedWSAError(err),
};
return .{ .len = n, .end = n == 0 };
}
fn windows_discard(context: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize {
_ = context;
_ = limit;
@panic("TODO");
}
const WriteError = posix.SendMsgError || error{
ConnectionResetByPeer,
SocketNotBound,
MessageTooBig,
NetworkSubsystemFailed,
SystemResources,
SocketNotConnected,
Unexpected,
};
pub const Writer = struct {
impl: switch (native_os) {
.windows => Stream,
else => PosixImpl,
},
pub const Reader = switch (native_os) {
.windows => struct {
stream: Stream,
err: ?Error = null,
const PosixImpl = struct {
fw: std.fs.File.Writer,
err: Error!void,
};
pub const Error = ReadError;
pub const Error = posix.SendMsgError || error{
ConnectionResetByPeer,
SocketNotBound,
MessageTooBig,
NetworkSubsystemFailed,
SystemResources,
SocketNotConnected,
Unexpected,
};
pub fn interface(w: *Writer) std.io.Writer {
return switch (native_os) {
.windows => .{
.context = w.impl.stream.handle,
pub fn interface(r: *Reader) std.io.Reader {
return .{
.context = r.stream.handle,
.vtable = &.{
.writeSplat = windows_writeSplat,
.writeFile = windows_writeFile,
.read = read,
.readVec = readVec,
.discard = discard,
},
},
else => .{
.context = &w.impl,
};
}
fn read(
context: ?*anyopaque,
bw: *std.io.BufferedWriter,
limit: std.io.Reader.Limit,
) std.io.Reader.Error!usize {
const buf = limit.slice(try bw.writableSliceGreedy(1));
const status = try readVec(context, &.{buf});
bw.advance(status.len);
return status;
}
fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
var iovecs: [max_buffers_len]windows.WSABUF = undefined;
var iovecs_i: usize = 0;
for (data) |d| {
// In case Windows checks pointer address before length, we must omit
// length-zero vectors.
if (d.len == 0) continue;
iovecs[iovecs_i] = .{ .buf = d.ptr, .len = d.len };
iovecs_i += 1;
if (iovecs_i >= iovecs.len) break;
}
const bufs = iovecs[0..iovecs_i];
if (bufs.len == 0) return .{}; // Prevent false positive end detection on empty `data`.
var n: u32 = undefined;
var flags: u32 = 0;
const rc = windows.ws2_32.WSARecvFrom(context, bufs.ptr, bufs.len, &n, &flags, null, null, null, null);
if (rc != 0) switch (windows.ws2_32.WSAGetLastError()) {
.WSAECONNRESET => return error.ConnectionResetByPeer,
.WSAEFAULT => unreachable, // a pointer is not completely contained in user address space.
.WSAEINPROGRESS, .WSAEINTR => unreachable, // deprecated and removed in WSA 2.2
.WSAEINVAL => return error.SocketNotBound,
.WSAEMSGSIZE => return error.MessageTooBig,
.WSAENETDOWN => return error.NetworkSubsystemFailed,
.WSAENETRESET => return error.ConnectionResetByPeer,
.WSAENOTCONN => return error.SocketNotConnected,
.WSAEWOULDBLOCK => return error.WouldBlock,
.WSANOTINITIALISED => unreachable, // WSAStartup must be called before this function
.WSA_IO_PENDING => unreachable, // not using overlapped I/O
.WSA_OPERATION_ABORTED => unreachable, // not using overlapped I/O
else => |err| return windows.unexpectedWSAError(err),
};
return .{ .len = n, .end = n == 0 };
}
fn discard(context: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize {
_ = context;
_ = limit;
@panic("TODO");
}
},
else => struct {
file_reader: std.fs.File.Reader,
pub const Error = ReadError;
pub fn interface(r: *Reader) std.io.Reader {
return r.file_reader.interface();
}
},
};
pub const Writer = switch (native_os) {
.windows => struct {
stream: Stream,
pub const Error = WriteError;
pub fn interface(w: *Writer) std.io.Writer {
return .{
.context = w.stream.handle,
.vtable = &.{
.writeSplat = posix_writeSplat,
.writeSplat = writeSplat,
.writeFile = writeFile,
},
};
}
fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
comptime assert(native_os == .windows);
if (data.len == 1 and splat == 0) return 0;
var splat_buffer: [256]u8 = undefined;
var iovecs: [max_buffers_len]windows.WSABUF = undefined;
var len: u32 = @min(iovecs.len, data.len);
for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{
.buf = if (d.len == 0) "" else d.ptr, // TODO: does Windows allow ptr=undefined len=0 ?
.len = d.len,
};
switch (splat) {
0 => len -= 1,
1 => {},
else => {
const pattern = data[data.len - 1];
if (pattern.len == 1) {
const memset_len = @min(splat_buffer.len, splat);
const buf = splat_buffer[0..memset_len];
@memset(buf, pattern[0]);
iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len };
var remaining_splat = splat - buf.len;
while (remaining_splat > splat_buffer.len and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len };
remaining_splat -= splat_buffer.len;
len += 1;
}
if (remaining_splat > 0 and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat };
len += 1;
}
}
},
}
var n: u32 = undefined;
const rc = windows.ws2_32.WSASend(context, &iovecs, len, &n, 0, null, null);
if (rc == windows.ws2_32.SOCKET_ERROR) switch (windows.ws2_32.WSAGetLastError()) {
.WSAECONNABORTED => return error.ConnectionResetByPeer,
.WSAECONNRESET => return error.ConnectionResetByPeer,
.WSAEFAULT => unreachable, // a pointer is not completely contained in user address space.
.WSAEINPROGRESS, .WSAEINTR => unreachable, // deprecated and removed in WSA 2.2
.WSAEINVAL => return error.SocketNotBound,
.WSAEMSGSIZE => return error.MessageTooBig,
.WSAENETDOWN => return error.NetworkSubsystemFailed,
.WSAENETRESET => return error.ConnectionResetByPeer,
.WSAENOBUFS => return error.SystemResources,
.WSAENOTCONN => return error.SocketNotConnected,
.WSAENOTSOCK => unreachable, // not a socket
.WSAEOPNOTSUPP => unreachable, // only for message-oriented sockets
.WSAESHUTDOWN => unreachable, // cannot send on a socket after write shutdown
.WSAEWOULDBLOCK => return error.WouldBlock,
.WSANOTINITIALISED => unreachable, // WSAStartup must be called before this function
.WSA_IO_PENDING => unreachable, // not using overlapped I/O
.WSA_OPERATION_ABORTED => unreachable, // not using overlapped I/O
else => |err| return windows.unexpectedWSAError(err),
};
return n;
}
fn writeFile(
context: *anyopaque,
in_file: std.fs.File,
in_offset: u64,
in_len: std.io.Writer.FileLen,
headers_and_trailers: []const []const u8,
headers_len: usize,
) std.io.Writer.FileError!usize {
const len_int = switch (in_len) {
.zero => return writeSplat(context, headers_and_trailers, 1),
.entire_file => std.math.maxInt(usize),
else => in_len.int(),
};
if (headers_len > 0) return writeSplat(context, headers_and_trailers[0..headers_len], 1);
var file_contents_buffer: [4096]u8 = undefined;
const read_buffer = file_contents_buffer[0..@min(file_contents_buffer.len, len_int)];
const n = try windows.ReadFile(in_file.handle, read_buffer, in_offset);
return writeSplat(context, &.{read_buffer[0..n]}, 1);
}
},
else => struct {
file_writer: std.fs.File.Writer,
err: ?Error = null,
pub const Error = WriteError;
pub fn interface(w: *Writer) std.io.Writer {
return .{
.context = &w.file_writer,
.vtable = &.{
.writeSplat = writeSplat,
.writeFile = std.fs.File.Writer.writeFile,
},
},
};
}
fn windows_writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
comptime assert(native_os == .windows);
if (data.len == 1 and splat == 0) return 0;
var splat_buffer: [256]u8 = undefined;
var iovecs: [max_buffers_len]windows.WSABUF = undefined;
var len: u32 = @min(iovecs.len, data.len);
for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{
.buf = if (d.len == 0) "" else d.ptr, // TODO: does Windows allow ptr=undefined len=0 ?
.len = d.len,
};
switch (splat) {
0 => len -= 1,
1 => {},
else => {
const pattern = data[data.len - 1];
if (pattern.len == 1) {
const memset_len = @min(splat_buffer.len, splat);
const buf = splat_buffer[0..memset_len];
@memset(buf, pattern[0]);
iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len };
var remaining_splat = splat - buf.len;
while (remaining_splat > splat_buffer.len and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len };
remaining_splat -= splat_buffer.len;
len += 1;
}
if (remaining_splat > 0 and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat };
len += 1;
}
}
},
};
}
var n: u32 = undefined;
const rc = windows.ws2_32.WSASend(context, &iovecs, len, &n, 0, null, null);
if (rc == windows.ws2_32.SOCKET_ERROR) switch (windows.ws2_32.WSAGetLastError()) {
.WSAECONNABORTED => return error.ConnectionResetByPeer,
.WSAECONNRESET => return error.ConnectionResetByPeer,
.WSAEFAULT => unreachable, // a pointer is not completely contained in user address space.
.WSAEINPROGRESS, .WSAEINTR => unreachable, // deprecated and removed in WSA 2.2
.WSAEINVAL => return error.SocketNotBound,
.WSAEMSGSIZE => return error.MessageTooBig,
.WSAENETDOWN => return error.NetworkSubsystemFailed,
.WSAENETRESET => return error.ConnectionResetByPeer,
.WSAENOBUFS => return error.SystemResources,
.WSAENOTCONN => return error.SocketNotConnected,
.WSAENOTSOCK => unreachable, // not a socket
.WSAEOPNOTSUPP => unreachable, // only for message-oriented sockets
.WSAESHUTDOWN => unreachable, // cannot send on a socket after write shutdown
.WSAEWOULDBLOCK => return error.WouldBlock,
.WSANOTINITIALISED => unreachable, // WSAStartup must be called before this function
.WSA_IO_PENDING => unreachable, // not using overlapped I/O
.WSA_OPERATION_ABORTED => unreachable, // not using overlapped I/O
else => |err| return windows.unexpectedWSAError(err),
};
return n;
}
fn posix_writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
const fw: *std.fs.File.Writer = @alignCast(@ptrCast(context));
const impl: *PosixImpl = @fieldParentPtr("fw", fw);
comptime assert(native_os != .windows);
var splat_buffer: [256]u8 = undefined;
var iovecs: [max_buffers_len]std.posix.iovec_const = undefined;
var len: usize = @min(iovecs.len, data.len);
for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{
.base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length.
.len = d.len,
};
var msg: posix.msghdr_const = .{
.name = null,
.namelen = 0,
.iov = &iovecs,
.iovlen = len,
.control = null,
.controllen = 0,
.flags = 0,
};
switch (splat) {
0 => msg.iovlen = len - 1,
1 => {},
else => {
const pattern = data[data.len - 1];
if (pattern.len == 1) {
const memset_len = @min(splat_buffer.len, splat);
const buf = splat_buffer[0..memset_len];
@memset(buf, pattern[0]);
iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len };
var remaining_splat = splat - buf.len;
while (remaining_splat > splat_buffer.len and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len };
remaining_splat -= splat_buffer.len;
len += 1;
fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
const fw: *std.fs.File.Writer = @alignCast(@ptrCast(context));
const w: *Writer = @fieldParentPtr("file_writer", fw);
var splat_buffer: [256]u8 = undefined;
var iovecs: [max_buffers_len]std.posix.iovec_const = undefined;
var len: usize = @min(iovecs.len, data.len);
for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{
.base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length.
.len = d.len,
};
var msg: posix.msghdr_const = .{
.name = null,
.namelen = 0,
.iov = &iovecs,
.iovlen = len,
.control = null,
.controllen = 0,
.flags = 0,
};
switch (splat) {
0 => msg.iovlen = len - 1,
1 => {},
else => {
const pattern = data[data.len - 1];
if (pattern.len == 1) {
const memset_len = @min(splat_buffer.len, splat);
const buf = splat_buffer[0..memset_len];
@memset(buf, pattern[0]);
iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len };
var remaining_splat = splat - buf.len;
while (remaining_splat > splat_buffer.len and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len };
remaining_splat -= splat_buffer.len;
len += 1;
}
if (remaining_splat > 0 and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat };
len += 1;
}
msg.iovlen = len;
}
if (remaining_splat > 0 and len < iovecs.len) {
iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat };
len += 1;
}
msg.iovlen = len;
}
},
},
}
const flags = posix.MSG.NOSIGNAL;
return std.posix.sendmsg(fw.file.handle, &msg, flags) catch |err| {
w.err = err;
return error.WriteFailed;
};
}
const flags = posix.MSG.NOSIGNAL;
return std.posix.sendmsg(fw.file.handle, &msg, flags) catch |err| {
impl.err = err;
return error.WriteFailed;
};
}
fn windows_writeFile(
context: *anyopaque,
in_file: std.fs.File,
in_offset: u64,
in_len: std.io.Writer.FileLen,
headers_and_trailers: []const []const u8,
headers_len: usize,
) std.io.Writer.FileError!usize {
const len_int = switch (in_len) {
.zero => return windows_writeSplat(context, headers_and_trailers, 1),
.entire_file => std.math.maxInt(usize),
else => in_len.int(),
};
if (headers_len > 0) return windows_writeSplat(context, headers_and_trailers[0..headers_len], 1);
var file_contents_buffer: [4096]u8 = undefined;
const read_buffer = file_contents_buffer[0..@min(file_contents_buffer.len, len_int)];
const n = try windows.ReadFile(in_file.handle, read_buffer, in_offset);
return windows_writeSplat(context, &.{read_buffer[0..n]}, 1);
}
},
};
pub fn reader(stream: Stream) Reader {
return switch (native_os) {
.windows => .{ .impl = stream },
else => .{ .impl = .{
.fr = .{
.file = .{ .handle = stream.handle },
.mode = .streaming,
.seek_err = error.Unseekable,
},
.err = {},
.windows => .{ .stream = stream },
else => .{ .file_reader = .{
.file = .{ .handle = stream.handle },
.mode = .streaming,
.seek_err = error.Unseekable,
} },
};
}
pub fn writer(stream: Stream) Writer {
return switch (native_os) {
.windows => .{ .impl = stream },
else => .{ .impl = .{
.fw = .{
.file = .{ .handle = stream.handle },
.mode = .streaming,
},
.err = {},
.windows => .{ .stream = stream },
else => .{ .file_writer = .{
.file = .{ .handle = stream.handle },
.mode = .streaming,
} },
};
}