Merge pull request #7124 from LemonBoy/netstuff1

std: Decouple network streams from fs.File
This commit is contained in:
Andrew Kelley 2021-01-11 15:35:00 -08:00 committed by GitHub
commit ec9158305d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 109 additions and 11 deletions

View File

@ -10,8 +10,13 @@ const net = @This();
const mem = std.mem;
const os = std.os;
const fs = std.fs;
const io = std.io;
pub const has_unix_sockets = @hasDecl(os, "sockaddr_un");
// Windows 10 added support for unix sockets in build 17063, redstone 4 is the
// first release to support them.
pub const has_unix_sockets = @hasDecl(os, "sockaddr_un") and
(builtin.os.tag != .windows or
std.Target.current.os.version_range.windows.isAtLeast(.win10_rs4) orelse false);
pub const Address = extern union {
any: os.sockaddr,
@ -596,7 +601,7 @@ pub const Ip6Address = extern struct {
}
};
pub fn connectUnixSocket(path: []const u8) !fs.File {
pub fn connectUnixSocket(path: []const u8) !Stream {
const opt_non_block = if (std.io.is_async) os.SOCK_NONBLOCK else 0;
const sockfd = try os.socket(
os.AF_UNIX,
@ -614,7 +619,7 @@ pub fn connectUnixSocket(path: []const u8) !fs.File {
try os.connect(sockfd, &addr.any, addr.getOsSockLen());
}
return fs.File{
return Stream{
.handle = sockfd,
};
}
@ -648,7 +653,7 @@ pub const AddressList = struct {
};
/// All memory allocated with `allocator` will be freed before this function returns.
pub fn tcpConnectToHost(allocator: *mem.Allocator, name: []const u8, port: u16) !fs.File {
pub fn tcpConnectToHost(allocator: *mem.Allocator, name: []const u8, port: u16) !Stream {
const list = try getAddressList(allocator, name, port);
defer list.deinit();
@ -665,7 +670,7 @@ pub fn tcpConnectToHost(allocator: *mem.Allocator, name: []const u8, port: u16)
return std.os.ConnectError.ConnectionRefused;
}
pub fn tcpConnectToAddress(address: Address) !fs.File {
pub fn tcpConnectToAddress(address: Address) !Stream {
const nonblock = if (std.io.is_async) os.SOCK_NONBLOCK else 0;
const sock_flags = os.SOCK_STREAM | nonblock |
(if (builtin.os.tag == .windows) 0 else os.SOCK_CLOEXEC);
@ -679,7 +684,7 @@ pub fn tcpConnectToAddress(address: Address) !fs.File {
try os.connect(sockfd, &address.any, address.getOsSockLen());
}
return fs.File{ .handle = sockfd };
return Stream{ .handle = sockfd };
}
/// Call `AddressList.deinit` on the result.
@ -1580,6 +1585,55 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8)
}
}
pub const Stream = struct {
// Underlying socket descriptor.
// Note that on some platforms this may not be interchangeable with a
// regular files descriptor.
handle: os.socket_t,
pub fn close(self: Stream) void {
os.closeSocket(self.handle);
}
pub const ReadError = os.ReadError;
pub const WriteError = os.WriteError;
pub const Reader = io.Reader(Stream, ReadError, read);
pub const Writer = io.Writer(Stream, WriteError, write);
pub fn reader(self: Stream) Reader {
return .{ .context = self };
}
pub fn writer(self: Stream) Writer {
return .{ .context = self };
}
pub fn read(self: Stream, buffer: []u8) ReadError!usize {
if (std.Target.current.os.tag == .windows) {
return os.windows.ReadFile(self.handle, buffer, null, io.default_mode);
}
if (std.io.is_async) {
return std.event.Loop.instance.?.read(self.handle, buffer, false);
} else {
return os.read(self.handle, buffer);
}
}
pub fn write(self: Stream, buffer: []const u8) WriteError!usize {
if (std.Target.current.os.tag == .windows) {
return os.windows.WriteFile(self.handle, buffer, null, io.default_mode);
}
if (std.io.is_async) {
return std.event.Loop.instance.?.write(self.handle, buffer, false);
} else {
return os.write(self.handle, buffer);
}
}
};
pub const StreamServer = struct {
/// Copied from `Options` on `init`.
kernel_backlog: u31,
@ -1686,7 +1740,7 @@ pub const StreamServer = struct {
} || os.UnexpectedError;
pub const Connection = struct {
file: fs.File,
stream: Stream,
address: Address,
};
@ -1705,7 +1759,7 @@ pub const StreamServer = struct {
if (accept_result) |fd| {
return Connection{
.file = fs.File{ .handle = fd },
.stream = Stream{ .handle = fd },
.address = accepted_addr,
};
} else |err| switch (err) {

View File

@ -145,7 +145,7 @@ test "listen on a port, send bytes, receive bytes" {
// Try only the IPv4 variant as some CI builders have no IPv6 localhost
// configured.
const localhost = try net.Address.parseIp("127.0.0.1", 8080);
const localhost = try net.Address.parseIp("127.0.0.1", 0);
var server = net.StreamServer.init(.{});
defer server.deinit();
@ -165,8 +165,9 @@ test "listen on a port, send bytes, receive bytes" {
defer t.wait();
var client = try server.accept();
defer client.stream.close();
var buf: [16]u8 = undefined;
const n = try client.file.reader().read(&buf);
const n = try client.stream.reader().read(&buf);
testing.expectEqual(@as(usize, 12), n);
testing.expectEqualSlices(u8, "Hello world!", buf[0..n]);
@ -249,6 +250,49 @@ fn testServer(server: *net.StreamServer) anyerror!void {
var client = try server.accept();
const stream = client.file.writer();
const stream = client.stream.writer();
try stream.print("hello from server\n", .{});
}
test "listen on a unix socket, send bytes, receive bytes" {
if (builtin.single_threaded) return error.SkipZigTest;
if (!net.has_unix_sockets) return error.SkipZigTest;
if (std.builtin.os.tag == .windows) {
_ = try std.os.windows.WSAStartup(2, 2);
}
defer {
if (std.builtin.os.tag == .windows) {
std.os.windows.WSACleanup() catch unreachable;
}
}
var server = net.StreamServer.init(.{});
defer server.deinit();
const socket_path = "socket.unix";
var socket_addr = try net.Address.initUnix(socket_path);
defer std.fs.cwd().deleteFile(socket_path) catch {};
try server.listen(socket_addr);
const S = struct {
fn clientFn(_: void) !void {
const socket = try net.connectUnixSocket(socket_path);
defer socket.close();
_ = try socket.writer().writeAll("Hello world!");
}
};
const t = try std.Thread.spawn({}, S.clientFn);
defer t.wait();
var client = try server.accept();
defer client.stream.close();
var buf: [16]u8 = undefined;
const n = try client.stream.reader().read(&buf);
testing.expectEqual(@as(usize, 12), n);
testing.expectEqualSlices(u8, "Hello world!", buf[0..n]);
}