diff --git a/lib/std/Io/Threaded.zig b/lib/std/Io/Threaded.zig index 6ceaee4505..58e91e25a5 100644 --- a/lib/std/Io/Threaded.zig +++ b/lib/std/Io/Threaded.zig @@ -22,12 +22,30 @@ mutex: std.Thread.Mutex = .{}, cond: std.Thread.Condition = .{}, run_queue: std.SinglyLinkedList = .{}, join_requested: bool = false, -threads: std.ArrayList(std.Thread), stack_size: usize, -cpu_count: usize, // 0 means no limit -concurrency_limit: usize, // 0 means no limit -available_thread_count: usize = 0, -one_shot_thread_count: usize = 0, +/// All threads are spawned detached; this is how we wait until they all exit. +wait_group: std.Thread.WaitGroup = .{}, +/// Maximum thread pool size (excluding main thread) when dispatching async +/// tasks. Until this limit, calls to `Io.async` when all threads are busy will +/// cause a new thread to be spawned and permanently added to the pool. After +/// this limit, calls to `Io.async` when all threads are busy run the task +/// immediately. +/// +/// Defaults to a number equal to logical CPU cores. +async_limit: Io.Limit, +/// Maximum thread pool size (excluding main thread) for dispatching concurrent +/// tasks. Until this limit, calls to `Io.concurrent` will increase the thread +/// pool size. +/// +/// concurrent tasks. After this number, calls to `Io.concurrent` return +/// `error.ConcurrencyUnavailable`. +concurrent_limit: Io.Limit = .unlimited, +/// Error from calling `std.Thread.getCpuCount` in `init`. +cpu_count_error: ?std.Thread.CpuCountError, +/// Number of threads that are unavailable to take tasks. To calculate +/// available count, subtract this from either `async_limit` or +/// `concurrent_limit`. +busy_count: usize = 0, wsa: if (is_windows) Wsa else struct {} = .{}, @@ -103,19 +121,18 @@ pub fn init( ) Threaded { if (builtin.single_threaded) return .init_single_threaded; + const cpu_count = std.Thread.getCpuCount(); + var t: Threaded = .{ .allocator = gpa, - .threads = .empty, .stack_size = std.Thread.SpawnConfig.default_stack_size, - .cpu_count = std.Thread.getCpuCount() catch 0, - .concurrency_limit = 0, + .async_limit = if (cpu_count) |n| .limited(n - 1) else |_| .nothing, + .cpu_count_error = if (cpu_count) |_| null else |e| e, .old_sig_io = undefined, .old_sig_pipe = undefined, .have_signal_handler = false, }; - t.threads.ensureTotalCapacity(gpa, t.cpu_count) catch {}; - if (posix.Sigaction != void) { // This causes sending `posix.SIG.IO` to thread to interrupt blocking // syscalls, returning `posix.E.INTR`. @@ -140,19 +157,17 @@ pub fn init( /// * `deinit` is safe, but unnecessary to call. pub const init_single_threaded: Threaded = .{ .allocator = .failing, - .threads = .empty, .stack_size = std.Thread.SpawnConfig.default_stack_size, - .cpu_count = 1, - .concurrency_limit = 0, + .async_limit = .nothing, + .cpu_count_error = null, + .concurrent_limit = .nothing, .old_sig_io = undefined, .old_sig_pipe = undefined, .have_signal_handler = false, }; pub fn deinit(t: *Threaded) void { - const gpa = t.allocator; t.join(); - t.threads.deinit(gpa); if (is_windows and t.wsa.status == .initialized) { if (ws2_32.WSACleanup() != 0) recoverableOsBugDetected(); } @@ -171,10 +186,12 @@ fn join(t: *Threaded) void { t.join_requested = true; } t.cond.broadcast(); - for (t.threads.items) |thread| thread.join(); + t.wait_group.wait(); } fn worker(t: *Threaded) void { + defer t.wait_group.finish(); + t.mutex.lock(); defer t.mutex.unlock(); @@ -184,20 +201,13 @@ fn worker(t: *Threaded) void { const closure: *Closure = @fieldParentPtr("node", closure_node); closure.start(closure); t.mutex.lock(); - t.available_thread_count += 1; + t.busy_count -= 1; } if (t.join_requested) break; t.cond.wait(&t.mutex); } } -fn oneShotWorker(t: *Threaded, closure: *Closure) void { - closure.start(closure); - t.mutex.lock(); - defer t.mutex.unlock(); - t.one_shot_thread_count -= 1; -} - pub fn io(t: *Threaded) Io { return .{ .userdata = t, @@ -488,7 +498,7 @@ fn async( start: *const fn (context: *const anyopaque, result: *anyopaque) void, ) ?*Io.AnyFuture { const t: *Threaded = @ptrCast(@alignCast(userdata)); - if (t.cpu_count == 1 or builtin.single_threaded) { + if (builtin.single_threaded or t.async_limit == .nothing) { start(context.ptr, result.ptr); return null; } @@ -500,35 +510,29 @@ fn async( t.mutex.lock(); - if (t.available_thread_count == 0) { - if (t.cpu_count != 0 and t.threads.items.len >= t.cpu_count) { - t.mutex.unlock(); - ac.deinit(gpa); - start(context.ptr, result.ptr); - return null; - } + const busy_count = t.busy_count; - t.threads.ensureUnusedCapacity(gpa, 1) catch { + if (busy_count >= @intFromEnum(t.async_limit)) { + t.mutex.unlock(); + ac.deinit(gpa); + start(context.ptr, result.ptr); + return null; + } + + t.busy_count = busy_count + 1; + + const pool_size = t.wait_group.value(); + if (pool_size - busy_count == 0) { + t.wait_group.start(); + const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch { + t.wait_group.finish(); + t.busy_count = busy_count; t.mutex.unlock(); ac.deinit(gpa); start(context.ptr, result.ptr); return null; }; - - const thread = std.Thread.spawn( - .{ .stack_size = t.stack_size }, - worker, - .{t}, - ) catch { - t.mutex.unlock(); - ac.deinit(gpa); - start(context.ptr, result.ptr); - return null; - }; - - t.threads.appendAssumeCapacity(thread); - } else { - t.available_thread_count -= 1; + thread.detach(); } t.run_queue.prepend(&ac.closure.node); @@ -550,47 +554,33 @@ fn concurrent( const t: *Threaded = @ptrCast(@alignCast(userdata)); const gpa = t.allocator; - const ac = AsyncClosure.init(gpa, result_len, result_alignment, context, context_alignment, start) catch { + const ac = AsyncClosure.init(gpa, result_len, result_alignment, context, context_alignment, start) catch return error.ConcurrencyUnavailable; - }; errdefer ac.deinit(gpa); t.mutex.lock(); defer t.mutex.unlock(); - // If there's an avilable thread, use it. - if (t.available_thread_count > 0) { - t.available_thread_count -= 1; - t.run_queue.prepend(&ac.closure.node); - t.cond.signal(); - return @ptrCast(ac); - } + const busy_count = t.busy_count; - // If we can spawn a normal worker, spawn it and use it. - if (t.cpu_count == 0 or t.threads.items.len < t.cpu_count) { - t.threads.ensureUnusedCapacity(gpa, 1) catch return error.ConcurrencyUnavailable; + if (busy_count >= @intFromEnum(t.concurrent_limit)) + return error.ConcurrencyUnavailable; + + t.busy_count = busy_count + 1; + errdefer t.busy_count = busy_count; + + const pool_size = t.wait_group.value(); + if (pool_size - busy_count == 0) { + t.wait_group.start(); + errdefer t.wait_group.finish(); const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch return error.ConcurrencyUnavailable; - - t.threads.appendAssumeCapacity(thread); - t.run_queue.prepend(&ac.closure.node); - t.cond.signal(); - return @ptrCast(ac); + thread.detach(); } - // If we have a concurrencty limit and we havent' hit it yet, - // spawn a new one-shot thread. - if (t.concurrency_limit != 0 and t.one_shot_thread_count >= t.concurrency_limit) - return error.ConcurrencyUnavailable; - - t.one_shot_thread_count += 1; - errdefer t.one_shot_thread_count -= 1; - - const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, oneShotWorker, .{ t, &ac.closure }) catch - return error.ConcurrencyUnavailable; - thread.detach(); - + t.run_queue.prepend(&ac.closure.node); + t.cond.signal(); return @ptrCast(ac); } @@ -684,41 +674,37 @@ fn groupAsync( context_alignment: std.mem.Alignment, start: *const fn (*Io.Group, context: *const anyopaque) void, ) void { - if (builtin.single_threaded) return start(group, context.ptr); - const t: *Threaded = @ptrCast(@alignCast(userdata)); + if (builtin.single_threaded or t.async_limit == .nothing) + return start(group, context.ptr); + const gpa = t.allocator; const gc = GroupClosure.init(gpa, t, group, context, context_alignment, start) catch return start(group, context.ptr); t.mutex.lock(); - if (t.available_thread_count == 0) { - if (t.cpu_count != 0 and t.threads.items.len >= t.cpu_count) { - t.mutex.unlock(); - gc.deinit(gpa); - return start(group, context.ptr); - } + const busy_count = t.busy_count; - t.threads.ensureUnusedCapacity(gpa, 1) catch { + if (busy_count >= @intFromEnum(t.async_limit)) { + t.mutex.unlock(); + gc.deinit(gpa); + return start(group, context.ptr); + } + + t.busy_count = busy_count + 1; + + const pool_size = t.wait_group.value(); + if (pool_size - busy_count == 0) { + t.wait_group.start(); + const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch { + t.wait_group.finish(); + t.busy_count = busy_count; t.mutex.unlock(); gc.deinit(gpa); return start(group, context.ptr); }; - - const thread = std.Thread.spawn( - .{ .stack_size = t.stack_size }, - worker, - .{t}, - ) catch { - t.mutex.unlock(); - gc.deinit(gpa); - return start(group, context.ptr); - }; - - t.threads.appendAssumeCapacity(thread); - } else { - t.available_thread_count -= 1; + thread.detach(); } // Append to the group linked list inside the mutex to make `Io.Group.async` thread-safe. diff --git a/lib/std/Io/Threaded/test.zig b/lib/std/Io/Threaded/test.zig index 7e6e687cf2..16afae7b63 100644 --- a/lib/std/Io/Threaded/test.zig +++ b/lib/std/Io/Threaded/test.zig @@ -10,7 +10,7 @@ test "concurrent vs main prevents deadlock via oversubscription" { defer threaded.deinit(); const io = threaded.io(); - threaded.cpu_count = 1; + threaded.async_limit = .nothing; var queue: Io.Queue(u8) = .init(&.{}); @@ -38,7 +38,7 @@ test "concurrent vs concurrent prevents deadlock via oversubscription" { defer threaded.deinit(); const io = threaded.io(); - threaded.cpu_count = 1; + threaded.async_limit = .nothing; var queue: Io.Queue(u8) = .init(&.{}); diff --git a/lib/std/Thread/WaitGroup.zig b/lib/std/Thread/WaitGroup.zig index a5970b7d69..8a9107192d 100644 --- a/lib/std/Thread/WaitGroup.zig +++ b/lib/std/Thread/WaitGroup.zig @@ -60,6 +60,10 @@ pub fn isDone(wg: *WaitGroup) bool { return (state / one_pending) == 0; } +pub fn value(wg: *WaitGroup) usize { + return wg.state.load(.monotonic) / one_pending; +} + // Spawns a new thread for the task. This is appropriate when the callee // delegates all work. pub fn spawnManager(