From 0c1e102e975f062fdb7b513256598a0f9846a543 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 3 Apr 2025 21:15:19 -0700 Subject: [PATCH] std.Io.EventLoop: implement select --- lib/std/Io/EventLoop.zig | 92 +++++++++++++++++++++++++++++++--------- lib/std/Thread/Pool.zig | 2 +- 2 files changed, 72 insertions(+), 22 deletions(-) diff --git a/lib/std/Io/EventLoop.zig b/lib/std/Io/EventLoop.zig index 3fc5aa0bac..b821f2f7e4 100644 --- a/lib/std/Io/EventLoop.zig +++ b/lib/std/Io/EventLoop.zig @@ -485,6 +485,7 @@ const SwitchMessage = struct { reschedule, recycle, register_awaiter: *?*Fiber, + register_select: []const *Io.AnyFuture, mutex_lock: struct { prev_state: Io.Mutex.State, mutex: *Io.Mutex, @@ -514,13 +515,21 @@ const SwitchMessage = struct { .register_awaiter => |awaiter| { const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev)); assert(prev_fiber.queue_next == null); - if (@atomicRmw( - ?*Fiber, - awaiter, - .Xchg, - prev_fiber, - .acq_rel, - ) == Fiber.finished) el.schedule(thread, .{ .head = prev_fiber, .tail = prev_fiber }); + if (@atomicRmw(?*Fiber, awaiter, .Xchg, prev_fiber, .acq_rel) == Fiber.finished) + el.schedule(thread, .{ .head = prev_fiber, .tail = prev_fiber }); + }, + .register_select => |futures| { + const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev)); + assert(prev_fiber.queue_next == null); + for (futures) |any_future| { + const future_fiber: *Fiber = @alignCast(@ptrCast(any_future)); + if (@atomicRmw(?*Fiber, &future_fiber.awaiter, .Xchg, prev_fiber, .acq_rel) == Fiber.finished) { + const closure: *AsyncClosure = .fromFiber(future_fiber); + if (!@atomicRmw(bool, &closure.already_awaited, .Xchg, true, .seq_cst)) { + el.schedule(thread, .{ .head = prev_fiber, .tail = prev_fiber }); + } + } + } }, .mutex_lock => |mutex_lock| { const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev)); @@ -661,6 +670,7 @@ const AsyncClosure = struct { fiber: *Fiber, start: *const fn (context: *const anyopaque, result: *anyopaque) void, result_align: Alignment, + already_awaited: bool, fn contextPointer(closure: *AsyncClosure) [*]align(Fiber.max_context_align.toByteUnits()) u8 { return @alignCast(@as([*]u8, @ptrCast(closure)) + @sizeOf(AsyncClosure)); @@ -668,12 +678,24 @@ const AsyncClosure = struct { fn call(closure: *AsyncClosure, message: *const SwitchMessage) callconv(.withStackAlign(.c, @alignOf(AsyncClosure))) noreturn { message.handle(closure.event_loop); - std.log.debug("{*} performing async", .{closure.fiber}); - closure.start(closure.contextPointer(), closure.fiber.resultBytes(closure.result_align)); - const awaiter = @atomicRmw(?*Fiber, &closure.fiber.awaiter, .Xchg, Fiber.finished, .acq_rel); - closure.event_loop.yield(awaiter, .nothing); + const fiber = closure.fiber; + std.log.debug("{*} performing async", .{fiber}); + closure.start(closure.contextPointer(), fiber.resultBytes(closure.result_align)); + const awaiter = @atomicRmw(?*Fiber, &fiber.awaiter, .Xchg, Fiber.finished, .acq_rel); + const ready_awaiter = r: { + const a = awaiter orelse break :r null; + if (@atomicRmw(bool, &closure.already_awaited, .Xchg, true, .acq_rel)) break :r null; + break :r a; + }; + closure.event_loop.yield(ready_awaiter, .nothing); unreachable; // switched to dead fiber } + + fn fromFiber(fiber: *Fiber) *AsyncClosure { + return @ptrFromInt(Fiber.max_context_align.max(.of(AsyncClosure)).backward( + @intFromPtr(fiber.allocatedEnd()) - Fiber.max_context_size, + ) - @sizeOf(AsyncClosure)); + } }; fn @"async"( @@ -696,9 +718,7 @@ fn @"async"( }; std.log.debug("allocated {*}", .{fiber}); - const closure: *AsyncClosure = @ptrFromInt(Fiber.max_context_align.max(.of(AsyncClosure)).backward( - @intFromPtr(fiber.allocatedEnd()) - Fiber.max_context_size, - ) - @sizeOf(AsyncClosure)); + const closure: *AsyncClosure = .fromFiber(fiber); const stack_end: [*]usize = @alignCast(@ptrCast(closure)); (stack_end - 1)[0..1].* = .{@intFromPtr(&AsyncClosure.call)}; fiber.* = .{ @@ -721,6 +741,7 @@ fn @"async"( .fiber = fiber, .start = start, .result_align = result_alignment, + .already_awaited = false, }; @memcpy(closure.contextPointer(), context); @@ -728,13 +749,6 @@ fn @"async"( return @ptrCast(fiber); } -fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize { - const el: *EventLoop = @alignCast(@ptrCast(userdata)); - _ = el; - _ = futures; - @panic("TODO"); -} - const DetachedClosure = struct { event_loop: *EventLoop, fiber: *Fiber, @@ -836,6 +850,42 @@ fn @"await"( event_loop.recycle(future_fiber); } +fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize { + const el: *EventLoop = @alignCast(@ptrCast(userdata)); + + // Optimization to avoid the yield below. + for (futures, 0..) |any_future, i| { + const future_fiber: *Fiber = @alignCast(@ptrCast(any_future)); + if (@atomicLoad(?*Fiber, &future_fiber.awaiter, .acquire) == Fiber.finished) + return i; + } + + el.yield(null, .{ .register_select = futures }); + + std.log.debug("back from select yield", .{}); + + const my_thread: *Thread = .current(); + const my_fiber = my_thread.currentFiber(); + var result: ?usize = null; + + for (futures, 0..) |any_future, i| { + const future_fiber: *Fiber = @alignCast(@ptrCast(any_future)); + if (@cmpxchgStrong(?*Fiber, &future_fiber.awaiter, my_fiber, null, .seq_cst, .seq_cst)) |awaiter| { + if (awaiter == Fiber.finished) { + if (result == null) result = i; + } else if (awaiter) |a| { + const closure: *AsyncClosure = .fromFiber(a); + closure.already_awaited = false; + } + } else { + const closure: *AsyncClosure = .fromFiber(my_fiber); + closure.already_awaited = false; + } + } + + return result.?; +} + fn cancel( userdata: ?*anyopaque, any_future: *std.Io.AnyFuture, diff --git a/lib/std/Thread/Pool.zig b/lib/std/Thread/Pool.zig index 64668a1c1f..02daa473f8 100644 --- a/lib/std/Thread/Pool.zig +++ b/lib/std/Thread/Pool.zig @@ -364,7 +364,7 @@ const AsyncClosure = struct { context_offset: usize, result_offset: usize, - const done_reset_event: *std.Thread.ResetEvent = @ptrFromInt(std.mem.alignBackward(usize, std.math.maxInt(usize), @alignOf(std.Thread.ResetEvent))); + const done_reset_event: *std.Thread.ResetEvent = @ptrFromInt(@alignOf(std.Thread.ResetEvent)); const canceling_tid: std.Thread.Id = switch (@typeInfo(std.Thread.Id)) { .int => |int_info| switch (int_info.signedness) {