WIP: hack away at std.Io return flight

This commit is contained in:
Andrew Kelley 2025-10-05 20:27:09 -07:00
parent 774df26835
commit b428612a20
10 changed files with 341 additions and 252 deletions

View File

@ -719,31 +719,38 @@ pub const Timestamp = struct {
///
/// The epoch is implementation-defined. For example NTFS/Windows uses
/// 1601-01-01.
realtime,
real,
/// A nonsettable system-wide clock that represents time since some
/// unspecified point in the past.
///
/// On Linux, corresponds to how long the system has been running since
/// it booted.
///
/// Not affected by discontinuous jumps in the system time (e.g., if
/// the system administrator manually changes the clock), but is
/// affected by frequency adjustments. **This clock does not count time
/// that the system is suspended.**
///
/// Guarantees that the time returned by consecutive calls will not go
/// backwards, but successive calls may return identical
/// Monotonic: Guarantees that the time returned by consecutive calls
/// will not go backwards, but successive calls may return identical
/// (not-increased) time values.
///
/// May or may not include time the system is suspended, but
/// implementations should exclude that time if possible.
monotonic,
/// Identical to `monotonic` except it also includes any time that the
/// system is suspended, if possible. However, it may be implemented
/// identically to `monotonic`.
boottime,
process_cputime_id,
thread_cputime_id,
/// Not affected by discontinuous jumps in the system time (e.g., if
/// the system administrator manually changes the clock), but may be
/// affected by frequency adjustments.
///
/// This clock expresses intent to **exclude time that the system is
/// suspended**. However, implementations may be unable to satisify
/// this, and may include that time.
///
/// * On Linux, corresponds `CLOCK_MONOTONIC`.
/// * On macOS, corresponds to `CLOCK_UPTIME_RAW`.
awake,
/// Identical to `awake` except it expresses intent to include time
/// that the system is suspended, however, it may be implemented
/// identically to `awake`.
///
/// * On Linux, corresponds `CLOCK_BOOTTIME`.
/// * On macOS, corresponds to `CLOCK_MONOTONIC_RAW`.
boot,
/// Tracks the amount of CPU in user or kernel mode used by the calling
/// process.
cpu_process,
/// Tracks the amount of CPU in user or kernel mode used by the calling
/// thread.
cpu_thread,
};
pub fn durationTo(from: Timestamp, to: Timestamp) Duration {
@ -825,7 +832,7 @@ pub const Duration = struct {
}
pub fn sleep(duration: Duration, io: Io) SleepError!void {
return io.vtable.sleep(io.userdata, .{ .duration = .{ .duration = duration, .clock = .monotonic } });
return io.vtable.sleep(io.userdata, .{ .duration = .{ .duration = duration, .clock = .awake } });
}
};

View File

@ -319,6 +319,11 @@ pub const Reader = struct {
};
}
/// Takes a legacy `std.fs.File` to help with upgrading.
pub fn initAdapted(file: std.fs.File, io: Io, buffer: []u8) Reader {
return .init(.{ .handle = file.handle }, io, buffer);
}
pub fn initSize(file: File, io: Io, buffer: []u8, size: ?u64) Reader {
return .{
.io = io,

View File

@ -1032,7 +1032,7 @@ fn nowWindows(userdata: ?*anyopaque, clock: Io.Timestamp.Clock) Io.Timestamp.Err
// and uses the NTFS/Windows epoch, which is 1601-01-01.
return @as(i96, windows.ntdll.RtlGetSystemTimePrecise()) * 100;
},
.monotonic, .boottime => {
.monotonic, .uptime => {
// QPC on windows doesn't fail on >= XP/2000 and includes time suspended.
return .{ .timestamp = windows.QueryPerformanceCounter() };
},
@ -1132,7 +1132,8 @@ fn sleepPosix(userdata: ?*anyopaque, timeout: Io.Timeout) Io.SleepError!void {
.sec = std.math.maxInt(sec_type),
.nsec = std.math.maxInt(nsec_type),
};
if (d.clock != .monotonic) return error.UnsupportedClock;
// TODO check which clock nanosleep uses on this host
// and return error.UnsupportedClock if it does not match
const ns = d.duration.nanoseconds;
break :t .{
.sec = @intCast(@divFloor(ns, std.time.ns_per_s)),
@ -1331,11 +1332,15 @@ fn setSocketOption(pool: *Pool, fd: posix.fd_t, level: i32, opt_name: u32, optio
fn ipConnectPosix(
userdata: ?*anyopaque,
address: *const Io.net.IpAddress,
options: Io.net.IpAddress.BindOptions,
options: Io.net.IpAddress.ConnectOptions,
) Io.net.IpAddress.ConnectError!Io.net.Stream {
if (options.timeout != .none) @panic("TODO");
const pool: *Pool = @ptrCast(@alignCast(userdata));
const family = posixAddressFamily(address);
const socket_fd = try openSocketPosix(pool, family, options);
const socket_fd = try openSocketPosix(pool, family, .{
.mode = options.mode,
.protocol = options.protocol,
});
var storage: PosixAddress = undefined;
var addr_len = addressToPosix(address, &storage);
try posixConnect(pool, socket_fd, &storage.any, addr_len);
@ -1490,11 +1495,11 @@ fn netSend(
const pool: *Pool = @ptrCast(@alignCast(userdata));
const posix_flags: u32 =
@as(u32, if (flags.confirm) posix.MSG.CONFIRM else 0) |
@as(u32, if (@hasDecl(posix.MSG, "CONFIRM") and flags.confirm) posix.MSG.CONFIRM else 0) |
@as(u32, if (flags.dont_route) posix.MSG.DONTROUTE else 0) |
@as(u32, if (flags.eor) posix.MSG.EOR else 0) |
@as(u32, if (flags.oob) posix.MSG.OOB else 0) |
@as(u32, if (flags.fastopen) posix.MSG.FASTOPEN else 0) |
@as(u32, if (@hasDecl(posix.MSG, "FASTOPEN") and flags.fastopen) posix.MSG.FASTOPEN else 0) |
posix.MSG.NOSIGNAL;
var i: usize = 0;
@ -2024,11 +2029,17 @@ fn recoverableOsBugDetected() void {
fn clockToPosix(clock: Io.Timestamp.Clock) posix.clockid_t {
return switch (clock) {
.realtime => posix.CLOCK.REALTIME,
.monotonic => posix.CLOCK.MONOTONIC,
.boottime => posix.CLOCK.BOOTTIME,
.process_cputime_id => posix.CLOCK.PROCESS_CPUTIME_ID,
.thread_cputime_id => posix.CLOCK.THREAD_CPUTIME_ID,
.real => posix.CLOCK.REALTIME,
.awake => switch (builtin.os.tag) {
.macos, .ios, .watchos, .tvos => posix.CLOCK.UPTIME_RAW,
else => posix.CLOCK.MONOTONIC,
},
.boot => switch (builtin.os.tag) {
.macos, .ios, .watchos, .tvos => posix.CLOCK.MONOTONIC_RAW,
else => posix.CLOCK.BOOTTIME,
},
.cpu_process => posix.CLOCK.PROCESS_CPUTIME_ID,
.cpu_thread => posix.CLOCK.THREAD_CPUTIME_ID,
};
}
@ -2036,7 +2047,7 @@ fn clockToWasi(clock: Io.Timestamp.Clock) std.os.wasi.clockid_t {
return switch (clock) {
.realtime => .REALTIME,
.monotonic => .MONOTONIC,
.boottime => .MONOTONIC,
.uptime => .MONOTONIC,
.process_cputime_id => .PROCESS_CPUTIME_ID,
.thread_cputime_id => .THREAD_CPUTIME_ID,
};

View File

@ -186,7 +186,7 @@ pub const IpAddress = union(enum) {
/// Waits for a TCP connection. When using this API, `bind` does not need
/// to be called. The returned `Server` has an open `stream`.
pub fn listen(address: IpAddress, io: Io, options: ListenOptions) ListenError!Server {
return io.vtable.tcpListen(io.userdata, address, options);
return io.vtable.listen(io.userdata, address, options);
}
pub const BindError = error{
@ -236,6 +236,8 @@ pub const IpAddress = union(enum) {
AddressInUse,
AddressUnavailable,
AddressFamilyUnsupported,
/// Insufficient memory or other resource internal to the operating system.
SystemResources,
ConnectionPending,
ConnectionRefused,
ConnectionResetByPeer,
@ -246,12 +248,23 @@ pub const IpAddress = union(enum) {
/// One of the `ConnectOptions` is not supported by the Io
/// implementation.
OptionUnsupported,
} || Io.UnexpectedError || Io.Cancelable;
/// Per-process limit on the number of open file descriptors has been reached.
ProcessFdQuotaExceeded,
/// System-wide limit on the total number of open files has been reached.
SystemFdQuotaExceeded,
ProtocolUnsupportedBySystem,
ProtocolUnsupportedByAddressFamily,
SocketModeUnsupported,
} || Io.Timeout.Error || Io.UnexpectedError || Io.Cancelable;
pub const ConnectOptions = BindOptions;
pub const ConnectOptions = struct {
mode: Socket.Mode,
protocol: ?Protocol = null,
timeout: Io.Timeout = .none,
};
/// Initiates a connection-oriented network stream.
pub fn connect(address: IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
pub fn connect(address: *const IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
return io.vtable.ipConnect(io.userdata, address, options);
}
};
@ -997,7 +1010,7 @@ pub const Stream = struct {
socket: Socket,
pub fn close(s: *Stream, io: Io) void {
io.vtable.netClose(io.userdata, s.socket);
io.vtable.netClose(io.userdata, s.socket.handle);
s.* = undefined;
}
@ -1040,10 +1053,13 @@ pub const Stream = struct {
return n;
}
fn readVec(io_r: *Reader, data: [][]u8) Io.Reader.Error!usize {
fn readVec(io_r: *Io.Reader, data: [][]u8) Io.Reader.Error!usize {
const r: *Reader = @alignCast(@fieldParentPtr("interface", io_r));
const io = r.io;
return io.vtable.netReadVec(io.vtable.userdata, r.stream, io_r, data);
return io.vtable.netRead(io.userdata, r.stream, data) catch |err| {
r.err = err;
return error.ReadFailed;
};
}
};
@ -1078,7 +1094,10 @@ pub const Stream = struct {
const w: *Writer = @alignCast(@fieldParentPtr("interface", io_w));
const io = w.io;
const buffered = io_w.buffered();
const n = try io.vtable.netWrite(io.vtable.userdata, w.stream, buffered, data, splat);
const n = io.vtable.netWrite(io.userdata, w.stream, buffered, data, splat) catch |err| {
w.err = err;
return error.WriteFailed;
};
return io_w.consume(n);
}
};
@ -1104,7 +1123,7 @@ pub const Server = struct {
/// Blocks until a client connects to the server.
pub fn accept(s: *Server, io: Io) AcceptError!Stream {
return io.vtable.accept(io, s);
return io.vtable.accept(io.userdata, s);
}
};

View File

@ -19,12 +19,12 @@ bytes: []const u8,
pub const max_len = 255;
pub const InitError = error{
pub const ValidateError = error{
NameTooLong,
InvalidHostName,
};
pub fn init(bytes: []const u8) InitError!HostName {
pub fn validate(bytes: []const u8) ValidateError!void {
if (bytes.len > max_len) return error.NameTooLong;
if (!std.unicode.utf8ValidateSlice(bytes)) return error.InvalidHostName;
for (bytes) |byte| {
@ -33,10 +33,34 @@ pub fn init(bytes: []const u8) InitError!HostName {
}
return error.InvalidHostName;
}
}
pub fn init(bytes: []const u8) ValidateError!HostName {
try validate(bytes);
return .{ .bytes = bytes };
}
/// TODO add a retry field here
pub fn sameParentDomain(parent_host: HostName, child_host: HostName) bool {
const parent_bytes = parent_host.bytes;
const child_bytes = child_host.bytes;
if (!std.ascii.endsWithIgnoreCase(child_bytes, parent_bytes)) return false;
if (child_bytes.len == parent_bytes.len) return true;
if (parent_bytes.len > child_bytes.len) return false;
return child_bytes[child_bytes.len - parent_bytes.len - 1] == '.';
}
test sameParentDomain {
try std.testing.expect(!sameParentDomain(try .init("foo.com"), try .init("bar.com")));
try std.testing.expect(sameParentDomain(try .init("foo.com"), try .init("foo.com")));
try std.testing.expect(sameParentDomain(try .init("foo.com"), try .init("bar.foo.com")));
try std.testing.expect(!sameParentDomain(try .init("bar.foo.com"), try .init("foo.com")));
}
/// Domain names are case-insensitive (RFC 5890, Section 2.3.2.4)
pub fn eql(a: HostName, b: HostName) bool {
return std.ascii.eqlIgnoreCase(a.bytes, b.bytes);
}
pub const LookupOptions = struct {
port: u16,
/// Must have at least length 2.
@ -266,15 +290,15 @@ fn lookupDns(io: Io, lookup_canon_name: []const u8, rc: *const ResolvConf, optio
var answers_remaining = answers.len;
for (answers) |*answer| answer.len = 0;
// boottime is chosen because time the computer is suspended should count
// boot clock is chosen because time the computer is suspended should count
// against time spent waiting for external messages to arrive.
var now_ts = try Io.Timestamp.now(io, .boottime);
var now_ts = try Io.Timestamp.now(io, .boot);
const final_ts = now_ts.addDuration(.fromSeconds(rc.timeout_seconds));
const attempt_duration: Io.Duration = .{
.nanoseconds = std.time.ns_per_s * @as(usize, rc.timeout_seconds) / rc.attempts,
};
send: while (now_ts.compare(.lt, final_ts)) : (now_ts = try Io.Timestamp.now(io, .boottime)) {
send: while (now_ts.compare(.lt, final_ts)) : (now_ts = try Io.Timestamp.now(io, .boot)) {
const max_messages = queries_buffer.len * ResolvConf.max_nameservers;
{
var message_buffer: [max_messages]Io.net.OutgoingMessage = undefined;
@ -518,7 +542,7 @@ fn writeResolutionQuery(q: *[280]u8, op: u4, dname: []const u8, class: u8, ty: u
return n;
}
pub const ExpandError = error{InvalidDnsPacket} || InitError;
pub const ExpandError = error{InvalidDnsPacket} || ValidateError;
/// Decompresses a DNS name.
///
@ -618,22 +642,36 @@ pub const DnsResponse = struct {
}
};
pub const ConnectTcpError = LookupError || IpAddress.ConnectTcpError;
pub const ConnectError = LookupError || IpAddress.ConnectError;
pub fn connectTcp(host_name: HostName, io: Io, port: u16) ConnectTcpError!Stream {
pub fn connect(
host_name: HostName,
io: Io,
port: u16,
options: IpAddress.ConnectOptions,
) ConnectError!Stream {
var addresses_buffer: [32]IpAddress = undefined;
var canonical_name_buffer: [HostName.max_len]u8 = undefined;
const results = try lookup(host_name, .{
const results = try lookup(host_name, io, .{
.port = port,
.addresses_buffer = &addresses_buffer,
.canonical_name_buffer = &.{},
.canonical_name_buffer = &canonical_name_buffer,
});
const addresses = addresses_buffer[0..results.addresses_len];
if (addresses.len == 0) return error.UnknownHostName;
for (addresses) |addr| {
return addr.connectTcp(io) catch |err| switch (err) {
// TODO instead of serially, use a Select API to send out
// the connections simultaneously and then keep the first
// successful one, canceling the rest.
// TODO On Linux this should additionally use an Io.Queue based
// DNS resolution API in order to send out a connection after
// each DNS response before waiting for the rest of them.
for (addresses) |*addr| {
return addr.connect(io, options) catch |err| switch (err) {
error.ConnectionRefused => continue,
else => |e| return e,
};

View File

@ -7,32 +7,30 @@ const testing = std.testing;
test "parse and render IP addresses at comptime" {
comptime {
const ipv6addr = net.IpAddress.parse("::1", 0) catch unreachable;
try std.testing.expectFmt("[::1]:0", "{f}", .{ipv6addr});
try testing.expectFmt("[::1]:0", "{f}", .{ipv6addr});
const ipv4addr = net.IpAddress.parse("127.0.0.1", 0) catch unreachable;
try std.testing.expectFmt("127.0.0.1:0", "{f}", .{ipv4addr});
try testing.expectFmt("127.0.0.1:0", "{f}", .{ipv4addr});
try testing.expectError(error.ParseFailed, net.IpAddress.parse("::123.123.123.123", 0));
try testing.expectError(error.ParseFailed, net.IpAddress.parse("127.01.0.1", 0));
try testing.expectError(error.ParseFailed, net.IpAddress.resolveIp("::123.123.123.123", 0));
try testing.expectError(error.ParseFailed, net.IpAddress.resolveIp("127.01.0.1", 0));
}
}
test "format IPv6 address with no zero runs" {
const addr = try std.net.IpAddress.parseIp6("2001:db8:1:2:3:4:5:6", 0);
try std.testing.expectFmt("[2001:db8:1:2:3:4:5:6]:0", "{f}", .{addr});
const addr = try net.IpAddress.parseIp6("2001:db8:1:2:3:4:5:6", 0);
try testing.expectFmt("[2001:db8:1:2:3:4:5:6]:0", "{f}", .{addr});
}
test "parse IPv6 addresses and check compressed form" {
try std.testing.expectFmt("[2001:db8::1:0:0:2]:0", "{f}", .{
try std.net.IpAddress.parseIp6("2001:0db8:0000:0000:0001:0000:0000:0002", 0),
try testing.expectFmt("[2001:db8::1:0:0:2]:0", "{f}", .{
try net.IpAddress.parseIp6("2001:0db8:0000:0000:0001:0000:0000:0002", 0),
});
try std.testing.expectFmt("[2001:db8::1:2]:0", "{f}", .{
try std.net.IpAddress.parseIp6("2001:0db8:0000:0000:0000:0000:0001:0002", 0),
try testing.expectFmt("[2001:db8::1:2]:0", "{f}", .{
try net.IpAddress.parseIp6("2001:0db8:0000:0000:0000:0000:0001:0002", 0),
});
try std.testing.expectFmt("[2001:db8:1:0:1::2]:0", "{f}", .{
try std.net.IpAddress.parseIp6("2001:0db8:0001:0000:0001:0000:0000:0002", 0),
try testing.expectFmt("[2001:db8:1:0:1::2]:0", "{f}", .{
try net.IpAddress.parseIp6("2001:0db8:0001:0000:0001:0000:0000:0002", 0),
});
}
@ -43,14 +41,14 @@ test "parse IPv6 address, check raw bytes" {
0x00, 0x01, 0x00, 0x00, // :0001:0000
0x00, 0x00, 0x00, 0x02, // :0000:0002
};
const addr = try std.net.IpAddress.parseIp6("2001:db8:0000:0000:0001:0000:0000:0002", 0);
const actual_raw = addr.in6.sa.addr[0..];
try std.testing.expectEqualSlices(u8, expected_raw[0..], actual_raw);
const addr = try net.IpAddress.parseIp6("2001:db8:0000:0000:0001:0000:0000:0002", 0);
try testing.expectEqualSlices(u8, &expected_raw, &addr.ip6.bytes);
}
test "parse and render IPv6 addresses" {
// TODO make this test parsing and rendering only, then it doesn't need I/O
const io = testing.io;
var buffer: [100]u8 = undefined;
const ips = [_][]const u8{
"FF01:0:0:0:0:0:0:FB",
@ -79,12 +77,12 @@ test "parse and render IPv6 addresses" {
for (ips, 0..) |ip, i| {
const addr = net.IpAddress.parseIp6(ip, 0) catch unreachable;
var newIp = std.fmt.bufPrint(buffer[0..], "{f}", .{addr}) catch unreachable;
try std.testing.expect(std.mem.eql(u8, printed[i], newIp[1 .. newIp.len - 3]));
try testing.expect(std.mem.eql(u8, printed[i], newIp[1 .. newIp.len - 3]));
if (builtin.os.tag == .linux) {
const addr_via_resolve = net.IpAddress.resolveIp6(ip, 0) catch unreachable;
const addr_via_resolve = net.IpAddress.resolveIp6(io, ip, 0) catch unreachable;
var newResolvedIp = std.fmt.bufPrint(buffer[0..], "{f}", .{addr_via_resolve}) catch unreachable;
try std.testing.expect(std.mem.eql(u8, printed[i], newResolvedIp[1 .. newResolvedIp.len - 3]));
try testing.expect(std.mem.eql(u8, printed[i], newResolvedIp[1 .. newResolvedIp.len - 3]));
}
}
@ -97,21 +95,23 @@ test "parse and render IPv6 addresses" {
try testing.expectError(error.Incomplete, net.IpAddress.parseIp6("1", 0));
// TODO Make this test pass on other operating systems.
if (builtin.os.tag == .linux or comptime builtin.os.tag.isDarwin() or builtin.os.tag == .windows) {
try testing.expectError(error.Incomplete, net.IpAddress.resolveIp6("ff01::fb%", 0));
try testing.expectError(error.Incomplete, net.IpAddress.resolveIp6(io, "ff01::fb%", 0));
// Assumes IFNAMESIZE will always be a multiple of 2
try testing.expectError(error.Overflow, net.IpAddress.resolveIp6("ff01::fb%wlp3" ++ "s0" ** @divExact(std.posix.IFNAMESIZE - 4, 2), 0));
try testing.expectError(error.Overflow, net.IpAddress.resolveIp6("ff01::fb%12345678901234", 0));
try testing.expectError(error.Overflow, net.IpAddress.resolveIp6(io, "ff01::fb%wlp3" ++ "s0" ** @divExact(std.posix.IFNAMESIZE - 4, 2), 0));
try testing.expectError(error.Overflow, net.IpAddress.resolveIp6(io, "ff01::fb%12345678901234", 0));
}
}
test "invalid but parseable IPv6 scope ids" {
const io = testing.io;
if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin() and builtin.os.tag != .windows) {
// Currently, resolveIp6 with alphanumerical scope IDs only works on Linux.
// TODO Make this test pass on other operating systems.
return error.SkipZigTest;
}
try testing.expectError(error.InterfaceNotFound, net.IpAddress.resolveIp6("ff01::fb%123s45678901234", 0));
try testing.expectError(error.InterfaceNotFound, net.IpAddress.resolveIp6(io, "ff01::fb%123s45678901234", 0));
}
test "parse and render IPv4 addresses" {
@ -125,7 +125,7 @@ test "parse and render IPv4 addresses" {
}) |ip| {
const addr = net.IpAddress.parseIp4(ip, 0) catch unreachable;
var newIp = std.fmt.bufPrint(buffer[0..], "{f}", .{addr}) catch unreachable;
try std.testing.expect(std.mem.eql(u8, ip, newIp[0 .. newIp.len - 2]));
try testing.expect(std.mem.eql(u8, ip, newIp[0 .. newIp.len - 2]));
}
try testing.expectError(error.Overflow, net.IpAddress.parseIp4("256.0.0.1", 0));
@ -136,50 +136,43 @@ test "parse and render IPv4 addresses" {
try testing.expectError(error.NonCanonical, net.IpAddress.parseIp4("127.01.0.1", 0));
}
test "parse and render UNIX addresses" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
if (!net.has_unix_sockets) return error.SkipZigTest;
const addr = net.Address.initUnix("/tmp/testpath") catch unreachable;
try std.testing.expectFmt("/tmp/testpath", "{f}", .{addr});
const too_long = [_]u8{'a'} ** 200;
try testing.expectError(error.NameTooLong, net.Address.initUnix(too_long[0..]));
}
test "resolve DNS" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
if (builtin.os.tag == .windows) {
_ = try std.os.windows.WSAStartup(2, 2);
}
defer {
if (builtin.os.tag == .windows) {
std.os.windows.WSACleanup() catch unreachable;
}
}
const io = testing.io;
// Resolve localhost, this should not fail.
{
const localhost_v4 = try net.IpAddress.parse("127.0.0.1", 80);
const localhost_v6 = try net.IpAddress.parse("::2", 80);
const result = try net.getAddressList(testing.allocator, "localhost", 80);
defer result.deinit();
for (result.addrs) |addr| {
if (addr.eql(localhost_v4) or addr.eql(localhost_v6)) break;
var addresses_buffer: [8]net.IpAddress = undefined;
var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
const result = try net.HostName.lookup(try .init("localhost"), io, .{
.port = 80,
.addresses_buffer = &addresses_buffer,
.canonical_name_buffer = &canon_name_buffer,
});
for (addresses_buffer[0..result.addresses_len]) |addr| {
if (addr.eql(&localhost_v4) or addr.eql(&localhost_v6)) break;
} else @panic("unexpected address for localhost");
}
{
// The tests are required to work even when there is no Internet connection,
// so some of these errors we must accept and skip the test.
const result = net.getAddressList(testing.allocator, "example.com", 80) catch |err| switch (err) {
var addresses_buffer: [8]net.IpAddress = undefined;
var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
const result = net.HostName.lookup(try .init("example.com"), io, .{
.port = 80,
.addresses_buffer = &addresses_buffer,
.canonical_name_buffer = &canon_name_buffer,
}) catch |err| switch (err) {
error.UnknownHostName => return error.SkipZigTest,
error.TemporaryNameServerFailure => return error.SkipZigTest,
error.NameServerFailure => return error.SkipZigTest,
else => return err,
};
result.deinit();
_ = result;
}
}
@ -187,6 +180,8 @@ test "listen on a port, send bytes, receive bytes" {
if (builtin.single_threaded) return error.SkipZigTest;
if (builtin.os.tag == .wasi) return error.SkipZigTest;
const io = testing.io;
if (builtin.os.tag == .windows) {
_ = try std.os.windows.WSAStartup(2, 2);
}
@ -198,28 +193,28 @@ test "listen on a port, send bytes, receive bytes" {
// Try only the IPv4 variant as some CI builders have no IPv6 localhost
// configured.
const localhost = try net.IpAddress.parse("127.0.0.1", 0);
const localhost: net.IpAddress = .{ .ip4 = .loopback(0) };
var server = try localhost.listen(.{});
defer server.deinit();
var server = try localhost.listen(io, .{});
defer server.deinit(io);
const S = struct {
fn clientFn(server_address: net.IpAddress) !void {
const socket = try net.tcpConnectToAddress(server_address);
defer socket.close();
var stream = try server_address.connect(io, .{ .mode = .stream });
defer stream.close(io);
var stream_writer = socket.writer(&.{});
var stream_writer = stream.writer(io, &.{});
try stream_writer.interface.writeAll("Hello world!");
}
};
const t = try std.Thread.spawn(.{}, S.clientFn, .{server.listen_address});
const t = try std.Thread.spawn(.{}, S.clientFn, .{server.socket.address});
defer t.join();
var client = try server.accept();
defer client.stream.close();
var client = try server.accept(io);
defer client.stream.close(io);
var buf: [16]u8 = undefined;
var stream_reader = client.stream.reader(&.{});
var stream_reader = client.stream.reader(io, &.{});
const n = try stream_reader.interface().readSliceShort(&buf);
try testing.expectEqual(@as(usize, 12), n);
@ -232,13 +227,15 @@ test "listen on an in use port" {
return error.SkipZigTest;
}
const localhost = try net.IpAddress.parse("127.0.0.1", 0);
const io = testing.io;
var server1 = try localhost.listen(.{ .reuse_address = true });
defer server1.deinit();
const localhost: net.IpAddress = .{ .ip4 = .loopback(0) };
var server2 = try server1.listen_address.listen(.{ .reuse_address = true });
defer server2.deinit();
var server1 = try localhost.listen(io, .{ .reuse_address = true });
defer server1.deinit(io);
var server2 = try server1.socket.address.listen(io, .{ .reuse_address = true });
defer server2.deinit(io);
}
fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void {
@ -268,9 +265,11 @@ fn testClient(addr: net.IpAddress) anyerror!void {
fn testServer(server: *net.Server) anyerror!void {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
var client = try server.accept();
const io = testing.io;
const stream = client.stream.writer();
var client = try server.accept(io);
const stream = client.stream.writer(io);
try stream.print("hello from server\n", .{});
}
@ -278,6 +277,8 @@ test "listen on a unix socket, send bytes, receive bytes" {
if (builtin.single_threaded) return error.SkipZigTest;
if (!net.has_unix_sockets) return error.SkipZigTest;
const io = testing.io;
if (builtin.os.tag == .windows) {
_ = try std.os.windows.WSAStartup(2, 2);
}
@ -293,15 +294,15 @@ test "listen on a unix socket, send bytes, receive bytes" {
const socket_addr = try net.IpAddress.initUnix(socket_path);
defer std.fs.cwd().deleteFile(socket_path) catch {};
var server = try socket_addr.listen(.{});
defer server.deinit();
var server = try socket_addr.listen(io, .{});
defer server.deinit(io);
const S = struct {
fn clientFn(path: []const u8) !void {
const socket = try net.connectUnixSocket(path);
defer socket.close();
var stream = try net.connectUnixSocket(path);
defer stream.close(io);
var stream_writer = socket.writer(&.{});
var stream_writer = stream.writer(io, &.{});
try stream_writer.interface.writeAll("Hello world!");
}
};
@ -309,10 +310,10 @@ test "listen on a unix socket, send bytes, receive bytes" {
const t = try std.Thread.spawn(.{}, S.clientFn, .{socket_path});
defer t.join();
var client = try server.accept();
defer client.stream.close();
var client = try server.accept(io);
defer client.stream.close(io);
var buf: [16]u8 = undefined;
var stream_reader = client.stream.reader(&.{});
var stream_reader = client.stream.reader(io, &.{});
const n = try stream_reader.interface().readSliceShort(&buf);
try testing.expectEqual(@as(usize, 12), n);
@ -324,14 +325,16 @@ test "listen on a unix socket with reuse_address option" {
// Windows doesn't implement reuse port option.
if (builtin.os.tag == .windows) return error.SkipZigTest;
const io = testing.io;
const socket_path = try generateFileName("socket.unix");
defer testing.allocator.free(socket_path);
const socket_addr = try net.Address.initUnix(socket_path);
defer std.fs.cwd().deleteFile(socket_path) catch {};
var server = try socket_addr.listen(.{ .reuse_address = true });
server.deinit();
var server = try socket_addr.listen(io, .{ .reuse_address = true });
server.deinit(io);
}
fn generateFileName(base_name: []const u8) ![]const u8 {
@ -351,19 +354,21 @@ test "non-blocking tcp server" {
return error.SkipZigTest;
}
const localhost = try net.IpAddress.parse("127.0.0.1", 0);
var server = localhost.listen(.{ .force_nonblocking = true });
defer server.deinit();
const io = testing.io;
const accept_err = server.accept();
const localhost: net.IpAddress = .{ .ip4 = .loopback(0) };
var server = localhost.listen(io, .{ .force_nonblocking = true });
defer server.deinit(io);
const accept_err = server.accept(io);
try testing.expectError(error.WouldBlock, accept_err);
const socket_file = try net.tcpConnectToAddress(server.listen_address);
const socket_file = try net.tcpConnectToAddress(server.socket.address);
defer socket_file.close();
var client = try server.accept();
defer client.stream.close();
const stream = client.stream.writer();
var client = try server.accept(io);
defer client.stream.close(io);
const stream = client.stream.writer(io);
try stream.print("hello from server\n", .{});
var buf: [100]u8 = undefined;

View File

@ -1,45 +1,48 @@
//! Uniform Resource Identifier (URI) parsing roughly adhering to <https://tools.ietf.org/html/rfc3986>.
//! Does not do perfect grammar and character class checking, but should be robust against URIs in the wild.
//! Uniform Resource Identifier (URI) parsing roughly adhering to
//! <https://tools.ietf.org/html/rfc3986>. Does not do perfect grammar and
//! character class checking, but should be robust against URIs in the wild.
const std = @import("std.zig");
const testing = std.testing;
const Uri = @This();
const Allocator = std.mem.Allocator;
const Writer = std.Io.Writer;
const HostName = std.Io.net.HostName;
scheme: []const u8,
user: ?Component = null,
password: ?Component = null,
/// If non-null, already validated.
host: ?Component = null,
port: ?u16 = null,
path: Component = Component.empty,
query: ?Component = null,
fragment: ?Component = null,
pub const host_name_max = 255;
pub const GetHostError = error{UriMissingHost};
/// Returned value may point into `buffer` or be the original string.
///
/// Suggested buffer length: `host_name_max`.
///
/// See also:
/// * `getHostAlloc`
pub fn getHost(uri: Uri, buffer: []u8) error{ UriMissingHost, UriHostTooLong }![]const u8 {
pub fn getHost(uri: Uri, buffer: *[HostName.max_len]u8) GetHostError!HostName {
const component = uri.host orelse return error.UriMissingHost;
return component.toRaw(buffer) catch |err| switch (err) {
error.NoSpaceLeft => return error.UriHostTooLong,
const bytes = component.toRaw(buffer) catch |err| switch (err) {
error.NoSpaceLeft => unreachable, // `host` already validated.
};
return .{ .bytes = bytes };
}
pub const GetHostAllocError = GetHostError || error{OutOfMemory};
/// Returned value may point into `buffer` or be the original string.
///
/// See also:
/// * `getHost`
pub fn getHostAlloc(uri: Uri, arena: Allocator) error{ UriMissingHost, UriHostTooLong, OutOfMemory }![]const u8 {
pub fn getHostAlloc(uri: Uri, arena: Allocator) GetHostAllocError![]const u8 {
const component = uri.host orelse return error.UriMissingHost;
const result = try component.toRawMaybeAlloc(arena);
if (result.len > host_name_max) return error.UriHostTooLong;
return result;
const bytes = try component.toRawMaybeAlloc(arena);
return .{ .bytes = bytes };
}
pub const Component = union(enum) {
@ -397,7 +400,7 @@ pub fn resolveInPlace(base: Uri, new_len: usize, aux_buf: *[]u8) ResolveInPlaceE
.scheme = new_parsed.scheme,
.user = new_parsed.user,
.password = new_parsed.password,
.host = new_parsed.host,
.host = try validateHost(new_parsed.host),
.port = new_parsed.port,
.path = remove_dot_segments(new_path),
.query = new_parsed.query,
@ -408,7 +411,7 @@ pub fn resolveInPlace(base: Uri, new_len: usize, aux_buf: *[]u8) ResolveInPlaceE
.scheme = base.scheme,
.user = new_parsed.user,
.password = new_parsed.password,
.host = host,
.host = try validateHost(host),
.port = new_parsed.port,
.path = remove_dot_segments(new_path),
.query = new_parsed.query,
@ -430,7 +433,7 @@ pub fn resolveInPlace(base: Uri, new_len: usize, aux_buf: *[]u8) ResolveInPlaceE
.scheme = base.scheme,
.user = base.user,
.password = base.password,
.host = base.host,
.host = try validateHost(base.host),
.port = base.port,
.path = path,
.query = query,
@ -438,6 +441,11 @@ pub fn resolveInPlace(base: Uri, new_len: usize, aux_buf: *[]u8) ResolveInPlaceE
};
}
fn validateHost(bytes: []const u8) []const u8 {
try HostName.validate(bytes);
return bytes;
}
/// In-place implementation of RFC 3986, Section 5.2.4.
fn remove_dot_segments(path: []u8) Component {
var in_i: usize = 0;

View File

@ -2281,7 +2281,7 @@ test "seekTo flushes buffered data" {
}
var read_buffer: [16]u8 = undefined;
var file_reader: std.Io.File.Reader = .init(file, io, &read_buffer);
var file_reader: std.Io.File.Reader = .initAdapted(file, io, &read_buffer);
var buf: [4]u8 = undefined;
try file_reader.interface.readSliceAll(&buf);

View File

@ -15,6 +15,7 @@ const assert = std.debug.assert;
const Io = std.Io;
const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
const HostName = std.Io.net.HostName;
const Client = @This();
@ -69,7 +70,7 @@ pub const ConnectionPool = struct {
/// The criteria for a connection to be considered a match.
pub const Criteria = struct {
host: []const u8,
host: HostName,
port: u16,
protocol: Protocol,
};
@ -89,7 +90,7 @@ pub const ConnectionPool = struct {
if (connection.port != criteria.port) continue;
// Domain names are case-insensitive (RFC 5890, Section 2.3.2.4)
if (!std.ascii.eqlIgnoreCase(connection.host(), criteria.host)) continue;
if (!connection.host().eql(criteria.host)) continue;
pool.acquireUnsafe(connection);
return connection;
@ -118,19 +119,19 @@ pub const ConnectionPool = struct {
/// If the connection is marked as closing, it will be closed instead.
///
/// Threadsafe.
pub fn release(pool: *ConnectionPool, connection: *Connection) void {
pub fn release(pool: *ConnectionPool, connection: *Connection, io: Io) void {
pool.mutex.lock();
defer pool.mutex.unlock();
pool.used.remove(&connection.pool_node);
if (connection.closing or pool.free_size == 0) return connection.destroy();
if (connection.closing or pool.free_size == 0) return connection.destroy(io);
if (pool.free_len >= pool.free_size) {
const popped: *Connection = @alignCast(@fieldParentPtr("pool_node", pool.free.popFirst().?));
pool.free_len -= 1;
popped.destroy();
popped.destroy(io);
}
if (connection.proxied) {
@ -178,21 +179,21 @@ pub const ConnectionPool = struct {
/// All future operations on the connection pool will deadlock.
///
/// Threadsafe.
pub fn deinit(pool: *ConnectionPool) void {
pub fn deinit(pool: *ConnectionPool, io: Io) void {
pool.mutex.lock();
var next = pool.free.first;
while (next) |node| {
const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node));
next = node.next;
connection.destroy();
connection.destroy(io);
}
next = pool.used.first;
while (next) |node| {
const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node));
next = node.next;
connection.destroy();
connection.destroy(io);
}
pool.* = undefined;
@ -242,19 +243,19 @@ pub const Connection = struct {
fn create(
client: *Client,
remote_host: []const u8,
remote_host: HostName,
port: u16,
stream: Io.net.Stream,
) error{OutOfMemory}!*Plain {
const gpa = client.allocator;
const alloc_len = allocLen(client, remote_host.len);
const alloc_len = allocLen(client, remote_host.bytes.len);
const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len);
errdefer gpa.free(base);
const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len];
const host_buffer = base[@sizeOf(Plain)..][0..remote_host.bytes.len];
const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size];
const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size];
assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len);
@memcpy(host_buffer, remote_host);
@memcpy(host_buffer, remote_host.bytes);
const plain: *Plain = @ptrCast(base);
plain.* = .{
.connection = .{
@ -263,7 +264,7 @@ pub const Connection = struct {
.stream_reader = stream.reader(socket_read_buffer),
.pool_node = .{},
.port = port,
.host_len = @intCast(remote_host.len),
.host_len = @intCast(remote_host.bytes.len),
.proxied = false,
.closing = false,
.protocol = .plain,
@ -283,9 +284,9 @@ pub const Connection = struct {
return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size;
}
fn host(plain: *Plain) []u8 {
fn host(plain: *Plain) HostName {
const base: [*]u8 = @ptrCast(plain);
return base[@sizeOf(Plain)..][0..plain.connection.host_len];
return .{ .bytes = base[@sizeOf(Plain)..][0..plain.connection.host_len] };
}
};
@ -295,15 +296,15 @@ pub const Connection = struct {
fn create(
client: *Client,
remote_host: []const u8,
remote_host: HostName,
port: u16,
stream: Io.net.Stream,
) error{ OutOfMemory, TlsInitializationFailed }!*Tls {
const gpa = client.allocator;
const alloc_len = allocLen(client, remote_host.len);
const alloc_len = allocLen(client, remote_host.bytes.len);
const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len);
errdefer gpa.free(base);
const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
const host_buffer = base[@sizeOf(Tls)..][0..remote_host.bytes.len];
// The TLS client wants enough buffer for the max encrypted frame
// size, and the HTTP body reader wants enough buffer for the
// entire HTTP header. This means we need a combined upper bound.
@ -313,7 +314,7 @@ pub const Connection = struct {
const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
const socket_read_buffer = socket_write_buffer.ptr[socket_write_buffer.len..][0..client.tls_buffer_size];
assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len);
@memcpy(host_buffer, remote_host);
@memcpy(host_buffer, remote_host.bytes);
const tls: *Tls = @ptrCast(base);
tls.* = .{
.connection = .{
@ -322,17 +323,17 @@ pub const Connection = struct {
.stream_reader = stream.reader(socket_read_buffer),
.pool_node = .{},
.port = port,
.host_len = @intCast(remote_host.len),
.host_len = @intCast(remote_host.bytes.len),
.proxied = false,
.closing = false,
.protocol = .tls,
},
// TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true
.client = std.crypto.tls.Client.init(
tls.connection.stream_reader.interface(),
&tls.connection.stream_reader.interface,
&tls.connection.stream_writer.interface,
.{
.host = .{ .explicit = remote_host },
.host = .{ .explicit = remote_host.bytes },
.ca = .{ .bundle = client.ca_bundle },
.ssl_key_log = client.ssl_key_log,
.read_buffer = tls_read_buffer,
@ -359,9 +360,9 @@ pub const Connection = struct {
client.write_buffer_size + client.tls_buffer_size;
}
fn host(tls: *Tls) []u8 {
fn host(tls: *Tls) HostName {
const base: [*]u8 = @ptrCast(tls);
return base[@sizeOf(Tls)..][0..tls.connection.host_len];
return .{ .bytes = base[@sizeOf(Tls)..][0..tls.connection.host_len] };
}
};
@ -384,7 +385,7 @@ pub const Connection = struct {
return c.stream_reader.stream;
}
pub fn host(c: *Connection) []u8 {
pub fn host(c: *Connection) HostName {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
@ -400,8 +401,8 @@ pub const Connection = struct {
/// If this is called without calling `flush` or `end`, data will be
/// dropped unsent.
pub fn destroy(c: *Connection) void {
c.getStream().close();
pub fn destroy(c: *Connection, io: Io) void {
c.stream_reader.stream.close(io);
switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
@ -437,7 +438,7 @@ pub const Connection = struct {
const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
return &tls.client.reader;
},
.plain => c.stream_reader.interface(),
.plain => &c.stream_reader.interface,
};
}
@ -866,6 +867,7 @@ pub const Request = struct {
/// Returns the request's `Connection` back to the pool of the `Client`.
pub fn deinit(r: *Request) void {
const io = r.client.io;
if (r.connection) |connection| {
connection.closing = connection.closing or switch (r.reader.state) {
.ready => false,
@ -880,7 +882,7 @@ pub const Request = struct {
},
else => true,
};
r.client.connection_pool.release(connection);
r.client.connection_pool.release(connection, io);
}
r.* = undefined;
}
@ -1182,6 +1184,7 @@ pub const Request = struct {
///
/// `aux_buf` must outlive accesses to `Request.uri`.
fn redirect(r: *Request, head: *const Response.Head, aux_buf: *[]u8) !void {
const io = r.client.io;
const new_location = head.location orelse return error.HttpRedirectLocationMissing;
if (new_location.len > aux_buf.*.len) return error.HttpRedirectLocationOversize;
const location = aux_buf.*[0..new_location.len];
@ -1204,13 +1207,13 @@ pub const Request = struct {
const protocol = Protocol.fromUri(new_uri) orelse return error.UnsupportedUriScheme;
const old_connection = r.connection.?;
const old_host = old_connection.host();
var new_host_name_buffer: [Uri.host_name_max]u8 = undefined;
var new_host_name_buffer: [HostName.max_len]u8 = undefined;
const new_host = try new_uri.getHost(&new_host_name_buffer);
const keep_privileged_headers =
std.ascii.eqlIgnoreCase(r.uri.scheme, new_uri.scheme) and
sameParentDomain(old_host, new_host);
old_host.sameParentDomain(new_host);
r.client.connection_pool.release(old_connection);
r.client.connection_pool.release(old_connection, io);
r.connection = null;
if (!keep_privileged_headers) {
@ -1266,7 +1269,7 @@ pub const Request = struct {
pub const Proxy = struct {
protocol: Protocol,
host: []const u8,
host: HostName,
authorization: ?[]const u8,
port: u16,
supports_connect: bool,
@ -1277,9 +1280,10 @@ pub const Proxy = struct {
/// All pending requests must be de-initialized and all active connections released
/// before calling this function.
pub fn deinit(client: *Client) void {
const io = client.io;
assert(client.connection_pool.used.first == null); // There are still active requests.
client.connection_pool.deinit();
client.connection_pool.deinit(io);
if (!disable_tls) client.ca_bundle.deinit(client.allocator);
client.* = undefined;
@ -1385,7 +1389,7 @@ pub const basic_authorization = struct {
}
};
pub const ConnectTcpError = Allocator.Error || error{
pub const ConnectTcpError = error{
ConnectionRefused,
NetworkUnreachable,
ConnectionTimedOut,
@ -1393,17 +1397,16 @@ pub const ConnectTcpError = Allocator.Error || error{
TemporaryNameServerFailure,
NameServerFailure,
UnknownHostName,
HostLacksNetworkAddresses,
UnexpectedConnectFailure,
TlsInitializationFailed,
};
} || Allocator.Error || Io.Cancelable;
/// Reuses a `Connection` if one matching `host` and `port` is already open.
///
/// Threadsafe.
pub fn connectTcp(
client: *Client,
host: []const u8,
host: HostName,
port: u16,
protocol: Protocol,
) ConnectTcpError!*Connection {
@ -1411,16 +1414,17 @@ pub fn connectTcp(
}
pub const ConnectTcpOptions = struct {
host: Io.net.HostName,
host: HostName,
port: u16,
protocol: Protocol,
proxied_host: ?[]const u8 = null,
proxied_host: ?HostName = null,
proxied_port: ?u16 = null,
timeout: Io.Timeout = .none,
};
pub fn connectTcpOptions(client: *Client, options: ConnectTcpOptions) ConnectTcpError!*Connection {
const host = options.host_name;
const host = options.host;
const port = options.port;
const protocol = options.protocol;
@ -1433,17 +1437,15 @@ pub fn connectTcpOptions(client: *Client, options: ConnectTcpOptions) ConnectTcp
.protocol = protocol,
})) |conn| return conn;
const stream = host.connectTcp(client.io, port) catch |err| switch (err) {
const stream = host.connect(client.io, port, .{ .mode = .stream }) catch |err| switch (err) {
error.ConnectionRefused => return error.ConnectionRefused,
error.NetworkUnreachable => return error.NetworkUnreachable,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure,
error.NameServerFailure => return error.NameServerFailure,
error.UnknownHostName => return error.UnknownHostName,
error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses,
error.Canceled => return error.Canceled,
else => return error.UnexpectedConnectFailure,
//else => return error.UnexpectedConnectFailure,
};
errdefer stream.close();
@ -1479,7 +1481,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
const stream = try std.net.connectUnixSocket(path);
const stream = try Io.net.connectUnixSocket(path);
errdefer stream.close();
conn.data = .{
@ -1504,9 +1506,10 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti
pub fn connectProxied(
client: *Client,
proxy: *Proxy,
proxied_host: []const u8,
proxied_host: HostName,
proxied_port: u16,
) !*Connection {
const io = client.io;
if (!proxy.supports_connect) return error.TunnelNotSupported;
if (client.connection_pool.findConnection(.{
@ -1526,12 +1529,12 @@ pub fn connectProxied(
});
errdefer {
connection.closing = true;
client.connection_pool.release(connection);
client.connection_pool.release(connection, io);
}
var req = client.request(.CONNECT, .{
.scheme = "http",
.host = .{ .raw = proxied_host },
.host = .{ .raw = proxied_host.bytes },
.port = proxied_port,
}, .{
.redirect_behavior = .unhandled,
@ -1576,7 +1579,7 @@ pub const ConnectError = ConnectTcpError || RequestError;
/// This function is threadsafe.
pub fn connect(
client: *Client,
host: []const u8,
host: HostName,
port: u16,
protocol: Protocol,
) ConnectError!*Connection {
@ -1586,9 +1589,7 @@ pub fn connect(
} orelse return client.connectTcp(host, port, protocol);
// Prevent proxying through itself.
if (std.ascii.eqlIgnoreCase(proxy.host, host) and
proxy.port == port and proxy.protocol == protocol)
{
if (proxy.host.eql(host) and proxy.port == port and proxy.protocol == protocol) {
return client.connectTcp(host, port, protocol);
}
@ -1608,7 +1609,6 @@ pub fn connect(
pub const RequestError = ConnectTcpError || error{
UnsupportedUriScheme,
UriMissingHost,
UriHostTooLong,
CertificateBundleLoadFailure,
};
@ -1697,7 +1697,7 @@ pub fn request(
}
const connection = options.connection orelse c: {
var host_name_buffer: [Uri.host_name_max]u8 = undefined;
var host_name_buffer: [HostName.max_len]u8 = undefined;
const host_name = try uri.getHost(&host_name_buffer);
break :c try client.connect(host_name, uriPort(uri, protocol), protocol);
};
@ -1835,20 +1835,6 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
return .{ .status = response.head.status };
}
pub fn sameParentDomain(parent_host: []const u8, child_host: []const u8) bool {
if (!std.ascii.endsWithIgnoreCase(child_host, parent_host)) return false;
if (child_host.len == parent_host.len) return true;
if (parent_host.len > child_host.len) return false;
return child_host[child_host.len - parent_host.len - 1] == '.';
}
test sameParentDomain {
try testing.expect(!sameParentDomain("foo.com", "bar.com"));
try testing.expect(sameParentDomain("foo.com", "foo.com"));
try testing.expect(sameParentDomain("foo.com", "bar.foo.com"));
try testing.expect(!sameParentDomain("bar.foo.com", "foo.com"));
}
test {
_ = Response;
}

View File

@ -53,7 +53,7 @@ test "trailers" {
const gpa = std.testing.allocator;
var client: http.Client = .{ .allocator = gpa };
var client: http.Client = .{ .allocator = gpa, .io = io };
defer client.deinit();
const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/trailer", .{
@ -141,12 +141,13 @@ test "HTTP server handles a chunked transfer coding request" {
"0\r\n" ++
"\r\n";
const gpa = std.testing.allocator;
var stream = try net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port());
const host_name: net.HostName = try .init("127.0.0.1");
var stream = try host_name.connect(io, test_server.port(), .{ .mode = .stream });
defer stream.close(io);
var stream_writer = stream.writer(&.{});
var stream_writer = stream.writer(io, &.{});
try stream_writer.interface.writeAll(request_bytes);
const gpa = std.testing.allocator;
const expected_response =
"HTTP/1.1 200 OK\r\n" ++
"connection: close\r\n" ++
@ -154,8 +155,8 @@ test "HTTP server handles a chunked transfer coding request" {
"content-type: text/plain\r\n" ++
"\r\n" ++
"message from server!\n";
var stream_reader = stream.reader(&.{});
const response = try stream_reader.interface().allocRemaining(gpa, .limited(expected_response.len + 1));
var stream_reader = stream.reader(io, &.{});
const response = try stream_reader.interface.allocRemaining(gpa, .limited(expected_response.len + 1));
defer gpa.free(response);
try expectEqualStrings(expected_response, response);
}
@ -241,7 +242,7 @@ test "echo content server" {
defer test_server.destroy();
{
var client: http.Client = .{ .allocator = std.testing.allocator };
var client: http.Client = .{ .allocator = std.testing.allocator, .io = io };
defer client.deinit();
try echoTests(&client, test_server.port());
@ -294,14 +295,15 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" {
defer test_server.destroy();
const request_bytes = "GET /foo HTTP/1.1\r\n\r\n";
const gpa = std.testing.allocator;
var stream = try net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port());
const host_name: net.HostName = try .init("127.0.0.1");
var stream = try host_name.connect(io, test_server.port(), .{ .mode = .stream });
defer stream.close(io);
var stream_writer = stream.writer(&.{});
var stream_writer = stream.writer(io, &.{});
try stream_writer.interface.writeAll(request_bytes);
var stream_reader = stream.reader(&.{});
const response = try stream_reader.interface().allocRemaining(gpa, .unlimited);
var stream_reader = stream.reader(io, &.{});
const gpa = std.testing.allocator;
const response = try stream_reader.interface.allocRemaining(gpa, .unlimited);
defer gpa.free(response);
var expected_response = std.array_list.Managed(u8).init(gpa);
@ -366,14 +368,15 @@ test "receiving arbitrary http headers from the client" {
"CoNneCtIoN:close\r\n" ++
"aoeu: asdf \r\n" ++
"\r\n";
const gpa = std.testing.allocator;
var stream = try net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port());
const host_name: net.HostName = try .init("127.0.0.1");
var stream = try host_name.connect(io, test_server.port(), .{ .mode = .stream });
defer stream.close(io);
var stream_writer = stream.writer(&.{});
var stream_writer = stream.writer(io, &.{});
try stream_writer.interface.writeAll(request_bytes);
var stream_reader = stream.reader(&.{});
const response = try stream_reader.interface().allocRemaining(gpa, .unlimited);
var stream_reader = stream.reader(io, &.{});
const gpa = std.testing.allocator;
const response = try stream_reader.interface.allocRemaining(gpa, .unlimited);
defer gpa.free(response);
var expected_response = std.array_list.Managed(u8).init(gpa);
@ -413,7 +416,7 @@ test "general client/server API coverage" {
else => |e| return e,
};
try handleRequest(&request, net_server.listen_address.getPort());
try handleRequest(&request, net_server.socket.address.getPort());
}
}
}
@ -543,9 +546,9 @@ test "general client/server API coverage" {
fn getUnusedTcpPort() !u16 {
const addr = try net.IpAddress.parse("127.0.0.1", 0);
var s = try addr.listen(.{});
defer s.deinit();
return s.listen_address.in.getPort();
var s = try addr.listen(io, .{});
defer s.deinit(io);
return s.socket.address.getPort();
}
});
defer test_server.destroy();
@ -553,7 +556,7 @@ test "general client/server API coverage" {
const log = std.log.scoped(.client);
const gpa = std.testing.allocator;
var client: http.Client = .{ .allocator = gpa };
var client: http.Client = .{ .allocator = gpa, .io = io };
defer client.deinit();
const port = test_server.port();
@ -918,7 +921,10 @@ test "Server streams both reading and writing" {
});
defer test_server.destroy();
var client: http.Client = .{ .allocator = std.testing.allocator };
var client: http.Client = .{
.allocator = std.testing.allocator,
.io = io,
};
defer client.deinit();
var redirect_buffer: [555]u8 = undefined;
@ -1089,17 +1095,20 @@ fn echoTests(client: *http.Client, port: u16) !void {
}
const TestServer = struct {
io: Io,
shutting_down: bool,
server_thread: std.Thread,
net_server: net.Server,
fn destroy(self: *@This()) void {
const io = self.io;
self.shutting_down = true;
const conn = net.tcpConnectToAddress(self.net_server.listen_address) catch @panic("shutdown failure");
conn.close();
var stream = self.net_server.socket.address.connect(io, .{ .mode = .stream }) catch
@panic("shutdown failure");
stream.close(io);
self.server_thread.join();
self.net_server.deinit();
self.net_server.deinit(io);
std.testing.allocator.destroy(self);
}
@ -1118,6 +1127,7 @@ fn createTestServer(io: Io, S: type) !*TestServer {
const address = try net.IpAddress.parse("127.0.0.1", 0);
const test_server = try std.testing.allocator.create(TestServer);
test_server.* = .{
.io = io,
.net_server = try address.listen(io, .{ .reuse_address = true }),
.shutting_down = false,
.server_thread = try std.Thread.spawn(.{}, S.run, .{test_server}),