From 15d5988e692c182892a118115fd7025048e06c29 Mon Sep 17 00:00:00 2001 From: Ryan Liptak Date: Sun, 16 Jan 2022 20:11:08 -0800 Subject: [PATCH] Add `process.EnvMap`, a platform-independent environment variable map EnvMap provides the same API as the previously used BufMap (besides `putMove` and `getPtr`), so usage sites of `getEnvMap` can usually remain unchanged. For non-Windows, EnvMap is a wrapper around BufMap. On Windows, it uses a new EnvMapWindows to handle some Windows-specific behavior: - Lookups use Unicode-aware case insensitivity (but `get` cannot return an error because EnvMapWindows has an internal buffer to use for lookup conversions) - Canonical names are returned when iterating the EnvMap Fixes #10561, closes #4603 --- lib/std/buf_map.zig | 2 +- lib/std/os/windows/ntdll.zig | 6 + lib/std/process.zig | 403 +++++++++++++++++++++++++++++++++-- lib/std/unicode.zig | 23 ++ 4 files changed, 421 insertions(+), 13 deletions(-) diff --git a/lib/std/buf_map.zig b/lib/std/buf_map.zig index 5b26ae9684..bd59d6da45 100644 --- a/lib/std/buf_map.zig +++ b/lib/std/buf_map.zig @@ -9,7 +9,7 @@ const testing = std.testing; pub const BufMap = struct { hash_map: BufMapHashMap, - const BufMapHashMap = StringHashMap([]const u8); + pub const BufMapHashMap = StringHashMap([]const u8); /// Create a BufMap backed by a specific allocator. /// That allocator will be used for both backing allocations diff --git a/lib/std/os/windows/ntdll.zig b/lib/std/os/windows/ntdll.zig index e3e590b094..2444d5f487 100644 --- a/lib/std/os/windows/ntdll.zig +++ b/lib/std/os/windows/ntdll.zig @@ -229,6 +229,12 @@ pub extern "ntdll" fn RtlEqualUnicodeString( CaseInSensitive: BOOLEAN, ) callconv(WINAPI) BOOLEAN; +pub extern "NtDll" fn RtlUpcaseUnicodeString( + DestinationString: *UNICODE_STRING, + SourceString: *const UNICODE_STRING, + AllocateDestinationString: BOOLEAN, +) callconv(WINAPI) NTSTATUS; + pub extern "ntdll" fn NtLockFile( FileHandle: HANDLE, Event: ?HANDLE, diff --git a/lib/std/process.zig b/lib/std/process.zig index c0f11b22ce..f5d14cf6da 100644 --- a/lib/std/process.zig +++ b/lib/std/process.zig @@ -2,7 +2,6 @@ const std = @import("std.zig"); const builtin = @import("builtin"); const os = std.os; const fs = std.fs; -const BufMap = std.BufMap; const mem = std.mem; const math = std.math; const Allocator = mem.Allocator; @@ -53,9 +52,385 @@ test "getCwdAlloc" { testing.allocator.free(cwd); } -/// Caller owns resulting `BufMap`. -pub fn getEnvMap(allocator: Allocator) !BufMap { - var result = BufMap.init(allocator); +/// EnvMap for Windows that handles Unicode-aware case insensitivity for lookups, while also +/// providing the canonical environment variable names when iterating. +/// +/// Allows for zero-allocation lookups (even though it needs to do UTF-8 -> UTF-16 -> uppercase +/// conversions) by allocating a buffer large enough to fit the largest environment variable +/// name, and using that when doing lookups (i.e. anything that overflows the buffer can be treated +/// as the environment variable not being found). +pub const EnvMapWindows = struct { + allocator: Allocator, + /// Keys are UTF-16le stored as []const u8 + uppercased_map: std.StringHashMapUnmanaged(EnvValue), + /// Buffer for converting to uppercased UTF-16 on key lookups + /// Must call `reallocUppercaseBuf` before doing any lookups after a `put` call. + uppercase_buf_utf16: []u16 = &[_]u16{}, + max_name_utf16_length: usize = 0, + + pub const EnvValue = struct { + value: []const u8, + canonical_name: []const u8, + }; + + const Self = @This(); + + /// Deinitialize with `deinit`. + pub fn init(allocator: Allocator) Self { + return .{ + .allocator = allocator, + .uppercased_map = std.StringHashMapUnmanaged(EnvValue){}, + }; + } + + pub fn deinit(self: *Self) void { + var it = self.uppercased_map.iterator(); + while (it.next()) |entry| { + self.allocator.free(entry.key_ptr.*); + self.allocator.free(entry.value_ptr.value); + self.allocator.free(entry.value_ptr.canonical_name); + } + self.uppercased_map.deinit(self.allocator); + self.allocator.free(self.uppercase_buf_utf16); + } + + /// Increases the size of the uppercase buffer if the maximum name size has increased. + /// Must be called before any `get` calls after any number of `put` calls. + pub fn reallocUppercaseBuf(self: *Self) !void { + if (self.max_name_utf16_length > self.uppercase_buf_utf16.len) { + self.uppercase_buf_utf16 = try self.allocator.realloc(self.uppercase_buf_utf16, self.max_name_utf16_length); + } + } + + /// Converts `src` to uppercase using `RtlUpcaseUnicodeString` and puts the result in `dest`. + /// Returns the length of the converted UTF-16 string. `dest.len` must be >= `src.len`. + /// + /// Note: As of now, RtlUpcaseUnicodeString does not seem to handle codepoints above 0x10000 + /// (i.e. those that require a surrogate pair), so this function will always return a length + /// equal to `src.len`. However, if RtlUpcaseUnicodeString is updated to handle codepoints above + /// 0x10000, this property would still hold unless there are lowercase <-> uppercase conversions + /// that cross over the boundary between codepoints >= 0x10000 and < 0x10000. + /// TODO: Is it feasible that Unicode lowercase <-> uppercase conversions could cross that boundary? + fn uppercaseName(dest: []u16, src: []const u16) u16 { + assert(dest.len >= src.len); + + const dest_bytes = @intCast(u16, dest.len * 2); + var dest_string = os.windows.UNICODE_STRING{ + .Length = dest_bytes, + .MaximumLength = dest_bytes, + .Buffer = @intToPtr([*]u16, @ptrToInt(dest.ptr)), + }; + const src_bytes = @intCast(u16, src.len * 2); + const src_string = os.windows.UNICODE_STRING{ + .Length = src_bytes, + .MaximumLength = src_bytes, + .Buffer = @intToPtr([*]u16, @ptrToInt(src.ptr)), + }; + const rc = os.windows.ntdll.RtlUpcaseUnicodeString(&dest_string, &src_string, os.windows.FALSE); + switch (rc) { + .SUCCESS => return dest_string.Length / 2, + else => unreachable, // we are not allocating, so no errors should be possible + } + } + + /// Note: Does not realloc the uppercase buf to allow for calling put for many variables and + /// only allocating the uppercase buf afterwards. + pub fn putUtf8(self: *Self, name: []const u8, value: []const u8) !void { + const uppercased_len = len: { + const name_uppercased_utf16 = uppercased: { + var name_utf16_buf = try std.ArrayListAligned(u8, @alignOf(u16)).initCapacity(self.allocator, name.len); + errdefer name_utf16_buf.deinit(); + + var uppercased_len = try std.unicode.utf8ToUtf16LeWriter(name_utf16_buf.writer(), name); + assert(uppercased_len == name_utf16_buf.items.len); + + break :uppercased name_utf16_buf.toOwnedSlice(); + }; + errdefer self.allocator.free(name_uppercased_utf16); + + const name_canonical = try self.allocator.dupe(u8, name); + errdefer self.allocator.free(name_canonical); + + const value_dupe = try self.allocator.dupe(u8, value); + errdefer self.allocator.free(value_dupe); + + const get_or_put = try self.uppercased_map.getOrPut(self.allocator, name_uppercased_utf16); + if (get_or_put.found_existing) { + // note: this is only safe from UAF because the errdefer that frees this value above + // no longer has a possibility of being triggered after this point + self.allocator.free(name_uppercased_utf16); + self.allocator.free(get_or_put.value_ptr.value); + self.allocator.free(get_or_put.value_ptr.canonical_name); + } else { + get_or_put.key_ptr.* = name_uppercased_utf16; + } + get_or_put.value_ptr.value = value_dupe; + get_or_put.value_ptr.canonical_name = name_canonical; + + break :len name_uppercased_utf16.len; + }; + + // The buffer for case conversion for key lookups will need to be as big as the largest + // key stored in the hash map. + self.max_name_utf16_length = @maximum(self.max_name_utf16_length, uppercased_len); + } + + /// Asserts that the name does not already exist in the map. + /// Note: Does not realloc the uppercase buf to allow for calling put for many variables and + /// only allocating the uppercase buf afterwards. + pub fn putUtf16NoClobber(self: *Self, name_utf16: []const u16, value_utf16: []const u16) !void { + const uppercased_len = len: { + const name_canonical = try std.unicode.utf16leToUtf8Alloc(self.allocator, name_utf16); + errdefer self.allocator.free(name_canonical); + + const value = try std.unicode.utf16leToUtf8Alloc(self.allocator, value_utf16); + errdefer self.allocator.free(value); + + const name_uppercased_utf16 = try self.allocator.alloc(u16, name_utf16.len); + errdefer self.allocator.free(name_uppercased_utf16); + + const uppercased_len = uppercaseName(name_uppercased_utf16, name_utf16); + assert(uppercased_len == name_uppercased_utf16.len); + + try self.uppercased_map.putNoClobber(self.allocator, std.mem.sliceAsBytes(name_uppercased_utf16), EnvValue{ + .value = value, + .canonical_name = name_canonical, + }); + break :len name_uppercased_utf16.len; + }; + + // The buffer for case conversion for key lookups will need to be as big as the largest + // key stored in the hash map. + self.max_name_utf16_length = @maximum(self.max_name_utf16_length, uppercased_len); + } + + /// Attempts to convert a UTF-8 name into a uppercased UTF-16le name for a lookup. If the + /// name cannot be converted, this function will return `null`. + fn utf8ToUppercasedUtf16(self: Self, name: []const u8) ?[]u16 { + const name_utf16: []u16 = to_utf16: { + var utf16_buf_stream = std.io.fixedBufferStream(std.mem.sliceAsBytes(self.uppercase_buf_utf16)); + _ = std.unicode.utf8ToUtf16LeWriter(utf16_buf_stream.writer(), name) catch |err| switch (err) { + // If the buffer isn't large enough, we can treat that as 'env var not found', as we + // know anything too large for the buffer can't be found in the map. + error.NoSpaceLeft => return null, + // Anything with invalid UTF-8 will also not be found in the map, so treat that as + // 'env var not found' too + error.InvalidUtf8 => return null, + }; + break :to_utf16 std.mem.bytesAsSlice(u16, utf16_buf_stream.getWritten()); + }; + + // uppercase in place + const uppercased_len = uppercaseName(name_utf16, name_utf16); + assert(uppercased_len == name_utf16.len); + + return name_utf16; + } + + /// Returns true if an entry was found and deleted, false otherwise. + pub fn remove(self: *Self, name: []const u8) bool { + const name_utf16 = self.utf8ToUppercasedUtf16(name) orelse return false; + const kv = self.uppercased_map.fetchRemove(std.mem.sliceAsBytes(name_utf16)) orelse return false; + self.allocator.free(kv.key); + self.allocator.free(kv.value.value); + self.allocator.free(kv.value.canonical_name); + return true; + } + + pub fn get(self: Self, name: []const u8) ?EnvValue { + const name_utf16 = self.utf8ToUppercasedUtf16(name) orelse return null; + return self.uppercased_map.get(std.mem.sliceAsBytes(name_utf16)); + } + + pub fn count(self: Self) EnvMap.Size { + return self.uppercased_map.count(); + } + + pub fn iterator(self: *const Self) Iterator { + return .{ + .env_map = self, + .uppercased_map_iterator = self.uppercased_map.iterator(), + }; + } + + pub const Iterator = struct { + env_map: *const Self, + uppercased_map_iterator: std.StringHashMapUnmanaged(EnvValue).Iterator, + + pub fn next(it: *Iterator) ?EnvMap.Entry { + if (it.uppercased_map_iterator.next()) |uppercased_entry| { + return EnvMap.Entry{ + .name = uppercased_entry.value_ptr.canonical_name, + .value = uppercased_entry.value_ptr.value, + }; + } else { + return null; + } + } + }; +}; + +test "EnvMapWindows" { + if (builtin.os.tag != .windows) return error.SkipZigTest; + + var env_map = EnvMapWindows.init(testing.allocator); + defer env_map.deinit(); + + // both put methods + try env_map.putUtf16NoClobber(std.unicode.utf8ToUtf16LeStringLiteral("Path"), std.unicode.utf8ToUtf16LeStringLiteral("something")); + try env_map.putUtf8("КИРИЛЛИЦА", "something else"); + try env_map.reallocUppercaseBuf(); + + try testing.expectEqual(@as(EnvMap.Size, 2), env_map.count()); + + // unicode-aware case-insensitive lookups + try testing.expectEqualStrings("something", env_map.get("PATH").?.value); + try testing.expectEqualStrings("something else", env_map.get("кириллица").?.value); + try testing.expect(env_map.get("missing") == null); + + // canonical names when iterating + var it = env_map.iterator(); + var count: EnvMap.Size = 0; + while (it.next()) |entry| { + const is_an_expected_name = std.mem.eql(u8, "Path", entry.name) or std.mem.eql(u8, "КИРИЛЛИЦА", entry.name); + try testing.expect(is_an_expected_name); + count += 1; + } + try testing.expectEqual(@as(EnvMap.Size, 2), count); +} + +pub const EnvMap = struct { + storage: StorageType, + + pub const StorageType = switch (builtin.os.tag) { + .windows => EnvMapWindows, + else => std.BufMap, + }; + + /// Matches what BufMap uses for its internal HashMap Size + pub const Size = u32; + + const Self = @This(); + + /// Deinitialize with `deinit`. + pub fn init(allocator: Allocator) Self { + return Self{ .storage = StorageType.init(allocator) }; + } + + pub fn deinit(self: *Self) void { + self.storage.deinit(); + } + + pub fn get(self: Self, name: []const u8) ?[]const u8 { + switch (builtin.os.tag) { + .windows => { + if (self.storage.get(name)) |entry| { + return entry.value; + } else { + return null; + } + }, + else => return self.storage.get(name), + } + } + + pub fn count(self: Self) Size { + return self.storage.count(); + } + + pub fn iterator(self: *const Self) Iterator { + return .{ .storage_iterator = self.storage.iterator() }; + } + + pub fn put(self: *Self, name: []const u8, value: []const u8) !void { + switch (builtin.os.tag) { + .windows => { + try self.storage.putUtf8(name, value); + try self.storage.reallocUppercaseBuf(); + }, + else => return self.storage.put(name, value), + } + } + + pub fn remove(self: *Self, name: []const u8) void { + _ = self.storage.remove(name); + } + + pub const Entry = struct { + name: []const u8, + value: []const u8, + }; + + pub const Iterator = struct { + storage_iterator: switch (builtin.os.tag) { + .windows => EnvMapWindows.Iterator, + else => std.BufMap.BufMapHashMap.Iterator, + }, + + pub fn next(it: *Iterator) ?Entry { + switch (builtin.os.tag) { + .windows => return it.storage_iterator.next(), + else => { + if (it.storage_iterator.next()) |entry| { + return Entry{ + .name = entry.key_ptr.*, + .value = entry.value_ptr.*, + }; + } else { + return null; + } + }, + } + } + }; +}; + +test "EnvMap" { + var env = EnvMap.init(testing.allocator); + defer env.deinit(); + + try env.put("SOMETHING_NEW", "hello"); + try testing.expectEqualStrings("hello", env.get("SOMETHING_NEW").?); + try testing.expectEqual(@as(EnvMap.Size, 1), env.count()); + + // overwrite + try env.put("SOMETHING_NEW", "something"); + try testing.expectEqualStrings("something", env.get("SOMETHING_NEW").?); + try testing.expectEqual(@as(EnvMap.Size, 1), env.count()); + + // a new longer name to test the Windows-specific conversion buffer + try env.put("SOMETHING_NEW_AND_LONGER", "1"); + try testing.expectEqualStrings("1", env.get("SOMETHING_NEW_AND_LONGER").?); + try testing.expectEqual(@as(EnvMap.Size, 2), env.count()); + + // case insensitivity on Windows only + if (builtin.os.tag == .windows) { + try testing.expectEqualStrings("1", env.get("something_New_aNd_LONGER").?); + } else { + try testing.expect(null == env.get("something_New_aNd_LONGER")); + } + + var it = env.iterator(); + var count: EnvMap.Size = 0; + while (it.next()) |entry| { + const is_an_expected_name = std.mem.eql(u8, "SOMETHING_NEW", entry.name) or std.mem.eql(u8, "SOMETHING_NEW_AND_LONGER", entry.name); + try testing.expect(is_an_expected_name); + count += 1; + } + try testing.expectEqual(@as(EnvMap.Size, 2), count); + + env.remove("SOMETHING_NEW"); + try testing.expect(env.get("SOMETHING_NEW") == null); + + try testing.expectEqual(@as(EnvMap.Size, 1), env.count()); +} + +/// Returns a snapshot of the environment variables of the current process. +/// Any modifications to the resulting EnvMap will not be not reflected in the environment, and +/// likewise, any future modifications to the environment will not be reflected in the EnvMap. +/// Caller owns resulting `EnvMap` and should call its `deinit` fn when done. +pub fn getEnvMap(allocator: Allocator) !EnvMap { + var result = EnvMap.init(allocator); errdefer result.deinit(); if (builtin.os.tag == .windows) { @@ -65,23 +440,27 @@ pub fn getEnvMap(allocator: Allocator) !BufMap { while (ptr[i] != 0) { const key_start = i; + // There are some special environment variables that start with =, + // so we need a special case to not treat = as a key/value separator + // if it's the first character. + // https://devblogs.microsoft.com/oldnewthing/20100506-00/?p=14133 + if (ptr[key_start] == '=') i += 1; + while (ptr[i] != 0 and ptr[i] != '=') : (i += 1) {} const key_w = ptr[key_start..i]; - const key = try std.unicode.utf16leToUtf8Alloc(allocator, key_w); - errdefer allocator.free(key); if (ptr[i] == '=') i += 1; const value_start = i; while (ptr[i] != 0) : (i += 1) {} const value_w = ptr[value_start..i]; - const value = try std.unicode.utf16leToUtf8Alloc(allocator, value_w); - errdefer allocator.free(value); + + try result.storage.putUtf16NoClobber(key_w, value_w); i += 1; // skip over null byte - - try result.putMove(key, value); } + + try result.storage.reallocUppercaseBuf(); return result; } else if (builtin.os.tag == .wasi and !builtin.link_libc) { var environ_count: usize = undefined; @@ -140,8 +519,8 @@ pub fn getEnvMap(allocator: Allocator) !BufMap { } } -test "os.getEnvMap" { - var env = try getEnvMap(std.testing.allocator); +test "getEnvMap" { + var env = try getEnvMap(testing.allocator); defer env.deinit(); } diff --git a/lib/std/unicode.zig b/lib/std/unicode.zig index 81a7ed838f..706b12105a 100644 --- a/lib/std/unicode.zig +++ b/lib/std/unicode.zig @@ -710,6 +710,29 @@ pub fn utf8ToUtf16Le(utf16le: []u16, utf8: []const u8) !usize { return dest_i; } +pub fn utf8ToUtf16LeWriter(writer: anytype, utf8: []const u8) !usize { + var src_i: usize = 0; + var bytes_written: usize = 0; + while (src_i < utf8.len) { + const n = utf8ByteSequenceLength(utf8[src_i]) catch return error.InvalidUtf8; + const next_src_i = src_i + n; + const codepoint = utf8Decode(utf8[src_i..next_src_i]) catch return error.InvalidUtf8; + if (codepoint < 0x10000) { + const short = @intCast(u16, codepoint); + try writer.writeIntLittle(u16, short); + bytes_written += 2; + } else { + const high = @intCast(u16, (codepoint - 0x10000) >> 10) + 0xD800; + const low = @intCast(u16, codepoint & 0x3FF) + 0xDC00; + try writer.writeIntLittle(u16, high); + try writer.writeIntLittle(u16, low); + bytes_written += 4; + } + src_i = next_src_i; + } + return bytes_written; +} + test "utf8ToUtf16Le" { var utf16le: [2]u16 = [_]u16{0} ** 2; {