From cb9f9bf58d9cf18c8fc70967c43240b1ea0f9ca1 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 24 Mar 2025 18:49:03 -0700 Subject: [PATCH] make thread pool satisfy async/await interface --- lib/std/Thread/Pool.zig | 62 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/lib/std/Thread/Pool.zig b/lib/std/Thread/Pool.zig index 86e8e87056..f1d2a7338f 100644 --- a/lib/std/Thread/Pool.zig +++ b/lib/std/Thread/Pool.zig @@ -1,7 +1,8 @@ -const std = @import("std"); const builtin = @import("builtin"); -const Pool = @This(); +const std = @import("std"); +const assert = std.debug.assert; const WaitGroup = @import("WaitGroup.zig"); +const Pool = @This(); mutex: std.Thread.Mutex = .{}, cond: std.Thread.Condition = .{}, @@ -307,3 +308,60 @@ pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void { pub fn getIdCount(pool: *Pool) usize { return @intCast(1 + pool.threads.len); } + +const AsyncClosure = struct { + func: *const fn (context: ?*anyopaque, result: *anyopaque) void, + context: ?*anyopaque, + run_node: std.Thread.Pool.RunQueue.Node = .{ .data = .{ .runFn = runFn } }, + reset_event: std.Thread.ResetEvent, + + fn runFn(runnable: *std.Thread.Pool.Runnable, _: ?usize) void { + const run_node: *std.Thread.Pool.RunQueue.Node = @fieldParentPtr("data", runnable); + const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); + closure.func(closure.context, closure.resultPointer()); + closure.reset_event.set(); + } + + fn resultPointer(closure: *@This()) [*]u8 { + const base: [*]u8 = @ptrCast(closure); + return base + @sizeOf(@This()); + } +}; + +pub fn @"async"( + userdata: ?*anyopaque, + eager_result: []u8, + context: ?*anyopaque, + start: *const fn (context: ?*anyopaque, result: *anyopaque) void, +) ?*std.Io.AnyFuture { + const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); + pool.mutex.lock(); + + const gpa = pool.allocator; + const n = @sizeOf(AsyncClosure) + eager_result.len; + const closure: *AsyncClosure = @alignCast(@ptrCast(gpa.alignedAlloc(u8, @alignOf(AsyncClosure), n) catch { + pool.mutex.unlock(); + start(context, eager_result.ptr); + return null; + })); + closure.* = .{ + .func = start, + .context = context, + .reset_event = .{}, + }; + pool.run_queue.prepend(&closure.run_node); + pool.mutex.unlock(); + + pool.cond.signal(); + + return @ptrCast(closure); +} + +pub fn @"await"(userdata: ?*anyopaque, any_future: *std.Io.AnyFuture, result: []u8) void { + const thread_pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata)); + const closure: *AsyncClosure = @ptrCast(@alignCast(any_future)); + closure.reset_event.wait(); + const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(closure); + @memcpy(result, (base + @sizeOf(AsyncClosure))[0..result.len]); + thread_pool.allocator.free(base[0 .. @sizeOf(AsyncClosure) + result.len]); +}