std.Io: implement Select

and finish implementation of HostName.connect
This commit is contained in:
Andrew Kelley 2025-10-15 00:36:02 -07:00
parent 35ce907c06
commit d3c4158a10
5 changed files with 144 additions and 36 deletions

View File

@ -639,7 +639,7 @@ pub const VTable = struct {
/// Copied and then passed to `start`.
context: []const u8,
context_alignment: std.mem.Alignment,
start: *const fn (context: *const anyopaque) void,
start: *const fn (*Group, context: *const anyopaque) void,
) void,
groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
groupCancel: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
@ -1005,7 +1005,8 @@ pub const Group = struct {
pub fn async(g: *Group, io: Io, function: anytype, args: std.meta.ArgsTuple(@TypeOf(function))) void {
const Args = @TypeOf(args);
const TypeErased = struct {
fn start(context: *const anyopaque) void {
fn start(group: *Group, context: *const anyopaque) void {
_ = group;
const args_casted: *const Args = @ptrCast(@alignCast(context));
@call(.auto, function, args_casted.*);
}
@ -1033,6 +1034,85 @@ pub const Group = struct {
}
};
pub fn Select(comptime U: type) type {
return struct {
io: Io,
group: Group,
queue: Queue(U),
outstanding: usize,
const S = @This();
pub const Union = U;
pub const Field = std.meta.FieldEnum(U);
pub fn init(io: Io, buffer: []U) S {
return .{
.io = io,
.queue = .init(buffer),
.group = .init,
.outstanding = 0,
};
}
/// Calls `function` with `args` asynchronously. The resource spawned is
/// owned by the select.
///
/// `function` must have return type matching the `field` field of `Union`.
///
/// `function` *may* be called immediately, before `async` returns.
///
/// After this is called, `wait` or `cancel` must be called before the
/// select is deinitialized.
///
/// Threadsafe.
///
/// Related:
/// * `Io.async`
/// * `Group.async`
pub fn async(
s: *S,
comptime field: Field,
function: anytype,
args: std.meta.ArgsTuple(@TypeOf(function)),
) void {
const Args = @TypeOf(args);
const TypeErased = struct {
fn start(group: *Group, context: *const anyopaque) void {
const args_casted: *const Args = @ptrCast(@alignCast(context));
const unerased_select: *S = @fieldParentPtr("group", group);
const elem = @unionInit(U, @tagName(field), @call(.auto, function, args_casted.*));
unerased_select.queue.putOneUncancelable(unerased_select.io, elem);
}
};
_ = @atomicRmw(usize, &s.outstanding, .Add, 1, .monotonic);
s.io.vtable.groupAsync(s.io.userdata, &s.group, @ptrCast((&args)[0..1]), .of(Args), TypeErased.start);
}
/// Blocks until another task of the select finishes.
///
/// Asserts there is at least one more `outstanding` task.
///
/// Not threadsafe.
pub fn wait(s: *S) Io.Cancelable!U {
s.outstanding -= 1;
return s.queue.getOne(s.io);
}
/// Equivalent to `wait` but requests cancellation on all remaining
/// tasks owned by the select.
///
/// It is illegal to call `wait` after this.
///
/// Idempotent. Not threadsafe.
pub fn cancel(s: *S) void {
s.outstanding = 0;
s.group.cancel(s.io);
}
};
}
pub const Mutex = struct {
state: State,

View File

@ -458,7 +458,7 @@ const GroupClosure = struct {
group: *Io.Group,
/// Points to sibling `GroupClosure`. Used for walking the group to cancel all.
node: std.SinglyLinkedList.Node,
func: *const fn (context: *anyopaque) void,
func: *const fn (*Io.Group, context: *anyopaque) void,
context_alignment: std.mem.Alignment,
context_len: usize,
@ -476,7 +476,7 @@ const GroupClosure = struct {
return;
}
current_closure = closure;
gc.func(gc.contextPointer());
gc.func(group, gc.contextPointer());
current_closure = null;
// In case a cancel happens after successful task completion, prevents
@ -512,7 +512,7 @@ fn groupAsync(
group: *Io.Group,
context: []const u8,
context_alignment: std.mem.Alignment,
start: *const fn (context: *const anyopaque) void,
start: *const fn (*Io.Group, context: *const anyopaque) void,
) void {
if (builtin.single_threaded) return start(context.ptr);
const pool: *Pool = @ptrCast(@alignCast(userdata));
@ -520,7 +520,7 @@ fn groupAsync(
const gpa = pool.allocator;
const n = GroupClosure.contextEnd(context_alignment, context.len);
const gc: *GroupClosure = @ptrCast(@alignCast(gpa.alignedAlloc(u8, .of(GroupClosure), n) catch {
return start(context.ptr);
return start(group, context.ptr);
}));
gc.* = .{
.closure = .{
@ -548,7 +548,7 @@ fn groupAsync(
pool.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
pool.mutex.unlock();
gc.free(gpa);
return start(context.ptr);
return start(group, context.ptr);
};
pool.run_queue.prepend(&gc.closure.node);
@ -558,7 +558,7 @@ fn groupAsync(
assert(pool.run_queue.popFirst() == &gc.closure.node);
pool.mutex.unlock();
gc.free(gpa);
return start(context.ptr);
return start(group, context.ptr);
};
pool.threads.appendAssumeCapacity(thread);
}
@ -2662,6 +2662,7 @@ fn netLookupFallible(
.{ .address = addr },
.{ .canonical_name = copyCanon(options.canonical_name_buffer, name) },
});
return;
} else |_| {}
}

View File

@ -315,8 +315,8 @@ pub const IpAddress = union(enum) {
};
/// Initiates a connection-oriented network stream.
pub fn connect(address: *const IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
return io.vtable.netConnectIp(io.userdata, address, options);
pub fn connect(address: IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
return io.vtable.netConnectIp(io.userdata, &address, options);
}
};

View File

@ -88,7 +88,7 @@ pub const LookupResult = union(enum) {
/// Adds any number of `IpAddress` into resolved, exactly one canonical_name,
/// and then always finishes by adding one `LookupResult.end` entry.
///
/// Guaranteed not to block if provided queue has capacity at least 8.
/// Guaranteed not to block if provided queue has capacity at least 16.
pub fn lookup(
host_name: HostName,
io: Io,
@ -216,11 +216,13 @@ pub fn connect(
} });
defer lookup_task.cancel(io);
var select: Io.Select(union(enum) { ip_connect: IpAddress.ConnectError!Stream }) = .init;
defer select.cancel(io);
const Result = union(enum) { connect_result: IpAddress.ConnectError!Stream };
var finished_task_buffer: [results_buffer.len]Result = undefined;
var select: Io.Select(Result) = .init(io, &finished_task_buffer);
defer select.cancel();
while (results.getOne(io)) |result| switch (result) {
.address => |address| select.async(io, .ip_connect, IpAddress.connect, .{ address, io, options }),
.address => |address| select.async(.connect_result, IpAddress.connect, .{ address, io, options }),
.canonical_name => continue,
.end => |lookup_result| {
try lookup_result;
@ -230,8 +232,8 @@ pub fn connect(
var aggregate_error: ConnectError = error.UnknownHostName;
while (select.remaining != 0) switch (select.wait(io)) {
.ip_connect => |ip_connect| if (ip_connect) |stream| return stream else |err| switch (err) {
while (select.outstanding != 0) switch (try select.wait()) {
.connect_result => |connect_result| if (connect_result) |stream| return stream else |err| switch (err) {
error.SystemResources => |e| return e,
error.OptionUnsupported => |e| return e,
error.ProcessFdQuotaExceeded => |e| return e,

View File

@ -1,5 +1,7 @@
const std = @import("std");
const builtin = @import("builtin");
const std = @import("std");
const Io = std.Io;
const net = std.Io.net;
const mem = std.mem;
const testing = std.testing;
@ -126,33 +128,56 @@ test "resolve DNS" {
const localhost_v4 = try net.IpAddress.parse("127.0.0.1", 80);
const localhost_v6 = try net.IpAddress.parse("::2", 80);
var addresses_buffer: [8]net.IpAddress = undefined;
var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
const result = try net.HostName.lookup(try .init("localhost"), io, .{
var canonical_name_buffer: [net.HostName.max_len]u8 = undefined;
var results_buffer: [32]net.HostName.LookupResult = undefined;
var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer);
net.HostName.lookup(try .init("localhost"), io, &results, .{
.port = 80,
.addresses_buffer = &addresses_buffer,
.canonical_name_buffer = &canon_name_buffer,
.canonical_name_buffer = &canonical_name_buffer,
});
for (addresses_buffer[0..result.addresses_len]) |addr| {
if (addr.eql(&localhost_v4) or addr.eql(&localhost_v6)) break;
} else @panic("unexpected address for localhost");
var addresses_found: usize = 0;
while (results.getOne(io)) |result| switch (result) {
.address => |address| {
if (address.eql(&localhost_v4) or address.eql(&localhost_v6))
addresses_found += 1;
},
.canonical_name => |canonical_name| try testing.expectEqualStrings("localhost", canonical_name.bytes),
.end => |end| {
try end;
break;
},
} else |err| return err;
try testing.expect(addresses_found != 0);
}
{
// The tests are required to work even when there is no Internet connection,
// so some of these errors we must accept and skip the test.
var addresses_buffer: [8]net.IpAddress = undefined;
var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
const result = net.HostName.lookup(try .init("example.com"), io, .{
var canonical_name_buffer: [net.HostName.max_len]u8 = undefined;
var results_buffer: [16]net.HostName.LookupResult = undefined;
var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer);
net.HostName.lookup(try .init("example.com"), io, &results, .{
.port = 80,
.addresses_buffer = &addresses_buffer,
.canonical_name_buffer = &canon_name_buffer,
}) catch |err| switch (err) {
error.UnknownHostName => return error.SkipZigTest,
error.NameServerFailure => return error.SkipZigTest,
else => return err,
};
_ = result;
.canonical_name_buffer = &canonical_name_buffer,
});
while (results.getOne(io)) |result| switch (result) {
.address => {},
.canonical_name => {},
.end => |end| {
end catch |err| switch (err) {
error.UnknownHostName => return error.SkipZigTest,
error.NameServerFailure => return error.SkipZigTest,
else => return err,
};
break;
},
} else |err| return err;
}
}