std.Thread.Mutex.Recursive: alternate implementation

This version is simpler. Thanks King!
This commit is contained in:
Andrew Kelley 2024-06-12 18:07:39 -07:00
parent fad223d92e
commit 5fc1f8a32b

View File

@ -30,7 +30,13 @@ pub const init: Recursive = .{
/// Otherwise, returns `true` and the caller should `unlock()` the Mutex to release it. /// Otherwise, returns `true` and the caller should `unlock()` the Mutex to release it.
pub fn tryLock(r: *Recursive) bool { pub fn tryLock(r: *Recursive) bool {
const current_thread_id = std.Thread.getCurrentId(); const current_thread_id = std.Thread.getCurrentId();
return tryLockInner(r, current_thread_id); if (@atomicLoad(std.Thread.Id, &r.thread_id, .unordered) != current_thread_id) {
if (!r.mutex.tryLock()) return false;
assert(r.lock_count == 0);
@atomicStore(std.Thread.Id, &r.thread_id, current_thread_id, .unordered);
}
r.lock_count += 1;
return true;
} }
/// Acquires the `Mutex`, blocking the current thread while the mutex is /// Acquires the `Mutex`, blocking the current thread while the mutex is
@ -42,12 +48,12 @@ pub fn tryLock(r: *Recursive) bool {
/// of whether the lock was already held by the same thread. /// of whether the lock was already held by the same thread.
pub fn lock(r: *Recursive) void { pub fn lock(r: *Recursive) void {
const current_thread_id = std.Thread.getCurrentId(); const current_thread_id = std.Thread.getCurrentId();
if (!tryLockInner(r, current_thread_id)) { if (@atomicLoad(std.Thread.Id, &r.thread_id, .unordered) != current_thread_id) {
r.mutex.lock(); r.mutex.lock();
assert(r.lock_count == 0); assert(r.lock_count == 0);
r.lock_count = 1; @atomicStore(std.Thread.Id, &r.thread_id, current_thread_id, .unordered);
@atomicStore(std.Thread.Id, &r.thread_id, current_thread_id, .monotonic);
} }
r.lock_count += 1;
} }
/// Releases the `Mutex` which was previously acquired with `lock` or `tryLock`. /// Releases the `Mutex` which was previously acquired with `lock` or `tryLock`.
@ -57,30 +63,10 @@ pub fn lock(r: *Recursive) void {
pub fn unlock(r: *Recursive) void { pub fn unlock(r: *Recursive) void {
r.lock_count -= 1; r.lock_count -= 1;
if (r.lock_count == 0) { if (r.lock_count == 0) {
// Prevent race where: @atomicStore(std.Thread.Id, &r.thread_id, invalid_thread_id, .unordered);
// * Thread A obtains lock and has not yet stored the new thread id.
// * Thread B loads the thread id after tryLock() false and observes stale thread id.
@atomicStore(std.Thread.Id, &r.thread_id, invalid_thread_id, .seq_cst);
r.mutex.unlock(); r.mutex.unlock();
} }
} }
fn tryLockInner(r: *Recursive, current_thread_id: std.Thread.Id) bool {
if (r.mutex.tryLock()) {
assert(r.lock_count == 0);
r.lock_count = 1;
@atomicStore(std.Thread.Id, &r.thread_id, current_thread_id, .monotonic);
return true;
}
const locked_thread_id = @atomicLoad(std.Thread.Id, &r.thread_id, .monotonic);
if (locked_thread_id == current_thread_id) {
r.lock_count += 1;
return true;
}
return false;
}
/// A value that does not alias any other thread id. /// A value that does not alias any other thread id.
const invalid_thread_id: std.Thread.Id = 0; const invalid_thread_id: std.Thread.Id = std.math.maxInt(std.Thread.Id);