From 988f58341b60d61ec5be86f3d7dec5738c68d16a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 29 Mar 2025 20:58:07 -0700 Subject: [PATCH] std.Io: introduce cancellation --- lib/std/Io.zig | 63 +++++++-- lib/std/Io/EventLoop.zig | 16 ++- lib/std/Thread/Pool.zig | 277 ++++++++++++++++++++++++--------------- 3 files changed, 239 insertions(+), 117 deletions(-) diff --git a/lib/std/Io.zig b/lib/std/Io.zig index c995a2a76a..5497b73b68 100644 --- a/lib/std/Io.zig +++ b/lib/std/Io.zig @@ -922,6 +922,8 @@ vtable: *const VTable, pub const VTable = struct { /// If it returns `null` it means `result` has been already populated and /// `await` will be a no-op. + /// + /// Thread-safe. async: *const fn ( /// Corresponds to `Io.userdata`. userdata: ?*anyopaque, @@ -937,6 +939,8 @@ pub const VTable = struct { ) ?*AnyFuture, /// This function is only called when `async` returns a non-null value. + /// + /// Thread-safe. await: *const fn ( /// Corresponds to `Io.userdata`. userdata: ?*anyopaque, @@ -947,13 +951,41 @@ pub const VTable = struct { result: []u8, ) void, - createFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) fs.File.OpenError!fs.File, - openFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) fs.File.OpenError!fs.File, + /// Equivalent to `await` but initiates cancel request. + /// + /// This function is only called when `async` returns a non-null value. + /// + /// Thread-safe. + cancel: *const fn ( + /// Corresponds to `Io.userdata`. + userdata: ?*anyopaque, + /// The same value that was returned from `async`. + any_future: *AnyFuture, + /// Points to a buffer where the result is written. + /// The length is equal to size in bytes of result type. + result: []u8, + ) void, + + /// Returns whether the current thread of execution is known to have + /// been requested to cancel. + /// + /// Thread-safe. + cancelRequested: *const fn (?*anyopaque) bool, + + createFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) FileOpenError!fs.File, + openFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) FileOpenError!fs.File, closeFile: *const fn (?*anyopaque, fs.File) void, - read: *const fn (?*anyopaque, file: fs.File, buffer: []u8) fs.File.ReadError!usize, - write: *const fn (?*anyopaque, file: fs.File, buffer: []const u8) fs.File.WriteError!usize, + read: *const fn (?*anyopaque, file: fs.File, buffer: []u8) FileReadError!usize, + write: *const fn (?*anyopaque, file: fs.File, buffer: []const u8) FileWriteError!usize, }; +pub const OpenFlags = fs.File.OpenFlags; +pub const CreateFlags = fs.File.CreateFlags; + +pub const FileOpenError = fs.File.OpenError || error{AsyncCancel}; +pub const FileReadError = fs.File.ReadError || error{AsyncCancel}; +pub const FileWriteError = fs.File.WriteError || error{AsyncCancel}; + pub const AnyFuture = opaque {}; pub fn Future(Result: type) type { @@ -961,6 +993,17 @@ pub fn Future(Result: type) type { any_future: ?*AnyFuture, result: Result, + /// Equivalent to `await` but sets a flag observable to application + /// code that cancellation has been requested. + /// + /// Idempotent. + pub fn cancel(f: *@This(), io: Io) Result { + const any_future = f.any_future orelse return f.result; + io.vtable.cancel(io.userdata, any_future, @ptrCast((&f.result)[0..1])); + f.any_future = null; + return f.result; + } + pub fn await(f: *@This(), io: Io) Result { const any_future = f.any_future orelse return f.result; io.vtable.await(io.userdata, any_future, @ptrCast((&f.result)[0..1])); @@ -994,11 +1037,11 @@ pub fn async(io: Io, function: anytype, args: anytype) Future(@typeInfo(@TypeOf( return future; } -pub fn openFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) fs.File.OpenError!fs.File { +pub fn openFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) FileOpenError!fs.File { return io.vtable.openFile(io.userdata, dir, sub_path, flags); } -pub fn createFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) fs.File.OpenError!fs.File { +pub fn createFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) FileOpenError!fs.File { return io.vtable.createFile(io.userdata, dir, sub_path, flags); } @@ -1006,22 +1049,22 @@ pub fn closeFile(io: Io, file: fs.File) void { return io.vtable.closeFile(io.userdata, file); } -pub fn read(io: Io, file: fs.File, buffer: []u8) fs.File.ReadError!usize { +pub fn read(io: Io, file: fs.File, buffer: []u8) FileReadError!usize { return io.vtable.read(io.userdata, file, buffer); } -pub fn write(io: Io, file: fs.File, buffer: []const u8) fs.File.WriteError!usize { +pub fn write(io: Io, file: fs.File, buffer: []const u8) FileWriteError!usize { return io.vtable.write(io.userdata, file, buffer); } -pub fn writeAll(io: Io, file: fs.File, bytes: []const u8) fs.File.WriteError!void { +pub fn writeAll(io: Io, file: fs.File, bytes: []const u8) FileWriteError!void { var index: usize = 0; while (index < bytes.len) { index += try io.write(file, bytes[index..]); } } -pub fn readAll(io: Io, file: fs.File, buffer: []u8) fs.File.ReadError!usize { +pub fn readAll(io: Io, file: fs.File, buffer: []u8) FileReadError!usize { var index: usize = 0; while (index != buffer.len) { const amt = try io.read(file, buffer[index..]); diff --git a/lib/std/Io/EventLoop.zig b/lib/std/Io/EventLoop.zig index 7baca5c853..f4147b16ee 100644 --- a/lib/std/Io/EventLoop.zig +++ b/lib/std/Io/EventLoop.zig @@ -7,12 +7,13 @@ const EventLoop = @This(); const Alignment = std.mem.Alignment; const IoUring = std.os.linux.IoUring; +/// Must be a thread-safe allocator. gpa: Allocator, mutex: std.Thread.Mutex, -queue: std.DoublyLinkedList(void), +queue: std.DoublyLinkedList, /// Atomic copy of queue.len queue_len: u32, -free: std.DoublyLinkedList(void), +free: std.DoublyLinkedList, main_fiber: Fiber, idle_count: usize, threads: std.ArrayListUnmanaged(Thread), @@ -39,7 +40,7 @@ const Thread = struct { const Fiber = struct { context: Context, awaiter: ?*Fiber, - queue_node: std.DoublyLinkedList(void).Node, + queue_node: std.DoublyLinkedList.Node, result_align: Alignment, const finished: ?*Fiber = @ptrFromInt(std.mem.alignBackward(usize, std.math.maxInt(usize), @alignOf(Fiber))); @@ -447,6 +448,15 @@ pub fn @"await"(userdata: ?*anyopaque, any_future: *std.Io.AnyFuture, result: [] event_loop.recycle(future_fiber); } +pub fn cancel(userdata: ?*anyopaque, any_future: *std.Io.AnyFuture, result: []u8) void { + const event_loop: *EventLoop = @alignCast(@ptrCast(userdata)); + const future_fiber: *Fiber = @alignCast(@ptrCast(any_future)); + // TODO set a flag that makes all IO operations for this fiber return error.Canceled + if (@atomicLoad(?*Fiber, &future_fiber.awaiter, .acquire) != Fiber.finished) event_loop.yield(null, .{ .register_awaiter = &future_fiber.awaiter }); + @memcpy(result, future_fiber.resultPointer()); + event_loop.recycle(future_fiber); +} + pub fn createFile(userdata: ?*anyopaque, dir: std.fs.Dir, sub_path: []const u8, flags: std.fs.File.CreateFlags) std.fs.File.OpenError!std.fs.File { const el: *EventLoop = @ptrCast(@alignCast(userdata)); diff --git a/lib/std/Thread/Pool.zig b/lib/std/Thread/Pool.zig index 486d067330..03fdbe21a1 100644 --- a/lib/std/Thread/Pool.zig +++ b/lib/std/Thread/Pool.zig @@ -1,22 +1,27 @@ const builtin = @import("builtin"); const std = @import("std"); +const Allocator = std.mem.Allocator; const assert = std.debug.assert; const WaitGroup = @import("WaitGroup.zig"); +const Io = std.Io; const Pool = @This(); +/// Must be a thread-safe allocator. +allocator: std.mem.Allocator, mutex: std.Thread.Mutex = .{}, cond: std.Thread.Condition = .{}, run_queue: std.SinglyLinkedList = .{}, is_running: bool = true, -/// Must be a thread-safe allocator. -allocator: std.mem.Allocator, -threads: if (builtin.single_threaded) [0]std.Thread else []std.Thread, +threads: std.ArrayListUnmanaged(std.Thread), ids: if (builtin.single_threaded) struct { inline fn deinit(_: @This(), _: std.mem.Allocator) void {} fn getIndex(_: @This(), _: std.Thread.Id) usize { return 0; } } else std.AutoArrayHashMapUnmanaged(std.Thread.Id, void), +stack_size: usize, + +threadlocal var current_closure: ?*AsyncClosure = null; pub const Runnable = struct { runFn: RunProto, @@ -33,48 +38,36 @@ pub const Options = struct { }; pub fn init(pool: *Pool, options: Options) !void { - const allocator = options.allocator; + const gpa = options.allocator; + const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1); + const threads = try gpa.alloc(std.Thread, thread_count); + errdefer gpa.free(threads); pool.* = .{ - .allocator = allocator, - .threads = if (builtin.single_threaded) .{} else &.{}, + .allocator = gpa, + .threads = .initBuffer(threads), .ids = .{}, + .stack_size = options.stack_size, }; - if (builtin.single_threaded) { - return; - } + if (builtin.single_threaded) return; - const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1); if (options.track_ids) { - try pool.ids.ensureTotalCapacity(allocator, 1 + thread_count); + try pool.ids.ensureTotalCapacity(gpa, 1 + thread_count); pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {}); } - - // kill and join any threads we spawned and free memory on error. - pool.threads = try allocator.alloc(std.Thread, thread_count); - var spawned: usize = 0; - errdefer pool.join(spawned); - - for (pool.threads) |*thread| { - thread.* = try std.Thread.spawn(.{ - .stack_size = options.stack_size, - .allocator = allocator, - }, worker, .{pool}); - spawned += 1; - } } pub fn deinit(pool: *Pool) void { - pool.join(pool.threads.len); // kill and join all threads. - pool.ids.deinit(pool.allocator); + const gpa = pool.allocator; + pool.join(); + pool.threads.deinit(gpa); + pool.ids.deinit(gpa); pool.* = undefined; } -fn join(pool: *Pool, spawned: usize) void { - if (builtin.single_threaded) { - return; - } +fn join(pool: *Pool) void { + if (builtin.single_threaded) return; { pool.mutex.lock(); @@ -87,11 +80,7 @@ fn join(pool: *Pool, spawned: usize) void { // wake up any sleeping threads (this can be done outside the mutex) // then wait for all the threads we know are spawned to complete. pool.cond.broadcast(); - for (pool.threads[0..spawned]) |thread| { - thread.join(); - } - - pool.allocator.free(pool.threads); + for (pool.threads.items) |thread| thread.join(); } /// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and @@ -123,26 +112,34 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args } }; - { - pool.mutex.lock(); + pool.mutex.lock(); - const closure = pool.allocator.create(Closure) catch { - pool.mutex.unlock(); - @call(.auto, func, args); - wait_group.finish(); - return; - }; - closure.* = .{ - .arguments = args, - .pool = pool, - .wait_group = wait_group, - }; - - pool.run_queue.prepend(&closure.runnable.node); + const gpa = pool.allocator; + const closure = gpa.create(Closure) catch { pool.mutex.unlock(); + @call(.auto, func, args); + wait_group.finish(); + return; + }; + closure.* = .{ + .arguments = args, + .pool = pool, + .wait_group = wait_group, + }; + + pool.run_queue.prepend(&closure.runnable.node); + + if (pool.threads.items.len < pool.threads.capacity) { + pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{ + .stack_size = pool.stack_size, + .allocator = gpa, + }, worker, .{pool}) catch t: { + pool.threads.items.len -= 1; + break :t undefined; + }; } - // Notify waiting threads outside the lock to try and keep the critical section small. + pool.mutex.unlock(); pool.cond.signal(); } @@ -179,31 +176,39 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar } }; - { - pool.mutex.lock(); + pool.mutex.lock(); - const closure = pool.allocator.create(Closure) catch { - const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId()); - pool.mutex.unlock(); - @call(.auto, func, .{id.?} ++ args); - wait_group.finish(); - return; - }; - closure.* = .{ - .arguments = args, - .pool = pool, - .wait_group = wait_group, - }; - - pool.run_queue.prepend(&closure.runnable.node); + const gpa = pool.allocator; + const closure = gpa.create(Closure) catch { + const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId()); pool.mutex.unlock(); + @call(.auto, func, .{id.?} ++ args); + wait_group.finish(); + return; + }; + closure.* = .{ + .arguments = args, + .pool = pool, + .wait_group = wait_group, + }; + + pool.run_queue.prepend(&closure.runnable.node); + + if (pool.threads.items.len < pool.threads.capacity) { + pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{ + .stack_size = pool.stack_size, + .allocator = gpa, + }, worker, .{pool}) catch t: { + pool.threads.items.len -= 1; + break :t undefined; + }; } - // Notify waiting threads outside the lock to try and keep the critical section small. + pool.mutex.unlock(); pool.cond.signal(); } -pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void { +pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) void { if (builtin.single_threaded) { @call(.auto, func, args); return; @@ -222,20 +227,32 @@ pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void { } }; - { - pool.mutex.lock(); - defer pool.mutex.unlock(); + pool.mutex.lock(); - const closure = try pool.allocator.create(Closure); - closure.* = .{ - .arguments = args, - .pool = pool, + const gpa = pool.allocator; + const closure = gpa.create(Closure) catch { + pool.mutex.unlock(); + @call(.auto, func, args); + return; + }; + closure.* = .{ + .arguments = args, + .pool = pool, + }; + + pool.run_queue.prepend(&closure.runnable.node); + + if (pool.threads.items.len < pool.threads.capacity) { + pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{ + .stack_size = pool.stack_size, + .allocator = gpa, + }, worker, .{pool}) catch t: { + pool.threads.items.len -= 1; + break :t undefined; }; - - pool.run_queue.prepend(&closure.runnable.node); } - // Notify waiting threads outside the lock to try and keep the critical section small. + pool.mutex.unlock(); pool.cond.signal(); } @@ -254,7 +271,7 @@ test spawn { .allocator = std.testing.allocator, }); defer pool.deinit(); - try pool.spawn(TestFn.checkRun, .{&completed}); + pool.spawn(TestFn.checkRun, .{&completed}); } try std.testing.expectEqual(true, completed); @@ -306,15 +323,17 @@ pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void { } pub fn getIdCount(pool: *Pool) usize { - return @intCast(1 + pool.threads.len); + return @intCast(1 + pool.threads.items.len); } -pub fn io(pool: *Pool) std.Io { +pub fn io(pool: *Pool) Io { return .{ .userdata = pool, .vtable = &.{ .@"async" = @"async", .@"await" = @"await", + .cancel = cancel, + .cancelRequested = cancelRequested, .createFile = createFile, .openFile = openFile, .closeFile = closeFile, @@ -326,15 +345,17 @@ pub fn io(pool: *Pool) std.Io { const AsyncClosure = struct { func: *const fn (context: *anyopaque, result: *anyopaque) void, - run_node: std.Thread.Pool.RunQueue.Node = .{ .data = .{ .runFn = runFn } }, + runnable: Runnable = .{ .runFn = runFn }, reset_event: std.Thread.ResetEvent, + cancel_flag: bool, context_offset: usize, result_offset: usize, fn runFn(runnable: *std.Thread.Pool.Runnable, _: ?usize) void { - const run_node: *std.Thread.Pool.RunQueue.Node = @fieldParentPtr("data", runnable); - const closure: *AsyncClosure = @alignCast(@fieldParentPtr("run_node", run_node)); + const closure: *AsyncClosure = @alignCast(@fieldParentPtr("runnable", runnable)); + current_closure = closure; closure.func(closure.contextPointer(), closure.resultPointer()); + current_closure = null; closure.reset_event.set(); } @@ -359,16 +380,23 @@ const AsyncClosure = struct { const base: [*]u8 = @ptrCast(closure); return base + closure.context_offset; } + + fn waitAndFree(closure: *AsyncClosure, gpa: Allocator, result: []u8) void { + closure.reset_event.wait(); + const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(closure); + @memcpy(result, closure.resultPointer()[0..result.len]); + gpa.free(base[0 .. closure.result_offset + result.len]); + } }; -pub fn @"async"( +fn @"async"( userdata: ?*anyopaque, result: []u8, result_alignment: std.mem.Alignment, context: []const u8, context_alignment: std.mem.Alignment, start: *const fn (context: *const anyopaque, result: *anyopaque) void, -) ?*std.Io.AnyFuture { +) ?*Io.AnyFuture { const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); pool.mutex.lock(); @@ -386,46 +414,87 @@ pub fn @"async"( .context_offset = context_offset, .result_offset = result_offset, .reset_event = .{}, + .cancel_flag = false, }; @memcpy(closure.contextPointer()[0..context.len], context); - pool.run_queue.prepend(&closure.run_node); - pool.mutex.unlock(); + pool.run_queue.prepend(&closure.runnable.node); + if (pool.threads.items.len < pool.threads.capacity) { + pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{ + .stack_size = pool.stack_size, + .allocator = gpa, + }, worker, .{pool}) catch t: { + pool.threads.items.len -= 1; + break :t undefined; + }; + } + + pool.mutex.unlock(); pool.cond.signal(); return @ptrCast(closure); } -pub fn @"await"(userdata: ?*anyopaque, any_future: *std.Io.AnyFuture, result: []u8) void { - const thread_pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); +fn @"await"(userdata: ?*anyopaque, any_future: *Io.AnyFuture, result: []u8) void { + const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); const closure: *AsyncClosure = @ptrCast(@alignCast(any_future)); - closure.reset_event.wait(); - const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(closure); - @memcpy(result, closure.resultPointer()[0..result.len]); - thread_pool.allocator.free(base[0 .. closure.result_offset + result.len]); + closure.waitAndFree(pool.allocator, result); } -pub fn createFile(userdata: ?*anyopaque, dir: std.fs.Dir, sub_path: []const u8, flags: std.fs.File.CreateFlags) std.fs.File.OpenError!std.fs.File { - _ = userdata; +fn cancel(userdata: ?*anyopaque, any_future: *Io.AnyFuture, result: []u8) void { + const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); + const closure: *AsyncClosure = @ptrCast(@alignCast(any_future)); + @atomicStore(bool, &closure.cancel_flag, true, .seq_cst); + closure.waitAndFree(pool.allocator, result); +} + +fn cancelRequested(userdata: ?*anyopaque) bool { + const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); + _ = pool; + const closure = current_closure orelse return false; + return @atomicLoad(bool, &closure.cancel_flag, .unordered); +} + +fn checkCancel(pool: *Pool) error{AsyncCancel}!void { + if (cancelRequested(pool)) return error.AsyncCancel; +} + +pub fn createFile( + userdata: ?*anyopaque, + dir: std.fs.Dir, + sub_path: []const u8, + flags: std.fs.File.CreateFlags, +) Io.FileOpenError!std.fs.File { + const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); + try pool.checkCancel(); return dir.createFile(sub_path, flags); } -pub fn openFile(userdata: ?*anyopaque, dir: std.fs.Dir, sub_path: []const u8, flags: std.fs.File.OpenFlags) std.fs.File.OpenError!std.fs.File { - _ = userdata; +pub fn openFile( + userdata: ?*anyopaque, + dir: std.fs.Dir, + sub_path: []const u8, + flags: std.fs.File.OpenFlags, +) Io.FileOpenError!std.fs.File { + const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); + try pool.checkCancel(); return dir.openFile(sub_path, flags); } pub fn closeFile(userdata: ?*anyopaque, file: std.fs.File) void { - _ = userdata; + const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); + _ = pool; return file.close(); } -pub fn read(userdata: ?*anyopaque, file: std.fs.File, buffer: []u8) std.fs.File.ReadError!usize { - _ = userdata; +pub fn read(userdata: ?*anyopaque, file: std.fs.File, buffer: []u8) Io.FileReadError!usize { + const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); + try pool.checkCancel(); return file.read(buffer); } -pub fn write(userdata: ?*anyopaque, file: std.fs.File, buffer: []const u8) std.fs.File.WriteError!usize { - _ = userdata; +pub fn write(userdata: ?*anyopaque, file: std.fs.File, buffer: []const u8) Io.FileWriteError!usize { + const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); + try pool.checkCancel(); return file.write(buffer); }