std.Io: fix Group.wait unsoundness

Previously if a Group.wait was canceled, then a subsequent call to
wait() or cancel() would trip an assertion in the synchronization code.
This commit is contained in:
Andrew Kelley 2025-10-28 15:06:07 -07:00
parent 6c794ce7bc
commit 03fd132b1c
4 changed files with 33 additions and 82 deletions

View File

@ -653,8 +653,7 @@ pub const VTable = struct {
context_alignment: std.mem.Alignment, context_alignment: std.mem.Alignment,
start: *const fn (*Group, context: *const anyopaque) void, start: *const fn (*Group, context: *const anyopaque) void,
) void, ) void,
groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) Cancelable!void, groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
groupWaitUncancelable: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
groupCancel: *const fn (?*anyopaque, *Group, token: *anyopaque) void, groupCancel: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
/// Blocks until one of the futures from the list has a result ready, such /// Blocks until one of the futures from the list has a result ready, such
@ -1038,29 +1037,18 @@ pub const Group = struct {
io.vtable.groupAsync(io.userdata, g, @ptrCast((&args)[0..1]), .of(Args), TypeErased.start); io.vtable.groupAsync(io.userdata, g, @ptrCast((&args)[0..1]), .of(Args), TypeErased.start);
} }
/// Blocks until all tasks of the group finish. /// Blocks until all tasks of the group finish. During this time,
/// /// cancellation requests propagate to all members of the group.
/// On success, further calls to `wait`, `waitUncancelable`, and `cancel`
/// do nothing.
///
/// Not threadsafe.
pub fn wait(g: *Group, io: Io) Cancelable!void {
const token = g.token orelse return;
try io.vtable.groupWait(io.userdata, g, token);
g.token = null;
}
/// Equivalent to `wait` except uninterruptible.
/// ///
/// Idempotent. Not threadsafe. /// Idempotent. Not threadsafe.
pub fn waitUncancelable(g: *Group, io: Io) void { pub fn wait(g: *Group, io: Io) void {
const token = g.token orelse return; const token = g.token orelse return;
g.token = null; g.token = null;
io.vtable.groupWaitUncancelable(io.userdata, g, token); io.vtable.groupWait(io.userdata, g, token);
} }
/// Equivalent to `wait` but requests cancellation on all tasks owned by /// Equivalent to `wait` but immediately requests cancellation on all
/// the group. /// members of the group.
/// ///
/// Idempotent. Not threadsafe. /// Idempotent. Not threadsafe.
pub fn cancel(g: *Group, io: Io) void { pub fn cancel(g: *Group, io: Io) void {

View File

@ -859,7 +859,6 @@ pub fn io(k: *Kqueue) Io {
.groupAsync = groupAsync, .groupAsync = groupAsync,
.groupWait = groupWait, .groupWait = groupWait,
.groupWaitUncancelable = groupWaitUncancelable,
.groupCancel = groupCancel, .groupCancel = groupCancel,
.mutexLock = mutexLock, .mutexLock = mutexLock,
@ -1027,15 +1026,7 @@ fn groupAsync(
@panic("TODO"); @panic("TODO");
} }
fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) Io.Cancelable!void { fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
const k: *Kqueue = @ptrCast(@alignCast(userdata));
_ = k;
_ = group;
_ = token;
@panic("TODO");
}
fn groupWaitUncancelable(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
const k: *Kqueue = @ptrCast(@alignCast(userdata)); const k: *Kqueue = @ptrCast(@alignCast(userdata));
_ = k; _ = k;
_ = group; _ = group;

View File

@ -177,7 +177,6 @@ pub fn io(t: *Threaded) Io {
.groupAsync = groupAsync, .groupAsync = groupAsync,
.groupWait = groupWait, .groupWait = groupWait,
.groupWaitUncancelable = groupWaitUncancelable,
.groupCancel = groupCancel, .groupCancel = groupCancel,
.mutexLock = mutexLock, .mutexLock = mutexLock,
@ -274,7 +273,6 @@ pub fn ioBasic(t: *Threaded) Io {
.groupAsync = groupAsync, .groupAsync = groupAsync,
.groupWait = groupWait, .groupWait = groupWait,
.groupWaitUncancelable = groupWaitUncancelable,
.groupCancel = groupCancel, .groupCancel = groupCancel,
.mutexLock = mutexLock, .mutexLock = mutexLock,
@ -579,7 +577,9 @@ const GroupClosure = struct {
assert(cancel_tid == .canceling); assert(cancel_tid == .canceling);
} }
syncFinish(group_state, reset_event); const prev_state = group_state.fetchSub(sync_one_pending, .acq_rel);
assert((prev_state / sync_one_pending) > 0);
if (prev_state == (sync_one_pending | sync_is_waiting)) reset_event.set();
} }
fn free(gc: *GroupClosure, gpa: Allocator) void { fn free(gc: *GroupClosure, gpa: Allocator) void {
@ -602,29 +602,6 @@ const GroupClosure = struct {
const sync_is_waiting: usize = 1 << 0; const sync_is_waiting: usize = 1 << 0;
const sync_one_pending: usize = 1 << 1; const sync_one_pending: usize = 1 << 1;
fn syncStart(state: *std.atomic.Value(usize)) void {
const prev_state = state.fetchAdd(sync_one_pending, .monotonic);
assert((prev_state / sync_one_pending) < (std.math.maxInt(usize) / sync_one_pending));
}
fn syncFinish(state: *std.atomic.Value(usize), event: *ResetEvent) void {
const prev_state = state.fetchSub(sync_one_pending, .acq_rel);
assert((prev_state / sync_one_pending) > 0);
if (prev_state == (sync_one_pending | sync_is_waiting)) event.set();
}
fn syncWait(t: *Threaded, state: *std.atomic.Value(usize), event: *ResetEvent) Io.Cancelable!void {
const prev_state = state.fetchAdd(sync_is_waiting, .acquire);
assert(prev_state & sync_is_waiting == 0);
if ((prev_state / sync_one_pending) > 0) try event.wait(t);
}
fn syncWaitUncancelable(state: *std.atomic.Value(usize), event: *ResetEvent) void {
const prev_state = state.fetchAdd(sync_is_waiting, .acquire);
assert(prev_state & sync_is_waiting == 0);
if ((prev_state / sync_one_pending) > 0) event.waitUncancelable();
}
}; };
fn groupAsync( fn groupAsync(
@ -686,13 +663,14 @@ fn groupAsync(
// This needs to be done before unlocking the mutex to avoid a race with // This needs to be done before unlocking the mutex to avoid a race with
// the associated task finishing. // the associated task finishing.
const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state); const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
GroupClosure.syncStart(group_state); const prev_state = group_state.fetchAdd(GroupClosure.sync_one_pending, .monotonic);
assert((prev_state / GroupClosure.sync_one_pending) < (std.math.maxInt(usize) / GroupClosure.sync_one_pending));
t.mutex.unlock(); t.mutex.unlock();
t.cond.signal(); t.cond.signal();
} }
fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) Io.Cancelable!void { fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
const t: *Threaded = @ptrCast(@alignCast(userdata)); const t: *Threaded = @ptrCast(@alignCast(userdata));
const gpa = t.allocator; const gpa = t.allocator;
@ -700,26 +678,19 @@ fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) Io.Canc
const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state); const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
const reset_event: *ResetEvent = @ptrCast(&group.context); const reset_event: *ResetEvent = @ptrCast(&group.context);
try GroupClosure.syncWait(t, group_state, reset_event); const prev_state = group_state.fetchAdd(GroupClosure.sync_is_waiting, .acquire);
assert(prev_state & GroupClosure.sync_is_waiting == 0);
if ((prev_state / GroupClosure.sync_one_pending) > 0) reset_event.wait(t) catch |err| switch (err) {
error.Canceled => {
var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token)); var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
while (true) { while (true) {
const gc: *GroupClosure = @fieldParentPtr("node", node); const gc: *GroupClosure = @fieldParentPtr("node", node);
const node_next = node.next; gc.closure.requestCancel();
gc.free(gpa); node = node.next orelse break;
node = node_next orelse break;
} }
} reset_event.waitUncancelable();
},
fn groupWaitUncancelable(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void { };
const t: *Threaded = @ptrCast(@alignCast(userdata));
const gpa = t.allocator;
if (builtin.single_threaded) return;
const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
const reset_event: *ResetEvent = @ptrCast(&group.context);
GroupClosure.syncWaitUncancelable(group_state, reset_event);
var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token)); var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
while (true) { while (true) {
@ -747,7 +718,9 @@ fn groupCancel(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void
const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state); const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
const reset_event: *ResetEvent = @ptrCast(&group.context); const reset_event: *ResetEvent = @ptrCast(&group.context);
GroupClosure.syncWaitUncancelable(group_state, reset_event); const prev_state = group_state.fetchAdd(GroupClosure.sync_is_waiting, .acquire);
assert(prev_state & GroupClosure.sync_is_waiting == 0);
if ((prev_state / GroupClosure.sync_one_pending) > 0) reset_event.waitUncancelable();
{ {
var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token)); var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
@ -1549,7 +1522,7 @@ fn dirAccessPosix(
.FAULT => |err| return errnoBug(err), .FAULT => |err| return errnoBug(err),
.IO => return error.InputOutput, .IO => return error.InputOutput,
.NOMEM => return error.SystemResources, .NOMEM => return error.SystemResources,
.ILSEQ => return error.BadPathName, // TODO move to wasi .ILSEQ => return error.BadPathName,
else => |err| return posix.unexpectedErrno(err), else => |err| return posix.unexpectedErrno(err),
} }
} }

View File

@ -280,9 +280,8 @@ pub fn connectMany(
.address => |address| group.async(io, enqueueConnection, .{ address, io, results, options }), .address => |address| group.async(io, enqueueConnection, .{ address, io, results, options }),
.canonical_name => continue, .canonical_name => continue,
.end => |lookup_result| { .end => |lookup_result| {
results.putOneUncancelable(io, .{ group.wait(io);
.end = if (group.wait(io)) lookup_result else |err| err, results.putOneUncancelable(io, .{ .end = lookup_result });
});
return; return;
}, },
} else |err| switch (err) { } else |err| switch (err) {