diff --git a/lib/std/io.zig b/lib/std/io.zig index 27c6337626..75693b8b1e 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -512,9 +512,17 @@ pub fn Poller(comptime StreamEnum: type) type { pub fn poll(self: *Self) !bool { if (builtin.os.tag == .windows) { - return pollWindows(self); + return pollWindows(self, null); } else { - return pollPosix(self); + return pollPosix(self, null); + } + } + + pub fn pollTimeout(self: *Self, nanoseconds: u64) !bool { + if (builtin.os.tag == .windows) { + return pollWindows(self, nanoseconds); + } else { + return pollPosix(self, nanoseconds); } } @@ -522,7 +530,7 @@ pub fn Poller(comptime StreamEnum: type) type { return &self.fifos[@intFromEnum(which)]; } - fn pollWindows(self: *Self) !bool { + fn pollWindows(self: *Self, nanoseconds: ?u64) !bool { const bump_amt = 512; if (!self.windows.first_read_done) { @@ -553,10 +561,15 @@ pub fn Poller(comptime StreamEnum: type) type { self.windows.active.count, &self.windows.active.handles_buf, 0, - os.windows.INFINITE, + if (nanoseconds) |ns| + @min(std.math.cast(u32, ns / std.time.ns_per_ms) orelse (os.windows.INFINITE - 1), os.windows.INFINITE - 1) + else + os.windows.INFINITE, ); if (status == os.windows.WAIT_FAILED) return os.windows.unexpectedError(os.windows.kernel32.GetLastError()); + if (status == os.windows.WAIT_TIMEOUT) + return true; if (status < os.windows.WAIT_OBJECT_0 or status > os.windows.WAIT_OBJECT_0 + enum_fields.len - 1) unreachable; @@ -594,7 +607,7 @@ pub fn Poller(comptime StreamEnum: type) type { } } - fn pollPosix(self: *Self) !bool { + fn pollPosix(self: *Self, nanoseconds: ?u64) !bool { // We ask for ensureUnusedCapacity with this much extra space. This // has more of an effect on small reads because once the reads // start to get larger the amount of space an ArrayList will @@ -603,7 +616,10 @@ pub fn Poller(comptime StreamEnum: type) type { const err_mask = os.POLL.ERR | os.POLL.NVAL | os.POLL.HUP; - const events_len = try os.poll(&self.poll_fds, std.math.maxInt(i32)); + const events_len = try os.poll(&self.poll_fds, if (nanoseconds) |ns| + std.math.cast(i32, ns / std.time.ns_per_ms) orelse std.math.maxInt(i32) + else + -1); if (events_len == 0) { for (self.poll_fds) |poll_fd| { if (poll_fd.fd != -1) return true;