diff --git a/lib/std/Thread/Condition.zig b/lib/std/Thread/Condition.zig index 1482c8166d..6829a9f15c 100644 --- a/lib/std/Thread/Condition.zig +++ b/lib/std/Thread/Condition.zig @@ -194,59 +194,50 @@ const FutexImpl = struct { const signal_mask = 0xffff << 16; fn wait(self: *Impl, mutex: *Mutex, timeout: ?u64) error{Timeout}!void { - // Register that we're waiting on the state by incrementing the wait count. - // This assumes that there can be at most ((1<<16)-1) or 65,355 threads concurrently waiting on the same Condvar. - // If this is hit in practice, then this condvar not working is the least of your concerns. + // Observe the epoch, then check the state again to see if we should wake up. + // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock: + // + // - T1: s = LOAD(&state) + // - T2: UPDATE(&s, signal) + // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch) + // - T1: e = LOAD(&epoch) (was reordered after the state load) + // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change) + // + // Acquire barrier to ensure the epoch load happens before the state load. + const epoch = self.epoch.load(.Acquire); var state = self.state.fetchAdd(one_waiter, .Monotonic); assert(state & waiter_mask != waiter_mask); state += one_waiter; + var futex_deadline = Futex.Deadline.init(timeout); - // Temporarily release the mutex in order to block on the condition variable. mutex.unlock(); defer mutex.lock(); - var futex_deadline = Futex.Deadline.init(timeout); - while (true) { - // Try to wake up by consuming a signal and decremented the waiter we added previously. - // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. - while (state & signal_mask != 0) { - const new_state = state - one_waiter - one_signal; - state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; - } - - // Observe the epoch, then check the state again to see if we should wake up. - // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock: - // - // - T1: s = LOAD(&state) - // - T2: UPDATE(&s, signal) - // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch) - // - T1: e = LOAD(&epoch) (was reordered after the state load) - // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change) - // - // Acquire barrier to ensure the epoch load happens before the state load. - const epoch = self.epoch.load(.Acquire); - state = self.state.load(.Monotonic); - if (state & signal_mask != 0) { - continue; - } - - futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) { - // On timeout, we must decrement the waiter we added above. - error.Timeout => { - while (true) { - // If there's a signal when we're timing out, consume it and report being woken up instead. - // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. - while (state & signal_mask != 0) { - const new_state = state - one_waiter - one_signal; - state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; - } - - // Remove the waiter we added and officially return timed out. - const new_state = state - one_waiter; - state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err; + futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) { + // On timeout, we must decrement the waiter we added above. + error.Timeout => { + while (true) { + // If there's a signal when we're timing out, consume it and report being woken up instead. + // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. + while (state & signal_mask != 0) { + const new_state = state - one_waiter - one_signal; + state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; } - }, - }; + + // Remove the waiter we added and officially return timed out. + const new_state = state - one_waiter; + state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err; + } + }, + }; + + while (true) { + // Wait thread, decrement waiter and consume signal if exists. + var new_state = state - one_waiter; + if (state & signal_mask != 0) { + new_state = state - one_signal; + } + state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; } } @@ -536,3 +527,74 @@ test "Condition - broadcasting" { t.join(); } } + +test "Condition - broadcasting - wake all threads" { + // Tests issue #12877 + // This test requires spawning threads + if (builtin.single_threaded) { + return error.SkipZigTest; + } + + const num_threads = 10; + + const BroadcastTest = struct { + mutex: Mutex = .{}, + cond: Condition = .{}, + completed: Condition = .{}, + count: usize = 0, + thread_id_to_wake: usize = 0, + threads: [num_threads]std.Thread = undefined, + wakeups: usize = 0, + + fn run(self: *@This(), thread_id: usize) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + // The last broadcast thread to start tells the main test thread it's completed. + self.count += 1; + if (self.count == num_threads) { + self.completed.signal(); + } + + while (self.thread_id_to_wake != thread_id) { + self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake }); + self.wakeups += 1; + } + if (self.thread_id_to_wake <= num_threads) { + // Signal next thread to wake up. + self.thread_id_to_wake += 1; + self.cond.broadcast(); + } + } + }; + + var broadcast_test = BroadcastTest{}; + var thread_id: usize = 1; + for (broadcast_test.threads) |*t| { + t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id }); + thread_id += 1; + } + + { + broadcast_test.mutex.lock(); + defer broadcast_test.mutex.unlock(); + + // Wait for all the broadcast threads to spawn. + // timedWait() to detect any potential deadlocks. + while (broadcast_test.count != num_threads) { + try broadcast_test.completed.timedWait( + &broadcast_test.mutex, + 1 * std.time.ns_per_s, + ); + } + + // Signal thread 1 to wake up + broadcast_test.thread_id_to_wake = 1; + broadcast_test.cond.broadcast(); + } + + for (broadcast_test.threads) |t| { + t.join(); + } + //std.debug.print("wakeups {d}\n", .{broadcast_test.wakeups}); +}