From 27b34a5b77623b106ae1e1c7c4a466608a10905e Mon Sep 17 00:00:00 2001 From: Jan Philipp Hafer Date: Tue, 31 Oct 2023 21:54:03 +0100 Subject: [PATCH] std.net: enable forcing non-blocking mode for accept Justification: It is common for non-CPU bound short routines to do non-blocking accept to eliminate unnecessary delays before subscribing to data, for example in hardware integration tests. --- lib/std/net.zig | 21 +++++++++++++++------ lib/std/net/test.zig | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/lib/std/net.zig b/lib/std/net.zig index 7e15effe54..a687e39142 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1871,6 +1871,7 @@ pub const StreamServer = struct { kernel_backlog: u31, reuse_address: bool, reuse_port: bool, + force_nonblocking: bool, /// `undefined` until `listen` returns successfully. listen_address: Address, @@ -1888,6 +1889,9 @@ pub const StreamServer = struct { /// Enable SO.REUSEPORT on the socket. reuse_port: bool = false, + + /// Force non-blocking mode. + force_nonblocking: bool = false, }; /// After this call succeeds, resources have been acquired and must @@ -1898,6 +1902,7 @@ pub const StreamServer = struct { .kernel_backlog = options.kernel_backlog, .reuse_address = options.reuse_address, .reuse_port = options.reuse_port, + .force_nonblocking = options.force_nonblocking, .listen_address = undefined, }; } @@ -1911,9 +1916,11 @@ pub const StreamServer = struct { pub fn listen(self: *StreamServer, address: Address) !void { const nonblock = if (std.io.is_async) os.SOCK.NONBLOCK else 0; const sock_flags = os.SOCK.STREAM | os.SOCK.CLOEXEC | nonblock; + var use_sock_flags: u32 = sock_flags; + if (self.force_nonblocking) use_sock_flags |= os.SOCK.NONBLOCK; const proto = if (address.any.family == os.AF.UNIX) @as(u32, 0) else os.IPPROTO.TCP; - const sockfd = try os.socket(address.any.family, sock_flags, proto); + const sockfd = try os.socket(address.any.family, use_sock_flags, proto); self.sockfd = sockfd; errdefer { os.closeSocket(sockfd); @@ -1963,8 +1970,8 @@ pub const StreamServer = struct { /// The system-wide limit on the total number of open files has been reached. SystemFdQuotaExceeded, - /// Not enough free memory. This often means that the memory allocation is limited - /// by the socket buffer limits, not by the system memory. + /// Not enough free memory. This often means that the memory allocation + /// is limited by the socket buffer limits, not by the system memory. SystemResources, /// Socket is not listening for new connections. @@ -1972,6 +1979,9 @@ pub const StreamServer = struct { ProtocolFailure, + /// Socket is in non-blocking mode and there is no connection to accept. + WouldBlock, + /// Firewall rules forbid connection. BlockedByFirewall, @@ -2007,9 +2017,8 @@ pub const StreamServer = struct { .stream = Stream{ .handle = fd }, .address = accepted_addr, }; - } else |err| switch (err) { - error.WouldBlock => unreachable, - else => |e| return e, + } else |err| { + return err; } } }; diff --git a/lib/std/net/test.zig b/lib/std/net/test.zig index 5e98ee2a4d..caac8f4e6c 100644 --- a/lib/std/net/test.zig +++ b/lib/std/net/test.zig @@ -340,3 +340,27 @@ fn generateFileName(base_name: []const u8) ![]const u8 { _ = std.fs.base64_encoder.encode(&sub_path, &random_bytes); return std.fmt.allocPrint(testing.allocator, "{s}-{s}", .{ sub_path[0..], base_name }); } + +test "non-blocking tcp server" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + const localhost = try net.Address.parseIp("127.0.0.1", 0); + var server = net.StreamServer.init(.{ .force_nonblocking = true }); + defer server.deinit(); + try server.listen(localhost); + + const accept_err = server.accept(); + try testing.expectError(error.WouldBlock, accept_err); + + const socket_file = try net.tcpConnectToAddress(server.listen_address); + defer socket_file.close(); + + var client = try server.accept(); + const stream = client.stream.writer(); + try stream.print("hello from server\n", .{}); + + var buf: [100]u8 = undefined; + const len = try socket_file.read(&buf); + const msg = buf[0..len]; + try testing.expect(mem.eql(u8, msg, "hello from server\n")); +}