stdlib: Thread.Condition wake only if signaled

Previous implementation didn't check whether there are pending signals
after return from futex.wait. While it is ok for broadcast case it can
result in multiple wakeups when only one thread is signaled.
This implementation checks that there are pending signals before
returning from wait.
It is similar to the original implementation but the without initial
signal check, here we first go to the futex and then check for pending
signal.
This commit is contained in:
Igor Anić 2022-11-21 17:26:54 +01:00
parent f229b74099
commit 9947b47d80

View File

@ -204,40 +204,44 @@ const FutexImpl = struct {
// - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change)
//
// Acquire barrier to ensure the epoch load happens before the state load.
const epoch = self.epoch.load(.Acquire);
var epoch = self.epoch.load(.Acquire);
var state = self.state.fetchAdd(one_waiter, .Monotonic);
assert(state & waiter_mask != waiter_mask);
state += one_waiter;
var futex_deadline = Futex.Deadline.init(timeout);
mutex.unlock();
defer mutex.lock();
futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) {
// On timeout, we must decrement the waiter we added above.
error.Timeout => {
while (true) {
// If there's a signal when we're timing out, consume it and report being woken up instead.
// Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
while (state & signal_mask != 0) {
const new_state = state - one_waiter - one_signal;
state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
}
// Remove the waiter we added and officially return timed out.
const new_state = state - one_waiter;
state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err;
}
},
};
var futex_deadline = Futex.Deadline.init(timeout);
while (true) {
// Wait thread, decrement waiter and consume signal if exists.
var new_state = state - one_waiter;
if (state & signal_mask != 0) {
new_state = state - one_signal;
futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) {
// On timeout, we must decrement the waiter we added above.
error.Timeout => {
while (true) {
// If there's a signal when we're timing out, consume it and report being woken up instead.
// Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
while (state & signal_mask != 0) {
const new_state = state - one_waiter - one_signal;
state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
}
// Remove the waiter we added and officially return timed out.
const new_state = state - one_waiter;
state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err;
}
},
};
epoch = self.epoch.load(.Acquire);
state = self.state.load(.Monotonic);
// Try to wake up by consuming a signal and decremented the waiter we added previously.
// Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
while (state & signal_mask != 0) {
const new_state = state - one_waiter - one_signal;
state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
}
state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
}
}
@ -535,66 +539,142 @@ test "Condition - broadcasting - wake all threads" {
return error.SkipZigTest;
}
var num_runs: usize = 1;
const num_threads = 10;
const BroadcastTest = struct {
mutex: Mutex = .{},
cond: Condition = .{},
completed: Condition = .{},
count: usize = 0,
thread_id_to_wake: usize = 0,
threads: [num_threads]std.Thread = undefined,
wakeups: usize = 0,
while (num_runs > 0) : (num_runs -= 1) {
const BroadcastTest = struct {
mutex: Mutex = .{},
cond: Condition = .{},
completed: Condition = .{},
count: usize = 0,
thread_id_to_wake: usize = 0,
threads: [num_threads]std.Thread = undefined,
wakeups: usize = 0,
fn run(self: *@This(), thread_id: usize) void {
self.mutex.lock();
defer self.mutex.unlock();
fn run(self: *@This(), thread_id: usize) void {
self.mutex.lock();
defer self.mutex.unlock();
// The last broadcast thread to start tells the main test thread it's completed.
self.count += 1;
if (self.count == num_threads) {
self.completed.signal();
// The last broadcast thread to start tells the main test thread it's completed.
self.count += 1;
if (self.count == num_threads) {
self.completed.signal();
}
while (self.thread_id_to_wake != thread_id) {
self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake });
self.wakeups += 1;
}
if (self.thread_id_to_wake <= num_threads) {
// Signal next thread to wake up.
self.thread_id_to_wake += 1;
self.cond.broadcast();
}
}
};
var broadcast_test = BroadcastTest{};
var thread_id: usize = 1;
for (broadcast_test.threads) |*t| {
t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id });
thread_id += 1;
}
{
broadcast_test.mutex.lock();
defer broadcast_test.mutex.unlock();
// Wait for all the broadcast threads to spawn.
// timedWait() to detect any potential deadlocks.
while (broadcast_test.count != num_threads) {
try broadcast_test.completed.timedWait(
&broadcast_test.mutex,
1 * std.time.ns_per_s,
);
}
while (self.thread_id_to_wake != thread_id) {
self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake });
// Signal thread 1 to wake up
broadcast_test.thread_id_to_wake = 1;
broadcast_test.cond.broadcast();
}
for (broadcast_test.threads) |t| {
t.join();
}
}
}
test "Condition - signal wakes one" {
// This test requires spawning threads
if (builtin.single_threaded) {
return error.SkipZigTest;
}
var num_runs: usize = 1;
const num_threads = 3;
const timeoutDelay = 10 * std.time.ns_per_ms;
while (num_runs > 0) : (num_runs -= 1) {
// Start multiple runner threads, wait for them to start and send the signal
// then. Expect that one thread wake up and all other times out.
//
// Test depends on delay in timedWait! If too small all threads can timeout
// before any one gets wake up.
const Runner = struct {
mutex: Mutex = .{},
cond: Condition = .{},
completed: Condition = .{},
count: usize = 0,
threads: [num_threads]std.Thread = undefined,
wakeups: usize = 0,
timeouts: usize = 0,
fn run(self: *@This()) void {
self.mutex.lock();
defer self.mutex.unlock();
// The last started thread tells the main test thread it's completed.
self.count += 1;
if (self.count == num_threads) {
self.completed.signal();
}
self.cond.timedWait(&self.mutex, timeoutDelay) catch {
self.timeouts += 1;
return;
};
self.wakeups += 1;
}
if (self.thread_id_to_wake <= num_threads) {
// Signal next thread to wake up.
self.thread_id_to_wake += 1;
self.cond.broadcast();
};
// Start threads
var runner = Runner{};
for (runner.threads) |*t| {
t.* = try std.Thread.spawn(.{}, Runner.run, .{&runner});
}
{
runner.mutex.lock();
defer runner.mutex.unlock();
// Wait for all the threads to spawn.
// timedWait() to detect any potential deadlocks.
while (runner.count != num_threads) {
try runner.completed.timedWait(&runner.mutex, 1 * std.time.ns_per_s);
}
}
};
var broadcast_test = BroadcastTest{};
var thread_id: usize = 1;
for (broadcast_test.threads) |*t| {
t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id });
thread_id += 1;
}
{
broadcast_test.mutex.lock();
defer broadcast_test.mutex.unlock();
// Wait for all the broadcast threads to spawn.
// timedWait() to detect any potential deadlocks.
while (broadcast_test.count != num_threads) {
try broadcast_test.completed.timedWait(
&broadcast_test.mutex,
1 * std.time.ns_per_s,
);
// Signal one thread, the others should get timeout.
runner.cond.signal();
}
// Signal thread 1 to wake up
broadcast_test.thread_id_to_wake = 1;
broadcast_test.cond.broadcast();
}
for (runner.threads) |t| {
t.join();
}
for (broadcast_test.threads) |t| {
t.join();
// Expect that only one got singal
try std.testing.expectEqual(runner.wakeups, 1);
try std.testing.expectEqual(runner.timeouts, num_threads - 1);
}
//std.debug.print("wakeups {d}\n", .{broadcast_test.wakeups});
}