From 0bde55e881a14e4a871c216610dd10b386640f2e Mon Sep 17 00:00:00 2001 From: viri Date: Sat, 15 Jan 2022 17:40:45 -0600 Subject: [PATCH] std.Thread(windows): use NT internals for name fns --- lib/std/Thread.zig | 60 +++++++++++++++++++++------------ lib/std/os/windows.zig | 15 --------- lib/std/os/windows/kernel32.zig | 3 -- 3 files changed, 38 insertions(+), 40 deletions(-) diff --git a/lib/std/Thread.zig b/lib/std/Thread.zig index 83c2992a45..05f6fd66d3 100644 --- a/lib/std/Thread.zig +++ b/lib/std/Thread.zig @@ -4,6 +4,7 @@ const std = @import("std.zig"); const builtin = @import("builtin"); +const math = std.math; const os = std.os; const assert = std.debug.assert; const target = builtin.target; @@ -85,20 +86,28 @@ pub fn setName(self: Thread, name: []const u8) SetNameError!void { try file.writer().writeAll(name); return; }, - .windows => if (target.os.isAtLeast(.windows, .win10_rs1)) |res| { - // SetThreadDescription is only available since version 1607, which is 10.0.14393.795 - // See https://en.wikipedia.org/wiki/Microsoft_Windows_SDK - if (!res) return error.Unsupported; + .windows => { + var buf: [max_name_len]u16 = undefined; + const len = try std.unicode.utf8ToUtf16Le(&buf, name); + const byte_len = math.cast(c_ushort, len * 2) catch return error.NameTooLong; - var name_buf_w: [max_name_len:0]u16 = undefined; - const length = try std.unicode.utf8ToUtf16Le(&name_buf_w, name); - name_buf_w[length] = 0; + // Note: NT allocates its own copy, no use-after-free here. + const unicode_string = os.windows.UNICODE_STRING{ + .Length = byte_len, + .MaximumLength = byte_len, + .Buffer = &buf, + }; - try os.windows.SetThreadDescription( + switch (os.windows.ntdll.NtSetInformationThread( self.getHandle(), - @ptrCast(os.windows.LPWSTR, &name_buf_w), - ); - return; + .ThreadNameInformation, + &unicode_string, + @sizeOf(os.windows.UNICODE_STRING), + )) { + .SUCCESS => return, + .NOT_IMPLEMENTED => return error.Unsupported, + else => |err| return os.windows.unexpectedStatus(err), + } }, .macos, .ios, .watchos, .tvos => if (use_pthreads) { // There doesn't seem to be a way to set the name for an arbitrary thread, only the current one. @@ -188,18 +197,25 @@ pub fn getName(self: Thread, buffer_ptr: *[max_name_len:0]u8) GetNameError!?[]co // musl doesn't provide pthread_getname_np and there's no way to retrieve the thread id of an arbitrary thread. return error.Unsupported; }, - .windows => if (target.os.isAtLeast(.windows, .win10_rs1)) |res| { - // GetThreadDescription is only available since version 1607, which is 10.0.14393.795 - // See https://en.wikipedia.org/wiki/Microsoft_Windows_SDK - if (!res) return error.Unsupported; + .windows => { + const buf_capacity = @sizeOf(os.windows.UNICODE_STRING) + (@sizeOf(u16) * max_name_len); + var buf: [buf_capacity]u8 align(@alignOf(os.windows.UNICODE_STRING)) = undefined; - var name_w: os.windows.LPWSTR = undefined; - try os.windows.GetThreadDescription(self.getHandle(), &name_w); - defer os.windows.LocalFree(name_w); - - const data_len = try std.unicode.utf16leToUtf8(buffer, std.mem.sliceTo(name_w, 0)); - - return if (data_len >= 1) buffer[0..data_len] else null; + switch (os.windows.ntdll.NtQueryInformationThread( + self.getHandle(), + .ThreadNameInformation, + &buf, + buf_capacity, + null, + )) { + .SUCCESS => { + const string = @ptrCast(*const os.windows.UNICODE_STRING, &buf); + const len = try std.unicode.utf16leToUtf8(buffer, string.Buffer[0 .. string.Length / 2]); + return if (len > 0) buffer[0..len] else null; + }, + .NOT_IMPLEMENTED => return error.Unsupported, + else => |err| return os.windows.unexpectedStatus(err), + } }, .macos, .ios, .watchos, .tvos => if (use_pthreads) { const err = std.c.pthread_getname_np(self.getHandle(), buffer.ptr, max_name_len + 1); diff --git a/lib/std/os/windows.zig b/lib/std/os/windows.zig index 7642fa80f8..229889fbd0 100644 --- a/lib/std/os/windows.zig +++ b/lib/std/os/windows.zig @@ -2029,21 +2029,6 @@ pub fn unexpectedStatus(status: NTSTATUS) std.os.UnexpectedError { return error.Unexpected; } -pub fn SetThreadDescription(hThread: HANDLE, lpThreadDescription: LPCWSTR) !void { - if (kernel32.SetThreadDescription(hThread, lpThreadDescription) == 0) { - switch (kernel32.GetLastError()) { - else => |err| return unexpectedError(err), - } - } -} -pub fn GetThreadDescription(hThread: HANDLE, ppszThreadDescription: *LPWSTR) !void { - if (kernel32.GetThreadDescription(hThread, ppszThreadDescription) == 0) { - switch (kernel32.GetLastError()) { - else => |err| return unexpectedError(err), - } - } -} - pub const Win32Error = @import("windows/win32error.zig").Win32Error; pub const NTSTATUS = @import("windows/ntstatus.zig").NTSTATUS; pub const LANG = @import("windows/lang.zig"); diff --git a/lib/std/os/windows/kernel32.zig b/lib/std/os/windows/kernel32.zig index b602921648..dfec1b9c6d 100644 --- a/lib/std/os/windows/kernel32.zig +++ b/lib/std/os/windows/kernel32.zig @@ -400,6 +400,3 @@ pub extern "kernel32" fn SleepConditionVariableSRW( pub extern "kernel32" fn TryAcquireSRWLockExclusive(s: *SRWLOCK) callconv(WINAPI) BOOLEAN; pub extern "kernel32" fn AcquireSRWLockExclusive(s: *SRWLOCK) callconv(WINAPI) void; pub extern "kernel32" fn ReleaseSRWLockExclusive(s: *SRWLOCK) callconv(WINAPI) void; - -pub extern "kernel32" fn SetThreadDescription(hThread: HANDLE, lpThreadDescription: LPCWSTR) callconv(WINAPI) HRESULT; -pub extern "kernel32" fn GetThreadDescription(hThread: HANDLE, ppszThreadDescription: *LPWSTR) callconv(WINAPI) HRESULT;