diff --git a/lib/std/Thread/Condition.zig b/lib/std/Thread/Condition.zig index 6829a9f15c..3625aab576 100644 --- a/lib/std/Thread/Condition.zig +++ b/lib/std/Thread/Condition.zig @@ -204,40 +204,44 @@ const FutexImpl = struct { // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change) // // Acquire barrier to ensure the epoch load happens before the state load. - const epoch = self.epoch.load(.Acquire); + var epoch = self.epoch.load(.Acquire); var state = self.state.fetchAdd(one_waiter, .Monotonic); assert(state & waiter_mask != waiter_mask); state += one_waiter; - var futex_deadline = Futex.Deadline.init(timeout); mutex.unlock(); defer mutex.lock(); - futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) { - // On timeout, we must decrement the waiter we added above. - error.Timeout => { - while (true) { - // If there's a signal when we're timing out, consume it and report being woken up instead. - // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. - while (state & signal_mask != 0) { - const new_state = state - one_waiter - one_signal; - state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; - } - - // Remove the waiter we added and officially return timed out. - const new_state = state - one_waiter; - state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err; - } - }, - }; + var futex_deadline = Futex.Deadline.init(timeout); while (true) { - // Wait thread, decrement waiter and consume signal if exists. - var new_state = state - one_waiter; - if (state & signal_mask != 0) { - new_state = state - one_signal; + futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) { + // On timeout, we must decrement the waiter we added above. + error.Timeout => { + while (true) { + // If there's a signal when we're timing out, consume it and report being woken up instead. + // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. + while (state & signal_mask != 0) { + const new_state = state - one_waiter - one_signal; + state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; + } + + // Remove the waiter we added and officially return timed out. + const new_state = state - one_waiter; + state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err; + } + }, + }; + + epoch = self.epoch.load(.Acquire); + state = self.state.load(.Monotonic); + + // Try to wake up by consuming a signal and decremented the waiter we added previously. + // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. + while (state & signal_mask != 0) { + const new_state = state - one_waiter - one_signal; + state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; } - state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; } } @@ -535,66 +539,142 @@ test "Condition - broadcasting - wake all threads" { return error.SkipZigTest; } + var num_runs: usize = 1; const num_threads = 10; - const BroadcastTest = struct { - mutex: Mutex = .{}, - cond: Condition = .{}, - completed: Condition = .{}, - count: usize = 0, - thread_id_to_wake: usize = 0, - threads: [num_threads]std.Thread = undefined, - wakeups: usize = 0, + while (num_runs > 0) : (num_runs -= 1) { + const BroadcastTest = struct { + mutex: Mutex = .{}, + cond: Condition = .{}, + completed: Condition = .{}, + count: usize = 0, + thread_id_to_wake: usize = 0, + threads: [num_threads]std.Thread = undefined, + wakeups: usize = 0, - fn run(self: *@This(), thread_id: usize) void { - self.mutex.lock(); - defer self.mutex.unlock(); + fn run(self: *@This(), thread_id: usize) void { + self.mutex.lock(); + defer self.mutex.unlock(); - // The last broadcast thread to start tells the main test thread it's completed. - self.count += 1; - if (self.count == num_threads) { - self.completed.signal(); + // The last broadcast thread to start tells the main test thread it's completed. + self.count += 1; + if (self.count == num_threads) { + self.completed.signal(); + } + + while (self.thread_id_to_wake != thread_id) { + self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake }); + self.wakeups += 1; + } + if (self.thread_id_to_wake <= num_threads) { + // Signal next thread to wake up. + self.thread_id_to_wake += 1; + self.cond.broadcast(); + } + } + }; + + var broadcast_test = BroadcastTest{}; + var thread_id: usize = 1; + for (broadcast_test.threads) |*t| { + t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id }); + thread_id += 1; + } + + { + broadcast_test.mutex.lock(); + defer broadcast_test.mutex.unlock(); + + // Wait for all the broadcast threads to spawn. + // timedWait() to detect any potential deadlocks. + while (broadcast_test.count != num_threads) { + try broadcast_test.completed.timedWait( + &broadcast_test.mutex, + 1 * std.time.ns_per_s, + ); } - while (self.thread_id_to_wake != thread_id) { - self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake }); + // Signal thread 1 to wake up + broadcast_test.thread_id_to_wake = 1; + broadcast_test.cond.broadcast(); + } + + for (broadcast_test.threads) |t| { + t.join(); + } + } +} + +test "Condition - signal wakes one" { + // This test requires spawning threads + if (builtin.single_threaded) { + return error.SkipZigTest; + } + + var num_runs: usize = 1; + const num_threads = 3; + const timeoutDelay = 10 * std.time.ns_per_ms; + + while (num_runs > 0) : (num_runs -= 1) { + + // Start multiple runner threads, wait for them to start and send the signal + // then. Expect that one thread wake up and all other times out. + // + // Test depends on delay in timedWait! If too small all threads can timeout + // before any one gets wake up. + + const Runner = struct { + mutex: Mutex = .{}, + cond: Condition = .{}, + completed: Condition = .{}, + count: usize = 0, + threads: [num_threads]std.Thread = undefined, + wakeups: usize = 0, + timeouts: usize = 0, + + fn run(self: *@This()) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + // The last started thread tells the main test thread it's completed. + self.count += 1; + if (self.count == num_threads) { + self.completed.signal(); + } + + self.cond.timedWait(&self.mutex, timeoutDelay) catch { + self.timeouts += 1; + return; + }; self.wakeups += 1; } - if (self.thread_id_to_wake <= num_threads) { - // Signal next thread to wake up. - self.thread_id_to_wake += 1; - self.cond.broadcast(); + }; + + // Start threads + var runner = Runner{}; + for (runner.threads) |*t| { + t.* = try std.Thread.spawn(.{}, Runner.run, .{&runner}); + } + + { + runner.mutex.lock(); + defer runner.mutex.unlock(); + + // Wait for all the threads to spawn. + // timedWait() to detect any potential deadlocks. + while (runner.count != num_threads) { + try runner.completed.timedWait(&runner.mutex, 1 * std.time.ns_per_s); } - } - }; - - var broadcast_test = BroadcastTest{}; - var thread_id: usize = 1; - for (broadcast_test.threads) |*t| { - t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id }); - thread_id += 1; - } - - { - broadcast_test.mutex.lock(); - defer broadcast_test.mutex.unlock(); - - // Wait for all the broadcast threads to spawn. - // timedWait() to detect any potential deadlocks. - while (broadcast_test.count != num_threads) { - try broadcast_test.completed.timedWait( - &broadcast_test.mutex, - 1 * std.time.ns_per_s, - ); + // Signal one thread, the others should get timeout. + runner.cond.signal(); } - // Signal thread 1 to wake up - broadcast_test.thread_id_to_wake = 1; - broadcast_test.cond.broadcast(); - } + for (runner.threads) |t| { + t.join(); + } - for (broadcast_test.threads) |t| { - t.join(); + // Expect that only one got singal + try std.testing.expectEqual(runner.wakeups, 1); + try std.testing.expectEqual(runner.timeouts, num_threads - 1); } - //std.debug.print("wakeups {d}\n", .{broadcast_test.wakeups}); }