Merge pull request #25998 from ziglang/std.Io.Threaded-async-guarantee

std.Io: guarantee when async() returns, task is already completed or has been successfully assigned a unit of concurrency
This commit is contained in:
Andrew Kelley 2025-11-21 20:56:29 -08:00 committed by GitHub
commit 2ea55d7153
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 151 additions and 130 deletions

View File

@ -580,6 +580,9 @@ pub const VTable = struct {
/// If it returns `null` it means `result` has been already populated and
/// `await` will be a no-op.
///
/// When this function returns non-null, the implementation guarantees that
/// a unit of concurrency has been assigned to the returned task.
///
/// Thread-safe.
async: *const fn (
/// Corresponds to `Io.userdata`.
@ -1024,6 +1027,10 @@ pub const Group = struct {
///
/// `function` *may* be called immediately, before `async` returns.
///
/// When this function returns, it is guaranteed that `function` has
/// already been called and completed, or it has successfully been assigned
/// a unit of concurrency.
///
/// After this is called, `wait` or `cancel` must be called before the
/// group is deinitialized.
///
@ -1094,6 +1101,10 @@ pub fn Select(comptime U: type) type {
///
/// `function` *may* be called immediately, before `async` returns.
///
/// When this function returns, it is guaranteed that `function` has
/// already been called and completed, or it has successfully been
/// assigned a unit of concurrency.
///
/// After this is called, `wait` or `cancel` must be called before the
/// select is deinitialized.
///
@ -1524,8 +1535,11 @@ pub fn Queue(Elem: type) type {
/// not guaranteed to be available until `await` is called.
///
/// `function` *may* be called immediately, before `async` returns. This has
/// weaker guarantees than `concurrent`, making more portable and
/// reusable.
/// weaker guarantees than `concurrent`, making more portable and reusable.
///
/// When this function returns, it is guaranteed that `function` has already
/// been called and completed, or it has successfully been assigned a unit of
/// concurrency.
///
/// See also:
/// * `Group`

View File

@ -13,6 +13,7 @@ const net = std.Io.net;
const HostName = std.Io.net.HostName;
const IpAddress = std.Io.net.IpAddress;
const Allocator = std.mem.Allocator;
const Alignment = std.mem.Alignment;
const assert = std.debug.assert;
const posix = std.posix;
@ -22,10 +23,30 @@ mutex: std.Thread.Mutex = .{},
cond: std.Thread.Condition = .{},
run_queue: std.SinglyLinkedList = .{},
join_requested: bool = false,
threads: std.ArrayList(std.Thread),
stack_size: usize,
cpu_count: std.Thread.CpuCountError!usize,
concurrent_count: usize,
/// All threads are spawned detached; this is how we wait until they all exit.
wait_group: std.Thread.WaitGroup = .{},
/// Maximum thread pool size (excluding main thread) when dispatching async
/// tasks. Until this limit, calls to `Io.async` when all threads are busy will
/// cause a new thread to be spawned and permanently added to the pool. After
/// this limit, calls to `Io.async` when all threads are busy run the task
/// immediately.
///
/// Defaults to a number equal to logical CPU cores.
async_limit: Io.Limit,
/// Maximum thread pool size (excluding main thread) for dispatching concurrent
/// tasks. Until this limit, calls to `Io.concurrent` will increase the thread
/// pool size.
///
/// concurrent tasks. After this number, calls to `Io.concurrent` return
/// `error.ConcurrencyUnavailable`.
concurrent_limit: Io.Limit = .unlimited,
/// Error from calling `std.Thread.getCpuCount` in `init`.
cpu_count_error: ?std.Thread.CpuCountError,
/// Number of threads that are unavailable to take tasks. To calculate
/// available count, subtract this from either `async_limit` or
/// `concurrent_limit`.
busy_count: usize = 0,
wsa: if (is_windows) Wsa else struct {} = .{},
@ -70,8 +91,6 @@ const Closure = struct {
start: Start,
node: std.SinglyLinkedList.Node = .{},
cancel_tid: CancelId,
/// Whether this task bumps minimum number of threads in the pool.
is_concurrent: bool,
const Start = *const fn (*Closure) void;
@ -90,8 +109,6 @@ const Closure = struct {
}
};
pub const InitError = std.Thread.CpuCountError || Allocator.Error;
/// Related:
/// * `init_single_threaded`
pub fn init(
@ -103,21 +120,20 @@ pub fn init(
/// here.
gpa: Allocator,
) Threaded {
if (builtin.single_threaded) return .init_single_threaded;
const cpu_count = std.Thread.getCpuCount();
var t: Threaded = .{
.allocator = gpa,
.threads = .empty,
.stack_size = std.Thread.SpawnConfig.default_stack_size,
.cpu_count = std.Thread.getCpuCount(),
.concurrent_count = 0,
.async_limit = if (cpu_count) |n| .limited(n - 1) else |_| .nothing,
.cpu_count_error = if (cpu_count) |_| null else |e| e,
.old_sig_io = undefined,
.old_sig_pipe = undefined,
.have_signal_handler = false,
};
if (t.cpu_count) |n| {
t.threads.ensureTotalCapacityPrecise(gpa, n - 1) catch {};
} else |_| {}
if (posix.Sigaction != void) {
// This causes sending `posix.SIG.IO` to thread to interrupt blocking
// syscalls, returning `posix.E.INTR`.
@ -142,19 +158,17 @@ pub fn init(
/// * `deinit` is safe, but unnecessary to call.
pub const init_single_threaded: Threaded = .{
.allocator = .failing,
.threads = .empty,
.stack_size = std.Thread.SpawnConfig.default_stack_size,
.cpu_count = 1,
.concurrent_count = 0,
.async_limit = .nothing,
.cpu_count_error = null,
.concurrent_limit = .nothing,
.old_sig_io = undefined,
.old_sig_pipe = undefined,
.have_signal_handler = false,
};
pub fn deinit(t: *Threaded) void {
const gpa = t.allocator;
t.join();
t.threads.deinit(gpa);
if (is_windows and t.wsa.status == .initialized) {
if (ws2_32.WSACleanup() != 0) recoverableOsBugDetected();
}
@ -173,10 +187,12 @@ fn join(t: *Threaded) void {
t.join_requested = true;
}
t.cond.broadcast();
for (t.threads.items) |thread| thread.join();
t.wait_group.wait();
}
fn worker(t: *Threaded) void {
defer t.wait_group.finish();
t.mutex.lock();
defer t.mutex.unlock();
@ -184,12 +200,9 @@ fn worker(t: *Threaded) void {
while (t.run_queue.popFirst()) |closure_node| {
t.mutex.unlock();
const closure: *Closure = @fieldParentPtr("node", closure_node);
const is_concurrent = closure.is_concurrent;
closure.start(closure);
t.mutex.lock();
if (is_concurrent) {
t.concurrent_count -= 1;
}
t.busy_count -= 1;
}
if (t.join_requested) break;
t.cond.wait(&t.mutex);
@ -387,7 +400,7 @@ const AsyncClosure = struct {
func: *const fn (context: *anyopaque, result: *anyopaque) void,
reset_event: ResetEvent,
select_condition: ?*ResetEvent,
context_alignment: std.mem.Alignment,
context_alignment: Alignment,
result_offset: usize,
alloc_len: usize,
@ -432,11 +445,10 @@ const AsyncClosure = struct {
fn init(
gpa: Allocator,
mode: enum { async, concurrent },
result_len: usize,
result_alignment: std.mem.Alignment,
result_alignment: Alignment,
context: []const u8,
context_alignment: std.mem.Alignment,
context_alignment: Alignment,
func: *const fn (context: *const anyopaque, result: *anyopaque) void,
) Allocator.Error!*AsyncClosure {
const max_context_misalignment = context_alignment.toByteUnits() -| @alignOf(AsyncClosure);
@ -454,10 +466,6 @@ const AsyncClosure = struct {
.closure = .{
.cancel_tid = .none,
.start = start,
.is_concurrent = switch (mode) {
.async => false,
.concurrent => true,
},
},
.func = func,
.context_alignment = context_alignment,
@ -470,10 +478,15 @@ const AsyncClosure = struct {
return ac;
}
fn waitAndDeinit(ac: *AsyncClosure, gpa: Allocator, result: []u8) void {
ac.reset_event.waitUncancelable();
fn waitAndDeinit(ac: *AsyncClosure, t: *Threaded, result: []u8) void {
ac.reset_event.wait(t) catch |err| switch (err) {
error.Canceled => {
ac.closure.requestCancel();
ac.reset_event.waitUncancelable();
},
};
@memcpy(result, ac.resultPointer()[0..result.len]);
ac.deinit(gpa);
ac.deinit(t.allocator);
}
fn deinit(ac: *AsyncClosure, gpa: Allocator) void {
@ -485,60 +498,50 @@ const AsyncClosure = struct {
fn async(
userdata: ?*anyopaque,
result: []u8,
result_alignment: std.mem.Alignment,
result_alignment: Alignment,
context: []const u8,
context_alignment: std.mem.Alignment,
context_alignment: Alignment,
start: *const fn (context: *const anyopaque, result: *anyopaque) void,
) ?*Io.AnyFuture {
if (builtin.single_threaded) {
const t: *Threaded = @ptrCast(@alignCast(userdata));
if (builtin.single_threaded or t.async_limit == .nothing) {
start(context.ptr, result.ptr);
return null;
}
const t: *Threaded = @ptrCast(@alignCast(userdata));
const cpu_count = t.cpu_count catch {
return concurrent(userdata, result.len, result_alignment, context, context_alignment, start) catch {
start(context.ptr, result.ptr);
return null;
};
};
const gpa = t.allocator;
const ac = AsyncClosure.init(gpa, .async, result.len, result_alignment, context, context_alignment, start) catch {
const ac = AsyncClosure.init(gpa, result.len, result_alignment, context, context_alignment, start) catch {
start(context.ptr, result.ptr);
return null;
};
t.mutex.lock();
const thread_capacity = cpu_count - 1 + t.concurrent_count;
const busy_count = t.busy_count;
t.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
if (busy_count >= @intFromEnum(t.async_limit)) {
t.mutex.unlock();
ac.deinit(gpa);
start(context.ptr, result.ptr);
return null;
};
t.run_queue.prepend(&ac.closure.node);
if (t.threads.items.len < thread_capacity) {
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
if (t.threads.items.len == 0) {
assert(t.run_queue.popFirst() == &ac.closure.node);
t.mutex.unlock();
ac.deinit(gpa);
start(context.ptr, result.ptr);
return null;
}
// Rely on other workers to do it.
t.mutex.unlock();
t.cond.signal();
return @ptrCast(ac);
};
t.threads.appendAssumeCapacity(thread);
}
t.busy_count = busy_count + 1;
const pool_size = t.wait_group.value();
if (pool_size - busy_count == 0) {
t.wait_group.start();
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
t.wait_group.finish();
t.busy_count = busy_count;
t.mutex.unlock();
ac.deinit(gpa);
start(context.ptr, result.ptr);
return null;
};
thread.detach();
}
t.run_queue.prepend(&ac.closure.node);
t.mutex.unlock();
t.cond.signal();
return @ptrCast(ac);
@ -547,45 +550,42 @@ fn async(
fn concurrent(
userdata: ?*anyopaque,
result_len: usize,
result_alignment: std.mem.Alignment,
result_alignment: Alignment,
context: []const u8,
context_alignment: std.mem.Alignment,
context_alignment: Alignment,
start: *const fn (context: *const anyopaque, result: *anyopaque) void,
) Io.ConcurrentError!*Io.AnyFuture {
if (builtin.single_threaded) return error.ConcurrencyUnavailable;
const t: *Threaded = @ptrCast(@alignCast(userdata));
const cpu_count = t.cpu_count catch 1;
const gpa = t.allocator;
const ac = AsyncClosure.init(gpa, .concurrent, result_len, result_alignment, context, context_alignment, start) catch {
const ac = AsyncClosure.init(gpa, result_len, result_alignment, context, context_alignment, start) catch
return error.ConcurrencyUnavailable;
};
errdefer ac.deinit(gpa);
t.mutex.lock();
defer t.mutex.unlock();
t.concurrent_count += 1;
const thread_capacity = cpu_count - 1 + t.concurrent_count;
const busy_count = t.busy_count;
t.threads.ensureTotalCapacity(gpa, thread_capacity) catch {
t.mutex.unlock();
ac.deinit(gpa);
if (busy_count >= @intFromEnum(t.concurrent_limit))
return error.ConcurrencyUnavailable;
};
t.run_queue.prepend(&ac.closure.node);
t.busy_count = busy_count + 1;
errdefer t.busy_count = busy_count;
if (t.threads.items.len < thread_capacity) {
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
assert(t.run_queue.popFirst() == &ac.closure.node);
t.mutex.unlock();
ac.deinit(gpa);
const pool_size = t.wait_group.value();
if (pool_size - busy_count == 0) {
t.wait_group.start();
errdefer t.wait_group.finish();
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch
return error.ConcurrencyUnavailable;
};
t.threads.appendAssumeCapacity(thread);
thread.detach();
}
t.mutex.unlock();
t.run_queue.prepend(&ac.closure.node);
t.cond.signal();
return @ptrCast(ac);
}
@ -597,7 +597,7 @@ const GroupClosure = struct {
/// Points to sibling `GroupClosure`. Used for walking the group to cancel all.
node: std.SinglyLinkedList.Node,
func: *const fn (*Io.Group, context: *anyopaque) void,
context_alignment: std.mem.Alignment,
context_alignment: Alignment,
alloc_len: usize,
fn start(closure: *Closure) void {
@ -638,7 +638,7 @@ const GroupClosure = struct {
t: *Threaded,
group: *Io.Group,
context: []const u8,
context_alignment: std.mem.Alignment,
context_alignment: Alignment,
func: *const fn (*Io.Group, context: *const anyopaque) void,
) Allocator.Error!*GroupClosure {
const max_context_misalignment = context_alignment.toByteUnits() -| @alignOf(GroupClosure);
@ -652,7 +652,6 @@ const GroupClosure = struct {
.closure = .{
.cancel_tid = .none,
.start = start,
.is_concurrent = false,
},
.t = t,
.group = group,
@ -678,45 +677,48 @@ fn groupAsync(
userdata: ?*anyopaque,
group: *Io.Group,
context: []const u8,
context_alignment: std.mem.Alignment,
context_alignment: Alignment,
start: *const fn (*Io.Group, context: *const anyopaque) void,
) void {
if (builtin.single_threaded) return start(group, context.ptr);
const t: *Threaded = @ptrCast(@alignCast(userdata));
const cpu_count = t.cpu_count catch 1;
if (builtin.single_threaded or t.async_limit == .nothing)
return start(group, context.ptr);
const gpa = t.allocator;
const gc = GroupClosure.init(gpa, t, group, context, context_alignment, start) catch {
const gc = GroupClosure.init(gpa, t, group, context, context_alignment, start) catch
return start(group, context.ptr);
};
t.mutex.lock();
const busy_count = t.busy_count;
if (busy_count >= @intFromEnum(t.async_limit)) {
t.mutex.unlock();
gc.deinit(gpa);
return start(group, context.ptr);
}
t.busy_count = busy_count + 1;
const pool_size = t.wait_group.value();
if (pool_size - busy_count == 0) {
t.wait_group.start();
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
t.wait_group.finish();
t.busy_count = busy_count;
t.mutex.unlock();
gc.deinit(gpa);
return start(group, context.ptr);
};
thread.detach();
}
// Append to the group linked list inside the mutex to make `Io.Group.async` thread-safe.
gc.node = .{ .next = @ptrCast(@alignCast(group.token)) };
group.token = &gc.node;
const thread_capacity = cpu_count - 1 + t.concurrent_count;
t.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
t.mutex.unlock();
gc.deinit(gpa);
return start(group, context.ptr);
};
t.run_queue.prepend(&gc.closure.node);
if (t.threads.items.len < thread_capacity) {
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
assert(t.run_queue.popFirst() == &gc.closure.node);
t.mutex.unlock();
gc.deinit(gpa);
return start(group, context.ptr);
};
t.threads.appendAssumeCapacity(thread);
}
// This needs to be done before unlocking the mutex to avoid a race with
// the associated task finishing.
const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
@ -794,25 +796,25 @@ fn await(
userdata: ?*anyopaque,
any_future: *Io.AnyFuture,
result: []u8,
result_alignment: std.mem.Alignment,
result_alignment: Alignment,
) void {
_ = result_alignment;
const t: *Threaded = @ptrCast(@alignCast(userdata));
const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
closure.waitAndDeinit(t.allocator, result);
closure.waitAndDeinit(t, result);
}
fn cancel(
userdata: ?*anyopaque,
any_future: *Io.AnyFuture,
result: []u8,
result_alignment: std.mem.Alignment,
result_alignment: Alignment,
) void {
_ = result_alignment;
const t: *Threaded = @ptrCast(@alignCast(userdata));
const ac: *AsyncClosure = @ptrCast(@alignCast(any_future));
ac.closure.requestCancel();
ac.waitAndDeinit(t.allocator, result);
ac.waitAndDeinit(t, result);
}
fn cancelRequested(userdata: ?*anyopaque) bool {

View File

@ -10,7 +10,7 @@ test "concurrent vs main prevents deadlock via oversubscription" {
defer threaded.deinit();
const io = threaded.io();
threaded.cpu_count = 1;
threaded.async_limit = .nothing;
var queue: Io.Queue(u8) = .init(&.{});
@ -38,7 +38,7 @@ test "concurrent vs concurrent prevents deadlock via oversubscription" {
defer threaded.deinit();
const io = threaded.io();
threaded.cpu_count = 1;
threaded.async_limit = .nothing;
var queue: Io.Queue(u8) = .init(&.{});

View File

@ -1,13 +1,14 @@
//! This struct represents a kernel thread, and acts as a namespace for concurrency
//! primitives that operate on kernel threads. For concurrency primitives that support
//! both evented I/O and async I/O, see the respective names in the top level std namespace.
//! This struct represents a kernel thread, and acts as a namespace for
//! concurrency primitives that operate on kernel threads. For concurrency
//! primitives that interact with the I/O interface, see `std.Io`.
const std = @import("std.zig");
const builtin = @import("builtin");
const math = std.math;
const assert = std.debug.assert;
const target = builtin.target;
const native_os = builtin.os.tag;
const std = @import("std.zig");
const math = std.math;
const assert = std.debug.assert;
const posix = std.posix;
const windows = std.os.windows;
const testing = std.testing;

View File

@ -60,6 +60,10 @@ pub fn isDone(wg: *WaitGroup) bool {
return (state / one_pending) == 0;
}
pub fn value(wg: *WaitGroup) usize {
return wg.state.load(.monotonic) / one_pending;
}
// Spawns a new thread for the task. This is appropriate when the callee
// delegates all work.
pub fn spawnManager(