From 7b5886118dabb59967d3e9b17d0502146df2ef92 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Wed, 2 Apr 2025 18:03:53 -0400 Subject: [PATCH] Io.Condition: implement full API --- lib/std/Io.zig | 25 ++++++++++---- lib/std/Io/EventLoop.zig | 70 +++++++++++++++++++++++++++++++--------- lib/std/Thread/Pool.zig | 7 ++-- 3 files changed, 78 insertions(+), 24 deletions(-) diff --git a/lib/std/Io.zig b/lib/std/Io.zig index d1e3947609..9a3423825b 100644 --- a/lib/std/Io.zig +++ b/lib/std/Io.zig @@ -983,7 +983,7 @@ pub const VTable = struct { mutexUnlock: *const fn (?*anyopaque, prev_state: Mutex.State, mutex: *Mutex) void, conditionWait: *const fn (?*anyopaque, cond: *Condition, mutex: *Mutex) Cancelable!void, - conditionWake: *const fn (?*anyopaque, cond: *Condition) void, + conditionWake: *const fn (?*anyopaque, cond: *Condition, wake: Condition.Wake) void, createFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) FileOpenError!fs.File, openFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) FileOpenError!fs.File, @@ -1162,9 +1162,20 @@ pub const Condition = struct { return io.vtable.conditionWait(io.userdata, cond, mutex); } - pub fn wake(cond: *Condition, io: Io) void { - io.vtable.conditionWake(io.userdata, cond); + pub fn signal(cond: *Condition, io: Io) void { + io.vtable.conditionWake(io.userdata, cond, .one); } + + pub fn broadcast(cond: *Condition, io: Io) void { + io.vtable.conditionWake(io.userdata, cond, .all); + } + + pub const Wake = enum { + /// wake up only one thread + one, + /// wake up all thread + all, + }; }; pub const TypeErasedQueue = struct { @@ -1216,7 +1227,7 @@ pub const TypeErasedQueue = struct { remaining = remaining[copy_len..]; getter.data.remaining = getter.data.remaining[copy_len..]; if (getter.data.remaining.len == 0) { - getter.data.condition.wake(io); + getter.data.condition.signal(io); continue; } q.getters.prepend(getter); @@ -1299,7 +1310,7 @@ pub const TypeErasedQueue = struct { putter.data.remaining = putter.data.remaining[copy_len..]; remaining = remaining[copy_len..]; if (putter.data.remaining.len == 0) { - putter.data.condition.wake(io); + putter.data.condition.signal(io); } else { assert(remaining.len == 0); q.putters.prepend(putter); @@ -1332,7 +1343,7 @@ pub const TypeErasedQueue = struct { putter.data.remaining = putter.data.remaining[copy_len..]; q.put_index += copy_len; if (putter.data.remaining.len == 0) { - putter.data.condition.wake(io); + putter.data.condition.signal(io); continue; } const second_available = q.buffer[0..q.get_index]; @@ -1341,7 +1352,7 @@ pub const TypeErasedQueue = struct { putter.data.remaining = putter.data.remaining[copy_len..]; q.put_index = copy_len; if (putter.data.remaining.len == 0) { - putter.data.condition.wake(io); + putter.data.condition.signal(io); continue; } q.putters.prepend(putter); diff --git a/lib/std/Io/EventLoop.zig b/lib/std/Io/EventLoop.zig index 83747bd008..edd00baac6 100644 --- a/lib/std/Io/EventLoop.zig +++ b/lib/std/Io/EventLoop.zig @@ -555,8 +555,24 @@ const SwitchMessage = struct { .condition_wait => |condition_wait| { const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev)); assert(prev_fiber.queue_next == null); - const cond_state: *?*Fiber = @ptrCast(&condition_wait.cond.state); - assert(@atomicRmw(?*Fiber, cond_state, .Xchg, prev_fiber, .release) == null); // More than one wait on same Condition is illegal. + const cond_impl = prev_fiber.resultPointer(ConditionImpl); + cond_impl.* = .{ + .tail = prev_fiber, + .event = .queued, + }; + if (@cmpxchgStrong( + ?*Fiber, + @as(*?*Fiber, @ptrCast(&condition_wait.cond.state)), + null, + prev_fiber, + .release, + .acquire, + )) |waiting_fiber| { + const waiting_cond_impl = waiting_fiber.?.resultPointer(ConditionImpl); + assert(waiting_cond_impl.tail.queue_next == null); + waiting_cond_impl.tail.queue_next = prev_fiber; + waiting_cond_impl.tail = prev_fiber; + } condition_wait.mutex.unlock(el.io()); }, .exit => for (el.threads.allocated[0..@atomicLoad(u32, &el.threads.active, .acquire)]) |*each_thread| { @@ -1267,10 +1283,7 @@ fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadl fn mutexLock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) error{Canceled}!void { const el: *EventLoop = @alignCast(@ptrCast(userdata)); - el.yield(null, .{ .mutex_lock = .{ - .prev_state = prev_state, - .mutex = mutex, - } }); + el.yield(null, .{ .mutex_lock = .{ .prev_state = prev_state, .mutex = mutex } }); } fn mutexUnlock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) void { var maybe_waiting_fiber: ?*Fiber = @ptrFromInt(@intFromEnum(prev_state)); @@ -1294,21 +1307,48 @@ fn mutexUnlock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mut el.yield(maybe_waiting_fiber.?, .reschedule); } +const ConditionImpl = struct { + tail: *Fiber, + event: union(enum) { + queued, + wake: Io.Condition.Wake, + }, +}; + fn conditionWait(userdata: ?*anyopaque, cond: *Io.Condition, mutex: *Io.Mutex) Io.Cancelable!void { const el: *EventLoop = @alignCast(@ptrCast(userdata)); - el.yield(null, .{ .condition_wait = .{ - .cond = cond, - .mutex = mutex, - } }); + el.yield(null, .{ .condition_wait = .{ .cond = cond, .mutex = mutex } }); + const thread = Thread.current(); + const fiber = thread.currentFiber(); + const cond_impl = fiber.resultPointer(ConditionImpl); try mutex.lock(el.io()); + switch (cond_impl.event) { + .queued => {}, + .wake => |wake| if (fiber.queue_next) |next_fiber| switch (wake) { + .one => if (@cmpxchgStrong( + ?*Fiber, + @as(*?*Fiber, @ptrCast(&cond.state)), + null, + next_fiber, + .release, + .acquire, + )) |old_fiber| { + const old_cond_impl = old_fiber.?.resultPointer(ConditionImpl); + assert(old_cond_impl.tail.queue_next == null); + old_cond_impl.tail.queue_next = next_fiber; + old_cond_impl.tail = cond_impl.tail; + }, + .all => el.schedule(thread, .{ .head = next_fiber, .tail = cond_impl.tail }), + }, + } + fiber.queue_next = null; } -fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition) void { +fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition, wake: Io.Condition.Wake) void { const el: *EventLoop = @alignCast(@ptrCast(userdata)); - const cond_state: *?*Fiber = @ptrCast(&cond.state); - if (@atomicRmw(?*Fiber, cond_state, .Xchg, null, .acquire)) |fiber| { - el.yield(fiber, .reschedule); - } + const waiting_fiber = @atomicRmw(?*Fiber, @as(*?*Fiber, @ptrCast(&cond.state)), .Xchg, null, .acquire) orelse return; + waiting_fiber.resultPointer(ConditionImpl).event = .{ .wake = wake }; + el.yield(waiting_fiber, .reschedule); } fn errno(signed: i32) std.os.linux.E { diff --git a/lib/std/Thread/Pool.zig b/lib/std/Thread/Pool.zig index 05bc4801b6..1e9903e45d 100644 --- a/lib/std/Thread/Pool.zig +++ b/lib/std/Thread/Pool.zig @@ -666,7 +666,7 @@ fn conditionWait(userdata: ?*anyopaque, cond: *Io.Condition, mutex: *Io.Mutex) I } } -fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition) void { +fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition, wake: Io.Condition.Wake) void { const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); _ = pool; comptime assert(@TypeOf(cond.state) == u64); @@ -690,7 +690,10 @@ fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition) void { return; } - const to_wake = 1; + const to_wake = switch (wake) { + .one => 1, + .all => wakeable, + }; // Reserve the amount of waiters to wake by incrementing the signals count. // Release barrier ensures code before the wake() happens before the signal it posted and consumed by the wait() threads.