mirror of
https://github.com/ziglang/zig.git
synced 2026-02-14 21:38:33 +00:00
std.Io: implement Select
and finish implementation of HostName.connect
This commit is contained in:
parent
35ce907c06
commit
d3c4158a10
@ -639,7 +639,7 @@ pub const VTable = struct {
|
||||
/// Copied and then passed to `start`.
|
||||
context: []const u8,
|
||||
context_alignment: std.mem.Alignment,
|
||||
start: *const fn (context: *const anyopaque) void,
|
||||
start: *const fn (*Group, context: *const anyopaque) void,
|
||||
) void,
|
||||
groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
|
||||
groupCancel: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
|
||||
@ -1005,7 +1005,8 @@ pub const Group = struct {
|
||||
pub fn async(g: *Group, io: Io, function: anytype, args: std.meta.ArgsTuple(@TypeOf(function))) void {
|
||||
const Args = @TypeOf(args);
|
||||
const TypeErased = struct {
|
||||
fn start(context: *const anyopaque) void {
|
||||
fn start(group: *Group, context: *const anyopaque) void {
|
||||
_ = group;
|
||||
const args_casted: *const Args = @ptrCast(@alignCast(context));
|
||||
@call(.auto, function, args_casted.*);
|
||||
}
|
||||
@ -1033,6 +1034,85 @@ pub const Group = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub fn Select(comptime U: type) type {
|
||||
return struct {
|
||||
io: Io,
|
||||
group: Group,
|
||||
queue: Queue(U),
|
||||
outstanding: usize,
|
||||
|
||||
const S = @This();
|
||||
|
||||
pub const Union = U;
|
||||
|
||||
pub const Field = std.meta.FieldEnum(U);
|
||||
|
||||
pub fn init(io: Io, buffer: []U) S {
|
||||
return .{
|
||||
.io = io,
|
||||
.queue = .init(buffer),
|
||||
.group = .init,
|
||||
.outstanding = 0,
|
||||
};
|
||||
}
|
||||
|
||||
/// Calls `function` with `args` asynchronously. The resource spawned is
|
||||
/// owned by the select.
|
||||
///
|
||||
/// `function` must have return type matching the `field` field of `Union`.
|
||||
///
|
||||
/// `function` *may* be called immediately, before `async` returns.
|
||||
///
|
||||
/// After this is called, `wait` or `cancel` must be called before the
|
||||
/// select is deinitialized.
|
||||
///
|
||||
/// Threadsafe.
|
||||
///
|
||||
/// Related:
|
||||
/// * `Io.async`
|
||||
/// * `Group.async`
|
||||
pub fn async(
|
||||
s: *S,
|
||||
comptime field: Field,
|
||||
function: anytype,
|
||||
args: std.meta.ArgsTuple(@TypeOf(function)),
|
||||
) void {
|
||||
const Args = @TypeOf(args);
|
||||
const TypeErased = struct {
|
||||
fn start(group: *Group, context: *const anyopaque) void {
|
||||
const args_casted: *const Args = @ptrCast(@alignCast(context));
|
||||
const unerased_select: *S = @fieldParentPtr("group", group);
|
||||
const elem = @unionInit(U, @tagName(field), @call(.auto, function, args_casted.*));
|
||||
unerased_select.queue.putOneUncancelable(unerased_select.io, elem);
|
||||
}
|
||||
};
|
||||
_ = @atomicRmw(usize, &s.outstanding, .Add, 1, .monotonic);
|
||||
s.io.vtable.groupAsync(s.io.userdata, &s.group, @ptrCast((&args)[0..1]), .of(Args), TypeErased.start);
|
||||
}
|
||||
|
||||
/// Blocks until another task of the select finishes.
|
||||
///
|
||||
/// Asserts there is at least one more `outstanding` task.
|
||||
///
|
||||
/// Not threadsafe.
|
||||
pub fn wait(s: *S) Io.Cancelable!U {
|
||||
s.outstanding -= 1;
|
||||
return s.queue.getOne(s.io);
|
||||
}
|
||||
|
||||
/// Equivalent to `wait` but requests cancellation on all remaining
|
||||
/// tasks owned by the select.
|
||||
///
|
||||
/// It is illegal to call `wait` after this.
|
||||
///
|
||||
/// Idempotent. Not threadsafe.
|
||||
pub fn cancel(s: *S) void {
|
||||
s.outstanding = 0;
|
||||
s.group.cancel(s.io);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub const Mutex = struct {
|
||||
state: State,
|
||||
|
||||
|
||||
@ -458,7 +458,7 @@ const GroupClosure = struct {
|
||||
group: *Io.Group,
|
||||
/// Points to sibling `GroupClosure`. Used for walking the group to cancel all.
|
||||
node: std.SinglyLinkedList.Node,
|
||||
func: *const fn (context: *anyopaque) void,
|
||||
func: *const fn (*Io.Group, context: *anyopaque) void,
|
||||
context_alignment: std.mem.Alignment,
|
||||
context_len: usize,
|
||||
|
||||
@ -476,7 +476,7 @@ const GroupClosure = struct {
|
||||
return;
|
||||
}
|
||||
current_closure = closure;
|
||||
gc.func(gc.contextPointer());
|
||||
gc.func(group, gc.contextPointer());
|
||||
current_closure = null;
|
||||
|
||||
// In case a cancel happens after successful task completion, prevents
|
||||
@ -512,7 +512,7 @@ fn groupAsync(
|
||||
group: *Io.Group,
|
||||
context: []const u8,
|
||||
context_alignment: std.mem.Alignment,
|
||||
start: *const fn (context: *const anyopaque) void,
|
||||
start: *const fn (*Io.Group, context: *const anyopaque) void,
|
||||
) void {
|
||||
if (builtin.single_threaded) return start(context.ptr);
|
||||
const pool: *Pool = @ptrCast(@alignCast(userdata));
|
||||
@ -520,7 +520,7 @@ fn groupAsync(
|
||||
const gpa = pool.allocator;
|
||||
const n = GroupClosure.contextEnd(context_alignment, context.len);
|
||||
const gc: *GroupClosure = @ptrCast(@alignCast(gpa.alignedAlloc(u8, .of(GroupClosure), n) catch {
|
||||
return start(context.ptr);
|
||||
return start(group, context.ptr);
|
||||
}));
|
||||
gc.* = .{
|
||||
.closure = .{
|
||||
@ -548,7 +548,7 @@ fn groupAsync(
|
||||
pool.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
|
||||
pool.mutex.unlock();
|
||||
gc.free(gpa);
|
||||
return start(context.ptr);
|
||||
return start(group, context.ptr);
|
||||
};
|
||||
|
||||
pool.run_queue.prepend(&gc.closure.node);
|
||||
@ -558,7 +558,7 @@ fn groupAsync(
|
||||
assert(pool.run_queue.popFirst() == &gc.closure.node);
|
||||
pool.mutex.unlock();
|
||||
gc.free(gpa);
|
||||
return start(context.ptr);
|
||||
return start(group, context.ptr);
|
||||
};
|
||||
pool.threads.appendAssumeCapacity(thread);
|
||||
}
|
||||
@ -2662,6 +2662,7 @@ fn netLookupFallible(
|
||||
.{ .address = addr },
|
||||
.{ .canonical_name = copyCanon(options.canonical_name_buffer, name) },
|
||||
});
|
||||
return;
|
||||
} else |_| {}
|
||||
}
|
||||
|
||||
|
||||
@ -315,8 +315,8 @@ pub const IpAddress = union(enum) {
|
||||
};
|
||||
|
||||
/// Initiates a connection-oriented network stream.
|
||||
pub fn connect(address: *const IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
|
||||
return io.vtable.netConnectIp(io.userdata, address, options);
|
||||
pub fn connect(address: IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
|
||||
return io.vtable.netConnectIp(io.userdata, &address, options);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -88,7 +88,7 @@ pub const LookupResult = union(enum) {
|
||||
/// Adds any number of `IpAddress` into resolved, exactly one canonical_name,
|
||||
/// and then always finishes by adding one `LookupResult.end` entry.
|
||||
///
|
||||
/// Guaranteed not to block if provided queue has capacity at least 8.
|
||||
/// Guaranteed not to block if provided queue has capacity at least 16.
|
||||
pub fn lookup(
|
||||
host_name: HostName,
|
||||
io: Io,
|
||||
@ -216,11 +216,13 @@ pub fn connect(
|
||||
} });
|
||||
defer lookup_task.cancel(io);
|
||||
|
||||
var select: Io.Select(union(enum) { ip_connect: IpAddress.ConnectError!Stream }) = .init;
|
||||
defer select.cancel(io);
|
||||
const Result = union(enum) { connect_result: IpAddress.ConnectError!Stream };
|
||||
var finished_task_buffer: [results_buffer.len]Result = undefined;
|
||||
var select: Io.Select(Result) = .init(io, &finished_task_buffer);
|
||||
defer select.cancel();
|
||||
|
||||
while (results.getOne(io)) |result| switch (result) {
|
||||
.address => |address| select.async(io, .ip_connect, IpAddress.connect, .{ address, io, options }),
|
||||
.address => |address| select.async(.connect_result, IpAddress.connect, .{ address, io, options }),
|
||||
.canonical_name => continue,
|
||||
.end => |lookup_result| {
|
||||
try lookup_result;
|
||||
@ -230,8 +232,8 @@ pub fn connect(
|
||||
|
||||
var aggregate_error: ConnectError = error.UnknownHostName;
|
||||
|
||||
while (select.remaining != 0) switch (select.wait(io)) {
|
||||
.ip_connect => |ip_connect| if (ip_connect) |stream| return stream else |err| switch (err) {
|
||||
while (select.outstanding != 0) switch (try select.wait()) {
|
||||
.connect_result => |connect_result| if (connect_result) |stream| return stream else |err| switch (err) {
|
||||
error.SystemResources => |e| return e,
|
||||
error.OptionUnsupported => |e| return e,
|
||||
error.ProcessFdQuotaExceeded => |e| return e,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
const std = @import("std");
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const std = @import("std");
|
||||
const Io = std.Io;
|
||||
const net = std.Io.net;
|
||||
const mem = std.mem;
|
||||
const testing = std.testing;
|
||||
@ -126,33 +128,56 @@ test "resolve DNS" {
|
||||
const localhost_v4 = try net.IpAddress.parse("127.0.0.1", 80);
|
||||
const localhost_v6 = try net.IpAddress.parse("::2", 80);
|
||||
|
||||
var addresses_buffer: [8]net.IpAddress = undefined;
|
||||
var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
|
||||
const result = try net.HostName.lookup(try .init("localhost"), io, .{
|
||||
var canonical_name_buffer: [net.HostName.max_len]u8 = undefined;
|
||||
var results_buffer: [32]net.HostName.LookupResult = undefined;
|
||||
var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer);
|
||||
|
||||
net.HostName.lookup(try .init("localhost"), io, &results, .{
|
||||
.port = 80,
|
||||
.addresses_buffer = &addresses_buffer,
|
||||
.canonical_name_buffer = &canon_name_buffer,
|
||||
.canonical_name_buffer = &canonical_name_buffer,
|
||||
});
|
||||
for (addresses_buffer[0..result.addresses_len]) |addr| {
|
||||
if (addr.eql(&localhost_v4) or addr.eql(&localhost_v6)) break;
|
||||
} else @panic("unexpected address for localhost");
|
||||
|
||||
var addresses_found: usize = 0;
|
||||
|
||||
while (results.getOne(io)) |result| switch (result) {
|
||||
.address => |address| {
|
||||
if (address.eql(&localhost_v4) or address.eql(&localhost_v6))
|
||||
addresses_found += 1;
|
||||
},
|
||||
.canonical_name => |canonical_name| try testing.expectEqualStrings("localhost", canonical_name.bytes),
|
||||
.end => |end| {
|
||||
try end;
|
||||
break;
|
||||
},
|
||||
} else |err| return err;
|
||||
|
||||
try testing.expect(addresses_found != 0);
|
||||
}
|
||||
|
||||
{
|
||||
// The tests are required to work even when there is no Internet connection,
|
||||
// so some of these errors we must accept and skip the test.
|
||||
var addresses_buffer: [8]net.IpAddress = undefined;
|
||||
var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
|
||||
const result = net.HostName.lookup(try .init("example.com"), io, .{
|
||||
var canonical_name_buffer: [net.HostName.max_len]u8 = undefined;
|
||||
var results_buffer: [16]net.HostName.LookupResult = undefined;
|
||||
var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer);
|
||||
|
||||
net.HostName.lookup(try .init("example.com"), io, &results, .{
|
||||
.port = 80,
|
||||
.addresses_buffer = &addresses_buffer,
|
||||
.canonical_name_buffer = &canon_name_buffer,
|
||||
}) catch |err| switch (err) {
|
||||
error.UnknownHostName => return error.SkipZigTest,
|
||||
error.NameServerFailure => return error.SkipZigTest,
|
||||
else => return err,
|
||||
};
|
||||
_ = result;
|
||||
.canonical_name_buffer = &canonical_name_buffer,
|
||||
});
|
||||
|
||||
while (results.getOne(io)) |result| switch (result) {
|
||||
.address => {},
|
||||
.canonical_name => {},
|
||||
.end => |end| {
|
||||
end catch |err| switch (err) {
|
||||
error.UnknownHostName => return error.SkipZigTest,
|
||||
error.NameServerFailure => return error.SkipZigTest,
|
||||
else => return err,
|
||||
};
|
||||
break;
|
||||
},
|
||||
} else |err| return err;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user