diff --git a/lib/std/os/linux/io_uring.zig b/lib/std/os/linux/io_uring.zig index cde97e2091..d90561e6e4 100644 --- a/lib/std/os/linux/io_uring.zig +++ b/lib/std/os/linux/io_uring.zig @@ -505,6 +505,24 @@ pub const IO_Uring = struct { return sqe; } + /// Queues (but does not submit) an SQE to perform an multishot `accept4(2)` on a socket. + /// Multishot variant allows an application to issue a single accept request, + /// which will repeatedly trigger a CQE when a connection request comes in. + /// Returns a pointer to the SQE. + pub fn accept_multishot( + self: *IO_Uring, + user_data: u64, + fd: os.fd_t, + addr: ?*os.sockaddr, + addrlen: ?*os.socklen_t, + flags: u32, + ) !*linux.io_uring_sqe { + const sqe = try self.get_sqe(); + io_uring_prep_multishot_accept(sqe, fd, addr, addrlen, flags); + sqe.user_data = user_data; + return sqe; + } + /// Queue (but does not submit) an SQE to perform a `connect(2)` on a socket. /// Returns a pointer to the SQE. pub fn connect( @@ -1621,6 +1639,17 @@ pub fn io_uring_prep_remove_buffers( sqe.buf_index = @intCast(group_id); } +pub fn io_uring_prep_multishot_accept( + sqe: *linux.io_uring_sqe, + fd: os.fd_t, + addr: ?*os.sockaddr, + addrlen: ?*os.socklen_t, + flags: u32, +) void { + io_uring_prep_accept(sqe, fd, addr, addrlen, flags); + sqe.ioprio |= linux.IORING_ACCEPT_MULTISHOT; +} + test "structs/offsets/entries" { if (builtin.os.tag != .linux) return error.SkipZigTest; @@ -3353,20 +3382,10 @@ const SocketTestHarness = struct { fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness { // Create a TCP server socket - var address = try net.Address.parseIp4("127.0.0.1", 0); - const kernel_backlog = 1; - const listener_socket = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); + const listener_socket = try createListenerSocket(&address); errdefer os.closeSocket(listener_socket); - try os.setsockopt(listener_socket, os.SOL.SOCKET, os.SO.REUSEADDR, &mem.toBytes(@as(c_int, 1))); - try os.bind(listener_socket, &address.any, address.getOsSockLen()); - try os.listen(listener_socket, kernel_backlog); - - // set address to the OS-chosen IP/port. - var slen: os.socklen_t = address.getOsSockLen(); - try os.getsockname(listener_socket, &address.any, &slen); - // Submit 1 accept var accept_addr: os.sockaddr = undefined; var accept_addr_len: os.socklen_t = @sizeOf(@TypeOf(accept_addr)); @@ -3410,3 +3429,58 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness { .client = client, }; } + +fn createListenerSocket(address: *net.Address) !os.socket_t { + const kernel_backlog = 1; + const listener_socket = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); + errdefer os.closeSocket(listener_socket); + + try os.setsockopt(listener_socket, os.SOL.SOCKET, os.SO.REUSEADDR, &mem.toBytes(@as(c_int, 1))); + try os.bind(listener_socket, &address.any, address.getOsSockLen()); + try os.listen(listener_socket, kernel_backlog); + + // set address to the OS-chosen IP/port. + var slen: os.socklen_t = address.getOsSockLen(); + try os.getsockname(listener_socket, &address.any, &slen); + + return listener_socket; +} + +test "accept multishot" { + if (builtin.os.tag != .linux) return error.SkipZigTest; + + var ring = IO_Uring.init(16, 0) catch |err| switch (err) { + error.SystemOutdated => return error.SkipZigTest, + error.PermissionDenied => return error.SkipZigTest, + else => return err, + }; + defer ring.deinit(); + + var address = try net.Address.parseIp4("127.0.0.1", 0); + const listener_socket = try createListenerSocket(&address); + defer os.closeSocket(listener_socket); + + // submit multishot accept operation + var addr: os.sockaddr = undefined; + var addr_len: os.socklen_t = @sizeOf(@TypeOf(addr)); + const userdata: u64 = 0xaaaaaaaa; + _ = try ring.accept_multishot(userdata, listener_socket, &addr, &addr_len, 0); + try testing.expectEqual(@as(u32, 1), try ring.submit()); + + var nr: usize = 4; // number of clients to connect + while (nr > 0) : (nr -= 1) { + // connect client + var client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); + errdefer os.closeSocket(client); + try os.connect(client, &address.any, address.getOsSockLen()); + + // test accept completion + var cqe = try ring.copy_cqe(); + if (cqe.err() == .INVAL) return error.SkipZigTest; + try testing.expect(cqe.res > 0); + try testing.expect(cqe.user_data == userdata); + try testing.expect(cqe.flags & linux.IORING_CQE_F_MORE > 0); // more flag is set + + os.closeSocket(client); + } +}