diff --git a/lib/std/Thread/Condition.zig b/lib/std/Thread/Condition.zig index 7a479e5540..fb48db8e53 100644 --- a/lib/std/Thread/Condition.zig +++ b/lib/std/Thread/Condition.zig @@ -17,6 +17,10 @@ pub fn wait(cond: *Condition, mutex: *Mutex) void { cond.impl.wait(mutex); } +pub fn timedWait(cond: *Condition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void { + try cond.impl.timedWait(mutex, timeout_ns); +} + pub fn signal(cond: *Condition) void { cond.impl.signal(); } @@ -41,6 +45,14 @@ pub const SingleThreadedCondition = struct { unreachable; // deadlock detected } + pub fn timedWait(cond: *SingleThreadedCondition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void { + _ = cond; + _ = mutex; + _ = timeout_ns; + std.time.sleep(timeout_ns); + return error.TimedOut; + } + pub fn signal(cond: *SingleThreadedCondition) void { _ = cond; } @@ -63,6 +75,25 @@ pub const WindowsCondition = struct { assert(rc != windows.FALSE); } + pub fn timedWait(cond: *WindowsCondition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void { + var timeout_checked = std.math.cast(windows.DWORD, timeout_ns / std.time.ns_per_ms) catch overflow: { + break :overflow std.math.maxInt(windows.DWORD); + }; + + // Handle the case where timeout is INFINITE, otherwise SleepConditionVariableSRW's time-out never elapses + const timeout_overflowed = timeout_checked == windows.INFINITE; + timeout_checked -= @boolToInt(timeout_overflowed); + + const rc = windows.kernel32.SleepConditionVariableSRW( + &cond.cond, + &mutex.impl.srwlock, + timeout_checked, + @as(windows.ULONG, 0), + ); + if (rc == windows.FALSE and windows.kernel32.GetLastError() == windows.Win32Error.TIMEOUT) return error.TimedOut; + assert(rc != windows.FALSE); + } + pub fn signal(cond: *WindowsCondition) void { windows.kernel32.WakeConditionVariable(&cond.cond); } @@ -80,6 +111,24 @@ pub const PthreadCondition = struct { assert(rc == .SUCCESS); } + pub fn timedWait(cond: *PthreadCondition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void { + var ts: std.os.timespec = undefined; + std.os.clock_gettime(std.os.CLOCK.REALTIME, &ts) catch unreachable; + ts.tv_sec += @intCast(@TypeOf(ts.tv_sec), timeout_ns / std.time.ns_per_s); + ts.tv_nsec += @intCast(@TypeOf(ts.tv_nsec), timeout_ns % std.time.ns_per_s); + if (ts.tv_nsec >= std.time.ns_per_s) { + ts.tv_sec += 1; + ts.tv_nsec -= std.time.ns_per_s; + } + + const rc = std.c.pthread_cond_timedwait(&cond.cond, &mutex.impl.pthread_mutex, &ts); + return switch (rc) { + .SUCCESS => {}, + .TIMEDOUT => error.TimedOut, + else => unreachable, + }; + } + pub fn signal(cond: *PthreadCondition) void { const rc = std.c.pthread_cond_signal(&cond.cond); assert(rc == .SUCCESS); @@ -100,6 +149,7 @@ pub const AtomicCondition = struct { pub const QueueItem = struct { futex: i32 = 0, + dequeued: bool = false, fn wait(cond: *@This()) void { while (@atomicLoad(i32, &cond.futex, .Acquire) == 0) { @@ -122,6 +172,39 @@ pub const AtomicCondition = struct { } } + pub fn timedWait(cond: *@This(), timeout_ns: u64) error{TimedOut}!void { + const start_time = std.time.nanoTimestamp(); + while (@atomicLoad(i32, &cond.futex, .Acquire) == 0) { + switch (builtin.os.tag) { + .linux => { + var ts: std.os.timespec = undefined; + ts.tv_sec = @intCast(@TypeOf(ts.tv_sec), timeout_ns / std.time.ns_per_s); + ts.tv_nsec = @intCast(@TypeOf(ts.tv_nsec), timeout_ns % std.time.ns_per_s); + switch (linux.getErrno(linux.futex_wait( + &cond.futex, + linux.FUTEX.PRIVATE_FLAG | linux.FUTEX.WAIT, + 0, + &ts, + ))) { + .SUCCESS => {}, + .INTR => {}, + .AGAIN => {}, + .TIMEDOUT => return error.TimedOut, + .INVAL => {}, // possibly timeout overflow + .FAULT => unreachable, + else => unreachable, + } + }, + else => { + if (std.time.nanoTimestamp() - start_time >= timeout_ns) { + return error.TimedOut; + } + std.atomic.spinLoopHint(); + }, + } + } + } + fn notify(cond: *@This()) void { @atomicStore(i32, &cond.futex, 1, .Release); @@ -158,6 +241,41 @@ pub const AtomicCondition = struct { mutex.lock(); } + pub fn timedWait(cond: *AtomicCondition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void { + var waiter = QueueList.Node{ .data = .{} }; + + { + cond.queue_mutex.lock(); + defer cond.queue_mutex.unlock(); + + cond.queue_list.prepend(&waiter); + @atomicStore(bool, &cond.pending, true, .SeqCst); + } + + var timed_out = false; + mutex.unlock(); + defer mutex.lock(); + waiter.data.timedWait(timeout_ns) catch |err| switch (err) { + error.TimedOut => { + defer if (!timed_out) { + waiter.data.wait(); + }; + cond.queue_mutex.lock(); + defer cond.queue_mutex.unlock(); + + if (!waiter.data.dequeued) { + timed_out = true; + cond.queue_list.remove(&waiter); + } + }, + else => unreachable, + }; + + if (timed_out) { + return error.TimedOut; + } + } + pub fn signal(cond: *AtomicCondition) void { if (@atomicLoad(bool, &cond.pending, .SeqCst) == false) return; @@ -167,12 +285,16 @@ pub const AtomicCondition = struct { defer cond.queue_mutex.unlock(); const maybe_waiter = cond.queue_list.popFirst(); + if (maybe_waiter) |waiter| { + waiter.data.dequeued = true; + } @atomicStore(bool, &cond.pending, cond.queue_list.first != null, .SeqCst); break :blk maybe_waiter; }; - if (maybe_waiter) |waiter| + if (maybe_waiter) |waiter| { waiter.data.notify(); + } } pub fn broadcast(cond: *AtomicCondition) void { @@ -186,12 +308,19 @@ pub const AtomicCondition = struct { defer cond.queue_mutex.unlock(); const waiters = cond.queue_list; + + var it = waiters.first; + while (it) |node| : (it = node.next) { + node.data.dequeued = true; + } + cond.queue_list = .{}; break :blk waiters; }; - while (waiters.popFirst()) |waiter| + while (waiters.popFirst()) |waiter| { waiter.data.notify(); + } } }; @@ -238,3 +367,45 @@ test "Thread.Condition" { for (threads) |t| t.join(); } + +test "Thread.Condition.timedWait" { + if (builtin.single_threaded) { + return error.SkipZigTest; + } + + var cond = Condition{}; + var mut = Mutex{}; + + // Expect a timeout, as the condition variable is never signaled + { + mut.lock(); + defer mut.unlock(); + try testing.expectError(error.TimedOut, cond.timedWait(&mut, 10 * std.time.ns_per_ms)); + } + + // Expect a signal before timeout + { + const TestContext = struct { + cond: *Condition, + mutex: *Mutex, + n: *u32, + fn worker(ctx: *@This()) void { + ctx.mutex.lock(); + defer ctx.mutex.unlock(); + ctx.n.* = 1; + ctx.cond.signal(); + } + }; + + var n: u32 = 0; + + var ctx = TestContext{ .cond = &cond, .mutex = &mut, .n = &n }; + mut.lock(); + var thread = try std.Thread.spawn(.{}, TestContext.worker, .{&ctx}); + // Looped check to handle spurious wakeups + while (n != 1) try cond.timedWait(&mut, 500 * std.time.ns_per_ms); + mut.unlock(); + try testing.expect(n == 1); + thread.join(); + } +}