diff --git a/lib/std/Io.zig b/lib/std/Io.zig index 41aebcb712..6da8187cd1 100644 --- a/lib/std/Io.zig +++ b/lib/std/Io.zig @@ -639,7 +639,7 @@ pub const VTable = struct { /// Copied and then passed to `start`. context: []const u8, context_alignment: std.mem.Alignment, - start: *const fn (context: *const anyopaque) void, + start: *const fn (*Group, context: *const anyopaque) void, ) void, groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) void, groupCancel: *const fn (?*anyopaque, *Group, token: *anyopaque) void, @@ -1005,7 +1005,8 @@ pub const Group = struct { pub fn async(g: *Group, io: Io, function: anytype, args: std.meta.ArgsTuple(@TypeOf(function))) void { const Args = @TypeOf(args); const TypeErased = struct { - fn start(context: *const anyopaque) void { + fn start(group: *Group, context: *const anyopaque) void { + _ = group; const args_casted: *const Args = @ptrCast(@alignCast(context)); @call(.auto, function, args_casted.*); } @@ -1033,6 +1034,85 @@ pub const Group = struct { } }; +pub fn Select(comptime U: type) type { + return struct { + io: Io, + group: Group, + queue: Queue(U), + outstanding: usize, + + const S = @This(); + + pub const Union = U; + + pub const Field = std.meta.FieldEnum(U); + + pub fn init(io: Io, buffer: []U) S { + return .{ + .io = io, + .queue = .init(buffer), + .group = .init, + .outstanding = 0, + }; + } + + /// Calls `function` with `args` asynchronously. The resource spawned is + /// owned by the select. + /// + /// `function` must have return type matching the `field` field of `Union`. + /// + /// `function` *may* be called immediately, before `async` returns. + /// + /// After this is called, `wait` or `cancel` must be called before the + /// select is deinitialized. + /// + /// Threadsafe. + /// + /// Related: + /// * `Io.async` + /// * `Group.async` + pub fn async( + s: *S, + comptime field: Field, + function: anytype, + args: std.meta.ArgsTuple(@TypeOf(function)), + ) void { + const Args = @TypeOf(args); + const TypeErased = struct { + fn start(group: *Group, context: *const anyopaque) void { + const args_casted: *const Args = @ptrCast(@alignCast(context)); + const unerased_select: *S = @fieldParentPtr("group", group); + const elem = @unionInit(U, @tagName(field), @call(.auto, function, args_casted.*)); + unerased_select.queue.putOneUncancelable(unerased_select.io, elem); + } + }; + _ = @atomicRmw(usize, &s.outstanding, .Add, 1, .monotonic); + s.io.vtable.groupAsync(s.io.userdata, &s.group, @ptrCast((&args)[0..1]), .of(Args), TypeErased.start); + } + + /// Blocks until another task of the select finishes. + /// + /// Asserts there is at least one more `outstanding` task. + /// + /// Not threadsafe. + pub fn wait(s: *S) Io.Cancelable!U { + s.outstanding -= 1; + return s.queue.getOne(s.io); + } + + /// Equivalent to `wait` but requests cancellation on all remaining + /// tasks owned by the select. + /// + /// It is illegal to call `wait` after this. + /// + /// Idempotent. Not threadsafe. + pub fn cancel(s: *S) void { + s.outstanding = 0; + s.group.cancel(s.io); + } + }; +} + pub const Mutex = struct { state: State, diff --git a/lib/std/Io/Threaded.zig b/lib/std/Io/Threaded.zig index f0d7f3ea4b..f654687684 100644 --- a/lib/std/Io/Threaded.zig +++ b/lib/std/Io/Threaded.zig @@ -458,7 +458,7 @@ const GroupClosure = struct { group: *Io.Group, /// Points to sibling `GroupClosure`. Used for walking the group to cancel all. node: std.SinglyLinkedList.Node, - func: *const fn (context: *anyopaque) void, + func: *const fn (*Io.Group, context: *anyopaque) void, context_alignment: std.mem.Alignment, context_len: usize, @@ -476,7 +476,7 @@ const GroupClosure = struct { return; } current_closure = closure; - gc.func(gc.contextPointer()); + gc.func(group, gc.contextPointer()); current_closure = null; // In case a cancel happens after successful task completion, prevents @@ -512,7 +512,7 @@ fn groupAsync( group: *Io.Group, context: []const u8, context_alignment: std.mem.Alignment, - start: *const fn (context: *const anyopaque) void, + start: *const fn (*Io.Group, context: *const anyopaque) void, ) void { if (builtin.single_threaded) return start(context.ptr); const pool: *Pool = @ptrCast(@alignCast(userdata)); @@ -520,7 +520,7 @@ fn groupAsync( const gpa = pool.allocator; const n = GroupClosure.contextEnd(context_alignment, context.len); const gc: *GroupClosure = @ptrCast(@alignCast(gpa.alignedAlloc(u8, .of(GroupClosure), n) catch { - return start(context.ptr); + return start(group, context.ptr); })); gc.* = .{ .closure = .{ @@ -548,7 +548,7 @@ fn groupAsync( pool.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch { pool.mutex.unlock(); gc.free(gpa); - return start(context.ptr); + return start(group, context.ptr); }; pool.run_queue.prepend(&gc.closure.node); @@ -558,7 +558,7 @@ fn groupAsync( assert(pool.run_queue.popFirst() == &gc.closure.node); pool.mutex.unlock(); gc.free(gpa); - return start(context.ptr); + return start(group, context.ptr); }; pool.threads.appendAssumeCapacity(thread); } @@ -2662,6 +2662,7 @@ fn netLookupFallible( .{ .address = addr }, .{ .canonical_name = copyCanon(options.canonical_name_buffer, name) }, }); + return; } else |_| {} } diff --git a/lib/std/Io/net.zig b/lib/std/Io/net.zig index 53cfee60c5..7ffc6098a2 100644 --- a/lib/std/Io/net.zig +++ b/lib/std/Io/net.zig @@ -315,8 +315,8 @@ pub const IpAddress = union(enum) { }; /// Initiates a connection-oriented network stream. - pub fn connect(address: *const IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream { - return io.vtable.netConnectIp(io.userdata, address, options); + pub fn connect(address: IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream { + return io.vtable.netConnectIp(io.userdata, &address, options); } }; diff --git a/lib/std/Io/net/HostName.zig b/lib/std/Io/net/HostName.zig index 788c40dd0c..291b2745f4 100644 --- a/lib/std/Io/net/HostName.zig +++ b/lib/std/Io/net/HostName.zig @@ -88,7 +88,7 @@ pub const LookupResult = union(enum) { /// Adds any number of `IpAddress` into resolved, exactly one canonical_name, /// and then always finishes by adding one `LookupResult.end` entry. /// -/// Guaranteed not to block if provided queue has capacity at least 8. +/// Guaranteed not to block if provided queue has capacity at least 16. pub fn lookup( host_name: HostName, io: Io, @@ -216,11 +216,13 @@ pub fn connect( } }); defer lookup_task.cancel(io); - var select: Io.Select(union(enum) { ip_connect: IpAddress.ConnectError!Stream }) = .init; - defer select.cancel(io); + const Result = union(enum) { connect_result: IpAddress.ConnectError!Stream }; + var finished_task_buffer: [results_buffer.len]Result = undefined; + var select: Io.Select(Result) = .init(io, &finished_task_buffer); + defer select.cancel(); while (results.getOne(io)) |result| switch (result) { - .address => |address| select.async(io, .ip_connect, IpAddress.connect, .{ address, io, options }), + .address => |address| select.async(.connect_result, IpAddress.connect, .{ address, io, options }), .canonical_name => continue, .end => |lookup_result| { try lookup_result; @@ -230,8 +232,8 @@ pub fn connect( var aggregate_error: ConnectError = error.UnknownHostName; - while (select.remaining != 0) switch (select.wait(io)) { - .ip_connect => |ip_connect| if (ip_connect) |stream| return stream else |err| switch (err) { + while (select.outstanding != 0) switch (try select.wait()) { + .connect_result => |connect_result| if (connect_result) |stream| return stream else |err| switch (err) { error.SystemResources => |e| return e, error.OptionUnsupported => |e| return e, error.ProcessFdQuotaExceeded => |e| return e, diff --git a/lib/std/Io/net/test.zig b/lib/std/Io/net/test.zig index edac076a6a..1c8f8bd8c7 100644 --- a/lib/std/Io/net/test.zig +++ b/lib/std/Io/net/test.zig @@ -1,5 +1,7 @@ -const std = @import("std"); const builtin = @import("builtin"); + +const std = @import("std"); +const Io = std.Io; const net = std.Io.net; const mem = std.mem; const testing = std.testing; @@ -126,33 +128,56 @@ test "resolve DNS" { const localhost_v4 = try net.IpAddress.parse("127.0.0.1", 80); const localhost_v6 = try net.IpAddress.parse("::2", 80); - var addresses_buffer: [8]net.IpAddress = undefined; - var canon_name_buffer: [net.HostName.max_len]u8 = undefined; - const result = try net.HostName.lookup(try .init("localhost"), io, .{ + var canonical_name_buffer: [net.HostName.max_len]u8 = undefined; + var results_buffer: [32]net.HostName.LookupResult = undefined; + var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer); + + net.HostName.lookup(try .init("localhost"), io, &results, .{ .port = 80, - .addresses_buffer = &addresses_buffer, - .canonical_name_buffer = &canon_name_buffer, + .canonical_name_buffer = &canonical_name_buffer, }); - for (addresses_buffer[0..result.addresses_len]) |addr| { - if (addr.eql(&localhost_v4) or addr.eql(&localhost_v6)) break; - } else @panic("unexpected address for localhost"); + + var addresses_found: usize = 0; + + while (results.getOne(io)) |result| switch (result) { + .address => |address| { + if (address.eql(&localhost_v4) or address.eql(&localhost_v6)) + addresses_found += 1; + }, + .canonical_name => |canonical_name| try testing.expectEqualStrings("localhost", canonical_name.bytes), + .end => |end| { + try end; + break; + }, + } else |err| return err; + + try testing.expect(addresses_found != 0); } { // The tests are required to work even when there is no Internet connection, // so some of these errors we must accept and skip the test. - var addresses_buffer: [8]net.IpAddress = undefined; - var canon_name_buffer: [net.HostName.max_len]u8 = undefined; - const result = net.HostName.lookup(try .init("example.com"), io, .{ + var canonical_name_buffer: [net.HostName.max_len]u8 = undefined; + var results_buffer: [16]net.HostName.LookupResult = undefined; + var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer); + + net.HostName.lookup(try .init("example.com"), io, &results, .{ .port = 80, - .addresses_buffer = &addresses_buffer, - .canonical_name_buffer = &canon_name_buffer, - }) catch |err| switch (err) { - error.UnknownHostName => return error.SkipZigTest, - error.NameServerFailure => return error.SkipZigTest, - else => return err, - }; - _ = result; + .canonical_name_buffer = &canonical_name_buffer, + }); + + while (results.getOne(io)) |result| switch (result) { + .address => {}, + .canonical_name => {}, + .end => |end| { + end catch |err| switch (err) { + error.UnknownHostName => return error.SkipZigTest, + error.NameServerFailure => return error.SkipZigTest, + else => return err, + }; + break; + }, + } else |err| return err; } }