From 600b652825ef4e1360f6e9d6ad5a20c32fada5ca Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 24 Apr 2024 13:49:43 -0700 Subject: [PATCH] Merge pull request #19698 from squeek502/windows-batbadbut std.process.Child: Mitigate arbitrary command execution vulnerability on Windows (BatBadBut) --- lib/std/child_process.zig | 288 +++++++++++++++++- lib/std/os/windows/kernel32.zig | 2 + lib/std/unicode.zig | 70 ++++- test/standalone/build.zig.zon | 3 + test/standalone/windows_bat_args/build.zig | 58 ++++ .../standalone/windows_bat_args/echo-args.zig | 14 + test/standalone/windows_bat_args/fuzz.zig | 160 ++++++++++ test/standalone/windows_bat_args/test.zig | 132 ++++++++ 8 files changed, 709 insertions(+), 18 deletions(-) create mode 100644 test/standalone/windows_bat_args/build.zig create mode 100644 test/standalone/windows_bat_args/echo-args.zig create mode 100644 test/standalone/windows_bat_args/fuzz.zig create mode 100644 test/standalone/windows_bat_args/test.zig diff --git a/lib/std/child_process.zig b/lib/std/child_process.zig index dcc00b77d5..e6069aec8e 100644 --- a/lib/std/child_process.zig +++ b/lib/std/child_process.zig @@ -136,6 +136,14 @@ pub const ChildProcess = struct { /// Windows-only. `cwd` was provided, but the path did not exist when spawning the child process. CurrentWorkingDirectoryUnlinked, + + /// Windows-only. NUL (U+0000), LF (U+000A), CR (U+000D) are not allowed + /// within arguments when executing a `.bat`/`.cmd` script. + /// - NUL/LF signifiies end of arguments, so anything afterwards + /// would be lost after execution. + /// - CR is stripped by `cmd.exe`, so any CR codepoints + /// would be lost after execution. + InvalidBatchScriptArg, } || posix.ExecveError || posix.SetIdError || @@ -814,17 +822,20 @@ pub const ChildProcess = struct { const app_name_w = try unicode.wtf8ToWtf16LeAllocZ(self.allocator, app_basename_wtf8); defer self.allocator.free(app_name_w); - const cmd_line_w = argvToCommandLineWindows(self.allocator, self.argv) catch |err| switch (err) { - // argv[0] contains unsupported characters that will never resolve to a valid exe. - error.InvalidArg0 => return error.FileNotFound, - else => |e| return e, - }; - defer self.allocator.free(cmd_line_w); - run: { const PATH: [:0]const u16 = std.process.getenvW(unicode.utf8ToUtf16LeStringLiteral("PATH")) orelse &[_:0]u16{}; const PATHEXT: [:0]const u16 = std.process.getenvW(unicode.utf8ToUtf16LeStringLiteral("PATHEXT")) orelse &[_:0]u16{}; + // In case the command ends up being a .bat/.cmd script, we need to escape things using the cmd.exe rules + // and invoke cmd.exe ourselves in order to mitigate arbitrary command execution from maliciously + // constructed arguments. + // + // We'll need to wait until we're actually trying to run the command to know for sure + // if the resolved command has the `.bat` or `.cmd` extension, so we defer actually + // serializing the command line until we determine how it should be serialized. + var cmd_line_cache = WindowsCommandLineCache.init(self.allocator, self.argv); + defer cmd_line_cache.deinit(); + var app_buf = std.ArrayListUnmanaged(u16){}; defer app_buf.deinit(self.allocator); @@ -846,8 +857,10 @@ pub const ChildProcess = struct { dir_buf.shrinkRetainingCapacity(normalized_len); } - windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, cmd_line_w.ptr, envp_ptr, cwd_w_ptr, &siStartInfo, &piProcInfo) catch |no_path_err| { + windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, &cmd_line_cache, envp_ptr, cwd_w_ptr, &siStartInfo, &piProcInfo) catch |no_path_err| { const original_err = switch (no_path_err) { + // argv[0] contains unsupported characters that will never resolve to a valid exe. + error.InvalidArg0 => return error.FileNotFound, error.FileNotFound, error.InvalidExe, error.AccessDenied => |e| e, error.UnrecoverableInvalidExe => return error.InvalidExe, else => |e| return e, @@ -872,9 +885,11 @@ pub const ChildProcess = struct { const normalized_len = windows.normalizePath(u16, dir_buf.items) catch continue; dir_buf.shrinkRetainingCapacity(normalized_len); - if (windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, cmd_line_w.ptr, envp_ptr, cwd_w_ptr, &siStartInfo, &piProcInfo)) { + if (windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, &cmd_line_cache, envp_ptr, cwd_w_ptr, &siStartInfo, &piProcInfo)) { break :run; } else |err| switch (err) { + // argv[0] contains unsupported characters that will never resolve to a valid exe. + error.InvalidArg0 => return error.FileNotFound, error.FileNotFound, error.AccessDenied, error.InvalidExe => continue, error.UnrecoverableInvalidExe => return error.InvalidExe, else => |e| return e, @@ -935,7 +950,7 @@ fn windowsCreateProcessPathExt( dir_buf: *std.ArrayListUnmanaged(u16), app_buf: *std.ArrayListUnmanaged(u16), pathext: [:0]const u16, - cmd_line: [*:0]u16, + cmd_line_cache: *WindowsCommandLineCache, envp_ptr: ?[*]u16, cwd_ptr: ?[*:0]u16, lpStartupInfo: *windows.STARTUPINFOW, @@ -1069,7 +1084,26 @@ fn windowsCreateProcessPathExt( try dir_buf.append(allocator, 0); const full_app_name = dir_buf.items[0 .. dir_buf.items.len - 1 :0]; - if (windowsCreateProcess(full_app_name.ptr, cmd_line, envp_ptr, cwd_ptr, lpStartupInfo, lpProcessInformation)) |_| { + const is_bat_or_cmd = bat_or_cmd: { + const app_name = app_buf.items[0..app_name_len]; + const ext_start = std.mem.lastIndexOfScalar(u16, app_name, '.') orelse break :bat_or_cmd false; + const ext = app_name[ext_start..]; + const ext_enum = windowsCreateProcessSupportsExtension(ext) orelse break :bat_or_cmd false; + switch (ext_enum) { + .cmd, .bat => break :bat_or_cmd true, + else => break :bat_or_cmd false, + } + }; + const cmd_line_w = if (is_bat_or_cmd) + try cmd_line_cache.scriptCommandLine(full_app_name) + else + try cmd_line_cache.commandLine(); + const app_name_w = if (is_bat_or_cmd) + try cmd_line_cache.cmdExePath() + else + full_app_name; + + if (windowsCreateProcess(app_name_w.ptr, cmd_line_w.ptr, envp_ptr, cwd_ptr, lpStartupInfo, lpProcessInformation)) |_| { return; } else |err| switch (err) { error.FileNotFound, @@ -1111,7 +1145,20 @@ fn windowsCreateProcessPathExt( try dir_buf.append(allocator, 0); const full_app_name = dir_buf.items[0 .. dir_buf.items.len - 1 :0]; - if (windowsCreateProcess(full_app_name.ptr, cmd_line, envp_ptr, cwd_ptr, lpStartupInfo, lpProcessInformation)) |_| { + const is_bat_or_cmd = switch (ext_enum) { + .cmd, .bat => true, + else => false, + }; + const cmd_line_w = if (is_bat_or_cmd) + try cmd_line_cache.scriptCommandLine(full_app_name) + else + try cmd_line_cache.commandLine(); + const app_name_w = if (is_bat_or_cmd) + try cmd_line_cache.cmdExePath() + else + full_app_name; + + if (windowsCreateProcess(app_name_w.ptr, cmd_line_w.ptr, envp_ptr, cwd_ptr, lpStartupInfo, lpProcessInformation)) |_| { return; } else |err| switch (err) { error.FileNotFound => continue, @@ -1236,6 +1283,223 @@ test windowsCreateProcessSupportsExtension { try std.testing.expect(windowsCreateProcessSupportsExtension(&[_]u16{ '.', 'e', 'X', 'e', 'c' }) == null); } +/// Serializes argv into a WTF-16 encoded command-line string for use with CreateProcessW. +/// +/// Serialization is done on-demand and the result is cached in order to allow for: +/// - Only serializing the particular type of command line needed (`.bat`/`.cmd` +/// command line serialization is different from `.exe`/etc) +/// - Reusing the serialized command lines if necessary (i.e. if the execution +/// of a command fails and the PATH is going to be continued to be searched +/// for more candidates) +pub const WindowsCommandLineCache = struct { + cmd_line: ?[:0]u16 = null, + script_cmd_line: ?[:0]u16 = null, + cmd_exe_path: ?[:0]u16 = null, + argv: []const []const u8, + allocator: mem.Allocator, + + pub fn init(allocator: mem.Allocator, argv: []const []const u8) WindowsCommandLineCache { + return .{ + .allocator = allocator, + .argv = argv, + }; + } + + pub fn deinit(self: *WindowsCommandLineCache) void { + if (self.cmd_line) |cmd_line| self.allocator.free(cmd_line); + if (self.script_cmd_line) |script_cmd_line| self.allocator.free(script_cmd_line); + if (self.cmd_exe_path) |cmd_exe_path| self.allocator.free(cmd_exe_path); + } + + pub fn commandLine(self: *WindowsCommandLineCache) ![:0]u16 { + if (self.cmd_line == null) { + self.cmd_line = try argvToCommandLineWindows(self.allocator, self.argv); + } + return self.cmd_line.?; + } + + /// Not cached, since the path to the batch script will change during PATH searching. + /// `script_path` should be as qualified as possible, e.g. if the PATH is being searched, + /// then script_path should include both the search path and the script filename + /// (this allows avoiding cmd.exe having to search the PATH again). + pub fn scriptCommandLine(self: *WindowsCommandLineCache, script_path: []const u16) ![:0]u16 { + if (self.script_cmd_line) |v| self.allocator.free(v); + self.script_cmd_line = try argvToScriptCommandLineWindows( + self.allocator, + script_path, + self.argv[1..], + ); + return self.script_cmd_line.?; + } + + pub fn cmdExePath(self: *WindowsCommandLineCache) ![:0]u16 { + if (self.cmd_exe_path == null) { + self.cmd_exe_path = try windowsCmdExePath(self.allocator); + } + return self.cmd_exe_path.?; + } +}; + +pub fn windowsCmdExePath(allocator: mem.Allocator) error{ OutOfMemory, Unexpected }![:0]u16 { + var buf = try std.ArrayListUnmanaged(u16).initCapacity(allocator, 128); + errdefer buf.deinit(allocator); + while (true) { + const unused_slice = buf.unusedCapacitySlice(); + // TODO: Get the system directory from PEB.ReadOnlyStaticServerData + const len = windows.kernel32.GetSystemDirectoryW(@ptrCast(unused_slice), @intCast(unused_slice.len)); + if (len == 0) { + switch (windows.kernel32.GetLastError()) { + else => |err| return windows.unexpectedError(err), + } + } + if (len > unused_slice.len) { + try buf.ensureUnusedCapacity(allocator, len); + } else { + buf.items.len = len; + break; + } + } + switch (buf.items[buf.items.len - 1]) { + '/', '\\' => {}, + else => try buf.append(allocator, fs.path.sep), + } + try buf.appendSlice(allocator, std.unicode.utf8ToUtf16LeStringLiteral("cmd.exe")); + return try buf.toOwnedSliceSentinel(allocator, 0); +} + +pub const ArgvToScriptCommandLineError = error{ + OutOfMemory, + InvalidWtf8, + /// NUL (U+0000), LF (U+000A), CR (U+000D) are not allowed + /// within arguments when executing a `.bat`/`.cmd` script. + /// - NUL/LF signifiies end of arguments, so anything afterwards + /// would be lost after execution. + /// - CR is stripped by `cmd.exe`, so any CR codepoints + /// would be lost after execution. + InvalidBatchScriptArg, +}; + +/// Serializes `argv` to a Windows command-line string that uses `cmd.exe /c` and `cmd.exe`-specific +/// escaping rules. The caller owns the returned slice. +/// +/// Escapes `argv` using the suggested mitigation against arbitrary command execution from: +/// https://flatt.tech/research/posts/batbadbut-you-cant-securely-execute-commands-on-windows/ +pub fn argvToScriptCommandLineWindows( + allocator: mem.Allocator, + /// Path to the `.bat`/`.cmd` script. If this path is relative, it is assumed to be relative to the CWD. + /// The script must have been verified to exist at this path before calling this function. + script_path: []const u16, + /// Arguments, not including the script name itself. Expected to be encoded as WTF-8. + script_args: []const []const u8, +) ArgvToScriptCommandLineError![:0]u16 { + var buf = try std.ArrayList(u8).initCapacity(allocator, 64); + defer buf.deinit(); + + // `/d` disables execution of AutoRun commands. + // `/e:ON` and `/v:OFF` are needed for BatBadBut mitigation: + // > If delayed expansion is enabled via the registry value DelayedExpansion, + // > it must be disabled by explicitly calling cmd.exe with the /V:OFF option. + // > Escaping for % requires the command extension to be enabled. + // > If it’s disabled via the registry value EnableExtensions, it must be enabled with the /E:ON option. + // https://flatt.tech/research/posts/batbadbut-you-cant-securely-execute-commands-on-windows/ + buf.appendSliceAssumeCapacity("cmd.exe /d /e:ON /v:OFF /c \""); + + // Always quote the path to the script arg + buf.appendAssumeCapacity('"'); + // We always want the path to the batch script to include a path separator in order to + // avoid cmd.exe searching the PATH for the script. This is not part of the arbitrary + // command execution mitigation, we just know exactly what script we want to execute + // at this point, and potentially making cmd.exe re-find it is unnecessary. + // + // If the script path does not have a path separator, then we know its relative to CWD and + // we can just put `.\` in the front. + if (mem.indexOfAny(u16, script_path, &[_]u16{ mem.nativeToLittle(u16, '\\'), mem.nativeToLittle(u16, '/') }) == null) { + try buf.appendSlice(".\\"); + } + // Note that we don't do any escaping/mitigations for this argument, since the relevant + // characters (", %, etc) are illegal in file paths and this function should only be called + // with script paths that have been verified to exist. + try std.unicode.wtf16LeToWtf8ArrayList(&buf, script_path); + buf.appendAssumeCapacity('"'); + + for (script_args) |arg| { + // Literal carriage returns get stripped when run through cmd.exe + // and NUL/newlines act as 'end of command.' Because of this, it's basically + // always a mistake to include these characters in argv, so it's + // an error condition in order to ensure that the return of this + // function can always roundtrip through cmd.exe. + if (std.mem.indexOfAny(u8, arg, "\x00\r\n") != null) { + return error.InvalidBatchScriptArg; + } + + // Separate args with a space. + try buf.append(' '); + + // Need to quote if the argument is empty (otherwise the arg would just be lost) + // or if the last character is a `\`, since then something like "%~2" in a .bat + // script would cause the closing " to be escaped which we don't want. + var needs_quotes = arg.len == 0 or arg[arg.len - 1] == '\\'; + if (!needs_quotes) { + for (arg) |c| { + switch (c) { + // Known good characters that don't need to be quoted + 'A'...'Z', 'a'...'z', '0'...'9', '#', '$', '*', '+', '-', '.', '/', ':', '?', '@', '\\', '_' => {}, + // When in doubt, quote + else => { + needs_quotes = true; + break; + }, + } + } + } + if (needs_quotes) { + try buf.append('"'); + } + var backslashes: usize = 0; + for (arg) |c| { + switch (c) { + '\\' => { + backslashes += 1; + }, + '"' => { + try buf.appendNTimes('\\', backslashes); + try buf.append('"'); + backslashes = 0; + }, + // Replace `%` with `%%cd:~,%`. + // + // cmd.exe allows extracting a substring from an environment + // variable with the syntax: `%foo:~,%`. + // Therefore, `%cd:~,%` will always expand to an empty string + // since both the start and end index are blank, and it is assumed + // that `%cd%` is always available since it is a built-in variable + // that corresponds to the current directory. + // + // This means that replacing `%foo%` with `%%cd:~,%foo%%cd:~,%` + // will stop `%foo%` from being expanded and *after* expansion + // we'll still be left with `%foo%` (the literal string). + '%' => { + // the trailing `%` is appended outside the switch + try buf.appendSlice("%%cd:~,"); + backslashes = 0; + }, + else => { + backslashes = 0; + }, + } + try buf.append(c); + } + if (needs_quotes) { + try buf.appendNTimes('\\', backslashes); + try buf.append('"'); + } + } + + try buf.append('"'); + + return try unicode.wtf8ToWtf16LeAllocZ(allocator, buf.items); +} + pub const ArgvToCommandLineError = error{ OutOfMemory, InvalidWtf8, InvalidArg0 }; /// Serializes `argv` to a Windows command-line string suitable for passing to a child process and diff --git a/lib/std/os/windows/kernel32.zig b/lib/std/os/windows/kernel32.zig index a1ca655ed1..668a624292 100644 --- a/lib/std/os/windows/kernel32.zig +++ b/lib/std/os/windows/kernel32.zig @@ -243,6 +243,8 @@ pub extern "kernel32" fn GetSystemInfo(lpSystemInfo: *SYSTEM_INFO) callconv(WINA pub extern "kernel32" fn GetSystemTimeAsFileTime(*FILETIME) callconv(WINAPI) void; pub extern "kernel32" fn IsProcessorFeaturePresent(ProcessorFeature: DWORD) BOOL; +pub extern "kernel32" fn GetSystemDirectoryW(lpBuffer: LPWSTR, uSize: UINT) callconv(WINAPI) UINT; + pub extern "kernel32" fn HeapCreate(flOptions: DWORD, dwInitialSize: SIZE_T, dwMaximumSize: SIZE_T) callconv(WINAPI) ?HANDLE; pub extern "kernel32" fn HeapDestroy(hHeap: HANDLE) callconv(WINAPI) BOOL; pub extern "kernel32" fn HeapReAlloc(hHeap: HANDLE, dwFlags: DWORD, lpMem: *anyopaque, dwBytes: SIZE_T) callconv(WINAPI) ?*anyopaque; diff --git a/lib/std/unicode.zig b/lib/std/unicode.zig index 9d1f52fb2d..327c485fd7 100644 --- a/lib/std/unicode.zig +++ b/lib/std/unicode.zig @@ -934,7 +934,7 @@ fn utf16LeToUtf8ArrayListImpl( .cannot_encode_surrogate_half => Utf16LeToUtf8AllocError, .can_encode_surrogate_half => mem.Allocator.Error, })!void { - assert(result.capacity >= utf16le.len); + assert(result.unusedCapacitySlice().len >= utf16le.len); var remaining = utf16le; vectorized: { @@ -979,7 +979,7 @@ fn utf16LeToUtf8ArrayListImpl( pub const Utf16LeToUtf8AllocError = mem.Allocator.Error || Utf16LeToUtf8Error; pub fn utf16LeToUtf8ArrayList(result: *std.ArrayList(u8), utf16le: []const u16) Utf16LeToUtf8AllocError!void { - try result.ensureTotalCapacityPrecise(utf16le.len); + try result.ensureUnusedCapacity(utf16le.len); return utf16LeToUtf8ArrayListImpl(result, utf16le, .cannot_encode_surrogate_half); } @@ -1138,7 +1138,7 @@ test utf16LeToUtf8 { } fn utf8ToUtf16LeArrayListImpl(result: *std.ArrayList(u16), utf8: []const u8, comptime surrogates: Surrogates) !void { - assert(result.capacity >= utf8.len); + assert(result.unusedCapacitySlice().len >= utf8.len); var remaining = utf8; vectorized: { @@ -1176,7 +1176,7 @@ fn utf8ToUtf16LeArrayListImpl(result: *std.ArrayList(u16), utf8: []const u8, com } pub fn utf8ToUtf16LeArrayList(result: *std.ArrayList(u16), utf8: []const u8) error{ InvalidUtf8, OutOfMemory }!void { - try result.ensureTotalCapacityPrecise(utf8.len); + try result.ensureUnusedCapacity(utf8.len); return utf8ToUtf16LeArrayListImpl(result, utf8, .cannot_encode_surrogate_half); } @@ -1351,6 +1351,64 @@ test utf8ToUtf16LeAllocZ { } } +test "ArrayList functions on a re-used list" { + // utf8ToUtf16LeArrayList + { + var list = std.ArrayList(u16).init(testing.allocator); + defer list.deinit(); + + const init_slice = utf8ToUtf16LeStringLiteral("abcdefg"); + try list.ensureTotalCapacityPrecise(init_slice.len); + list.appendSliceAssumeCapacity(init_slice); + + try utf8ToUtf16LeArrayList(&list, "hijklmnopqrstuvwyxz"); + + try testing.expectEqualSlices(u16, utf8ToUtf16LeStringLiteral("abcdefghijklmnopqrstuvwyxz"), list.items); + } + + // utf16LeToUtf8ArrayList + { + var list = std.ArrayList(u8).init(testing.allocator); + defer list.deinit(); + + const init_slice = "abcdefg"; + try list.ensureTotalCapacityPrecise(init_slice.len); + list.appendSliceAssumeCapacity(init_slice); + + try utf16LeToUtf8ArrayList(&list, utf8ToUtf16LeStringLiteral("hijklmnopqrstuvwyxz")); + + try testing.expectEqualStrings("abcdefghijklmnopqrstuvwyxz", list.items); + } + + // wtf8ToWtf16LeArrayList + { + var list = std.ArrayList(u16).init(testing.allocator); + defer list.deinit(); + + const init_slice = utf8ToUtf16LeStringLiteral("abcdefg"); + try list.ensureTotalCapacityPrecise(init_slice.len); + list.appendSliceAssumeCapacity(init_slice); + + try wtf8ToWtf16LeArrayList(&list, "hijklmnopqrstuvwyxz"); + + try testing.expectEqualSlices(u16, utf8ToUtf16LeStringLiteral("abcdefghijklmnopqrstuvwyxz"), list.items); + } + + // wtf16LeToWtf8ArrayList + { + var list = std.ArrayList(u8).init(testing.allocator); + defer list.deinit(); + + const init_slice = "abcdefg"; + try list.ensureTotalCapacityPrecise(init_slice.len); + list.appendSliceAssumeCapacity(init_slice); + + try wtf16LeToWtf8ArrayList(&list, utf8ToUtf16LeStringLiteral("hijklmnopqrstuvwyxz")); + + try testing.expectEqualStrings("abcdefghijklmnopqrstuvwyxz", list.items); + } +} + /// Converts a UTF-8 string literal into a UTF-16LE string literal. pub fn utf8ToUtf16LeStringLiteral(comptime utf8: []const u8) *const [calcUtf16LeLen(utf8) catch |err| @compileError(err):0]u16 { return comptime blk: { @@ -1685,7 +1743,7 @@ pub const Wtf8Iterator = struct { }; pub fn wtf16LeToWtf8ArrayList(result: *std.ArrayList(u8), utf16le: []const u16) mem.Allocator.Error!void { - try result.ensureTotalCapacityPrecise(utf16le.len); + try result.ensureUnusedCapacity(utf16le.len); return utf16LeToUtf8ArrayListImpl(result, utf16le, .can_encode_surrogate_half); } @@ -1714,7 +1772,7 @@ pub fn wtf16LeToWtf8(wtf8: []u8, wtf16le: []const u16) usize { } pub fn wtf8ToWtf16LeArrayList(result: *std.ArrayList(u16), wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }!void { - try result.ensureTotalCapacityPrecise(wtf8.len); + try result.ensureUnusedCapacity(wtf8.len); return utf8ToUtf16LeArrayListImpl(result, wtf8, .can_encode_surrogate_half); } diff --git a/test/standalone/build.zig.zon b/test/standalone/build.zig.zon index a307c0a528..7ddde1611a 100644 --- a/test/standalone/build.zig.zon +++ b/test/standalone/build.zig.zon @@ -107,6 +107,9 @@ .windows_argv = .{ .path = "windows_argv", }, + .windows_bat_args = .{ + .path = "windows_bat_args", + }, .self_exe_symlink = .{ .path = "self_exe_symlink", }, diff --git a/test/standalone/windows_bat_args/build.zig b/test/standalone/windows_bat_args/build.zig new file mode 100644 index 0000000000..1d0a82b920 --- /dev/null +++ b/test/standalone/windows_bat_args/build.zig @@ -0,0 +1,58 @@ +const std = @import("std"); +const builtin = @import("builtin"); + +pub fn build(b: *std.Build) !void { + const test_step = b.step("test", "Test it"); + b.default_step = test_step; + + const optimize: std.builtin.OptimizeMode = .Debug; + const target = b.host; + + if (builtin.os.tag != .windows) return; + + const echo_args = b.addExecutable(.{ + .name = "echo-args", + .root_source_file = b.path("echo-args.zig"), + .optimize = optimize, + .target = target, + }); + + const test_exe = b.addExecutable(.{ + .name = "test", + .root_source_file = b.path("test.zig"), + .optimize = optimize, + .target = target, + }); + + const run = b.addRunArtifact(test_exe); + run.addArtifactArg(echo_args); + run.expectExitCode(0); + run.skip_foreign_checks = true; + + test_step.dependOn(&run.step); + + const fuzz = b.addExecutable(.{ + .name = "fuzz", + .root_source_file = b.path("fuzz.zig"), + .optimize = optimize, + .target = target, + }); + + const fuzz_max_iterations = b.option(u64, "iterations", "The max fuzz iterations (default: 100)") orelse 100; + const fuzz_iterations_arg = std.fmt.allocPrint(b.allocator, "{}", .{fuzz_max_iterations}) catch @panic("oom"); + + const fuzz_seed = b.option(u64, "seed", "Seed to use for the PRNG (default: random)") orelse seed: { + var buf: [8]u8 = undefined; + try std.posix.getrandom(&buf); + break :seed std.mem.readInt(u64, &buf, builtin.cpu.arch.endian()); + }; + const fuzz_seed_arg = std.fmt.allocPrint(b.allocator, "{}", .{fuzz_seed}) catch @panic("oom"); + + const fuzz_run = b.addRunArtifact(fuzz); + fuzz_run.addArtifactArg(echo_args); + fuzz_run.addArgs(&.{ fuzz_iterations_arg, fuzz_seed_arg }); + fuzz_run.expectExitCode(0); + fuzz_run.skip_foreign_checks = true; + + test_step.dependOn(&fuzz_run.step); +} diff --git a/test/standalone/windows_bat_args/echo-args.zig b/test/standalone/windows_bat_args/echo-args.zig new file mode 100644 index 0000000000..2552045aed --- /dev/null +++ b/test/standalone/windows_bat_args/echo-args.zig @@ -0,0 +1,14 @@ +const std = @import("std"); + +pub fn main() !void { + var arena_state = std.heap.ArenaAllocator.init(std.heap.page_allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + const stdout = std.io.getStdOut().writer(); + var args = try std.process.argsAlloc(arena); + for (args[1..], 1..) |arg, i| { + try stdout.writeAll(arg); + if (i != args.len - 1) try stdout.writeByte('\x00'); + } +} diff --git a/test/standalone/windows_bat_args/fuzz.zig b/test/standalone/windows_bat_args/fuzz.zig new file mode 100644 index 0000000000..6908f66a06 --- /dev/null +++ b/test/standalone/windows_bat_args/fuzz.zig @@ -0,0 +1,160 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const Allocator = std.mem.Allocator; + +pub fn main() anyerror!void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer if (gpa.deinit() == .leak) @panic("found memory leaks"); + const allocator = gpa.allocator(); + + var it = try std.process.argsWithAllocator(allocator); + defer it.deinit(); + _ = it.next() orelse unreachable; // skip binary name + const child_exe_path = it.next() orelse unreachable; + + const iterations: u64 = iterations: { + const arg = it.next() orelse "0"; + break :iterations try std.fmt.parseUnsigned(u64, arg, 10); + }; + + var rand_seed = false; + const seed: u64 = seed: { + const seed_arg = it.next() orelse { + rand_seed = true; + var buf: [8]u8 = undefined; + try std.posix.getrandom(&buf); + break :seed std.mem.readInt(u64, &buf, builtin.cpu.arch.endian()); + }; + break :seed try std.fmt.parseUnsigned(u64, seed_arg, 10); + }; + var random = std.rand.DefaultPrng.init(seed); + const rand = random.random(); + + // If the seed was not given via the CLI, then output the + // randomly chosen seed so that this run can be reproduced + if (rand_seed) { + std.debug.print("rand seed: {}\n", .{seed}); + } + + var tmp = std.testing.tmpDir(.{}); + defer tmp.cleanup(); + + try tmp.dir.setAsCwd(); + defer tmp.parent_dir.setAsCwd() catch {}; + + var buf = try std.ArrayList(u8).initCapacity(allocator, 128); + defer buf.deinit(); + try buf.appendSlice("@echo off\n"); + try buf.append('"'); + try buf.appendSlice(child_exe_path); + try buf.append('"'); + const preamble_len = buf.items.len; + + try buf.appendSlice(" %*"); + try tmp.dir.writeFile("args1.bat", buf.items); + buf.shrinkRetainingCapacity(preamble_len); + + try buf.appendSlice(" %1 %2 %3 %4 %5 %6 %7 %8 %9"); + try tmp.dir.writeFile("args2.bat", buf.items); + buf.shrinkRetainingCapacity(preamble_len); + + try buf.appendSlice(" \"%~1\" \"%~2\" \"%~3\" \"%~4\" \"%~5\" \"%~6\" \"%~7\" \"%~8\" \"%~9\""); + try tmp.dir.writeFile("args3.bat", buf.items); + buf.shrinkRetainingCapacity(preamble_len); + + var i: u64 = 0; + while (iterations == 0 or i < iterations) { + const rand_arg = try randomArg(allocator, rand); + defer allocator.free(rand_arg); + + try testExec(allocator, &.{rand_arg}, null); + + i += 1; + } +} + +fn testExec(allocator: std.mem.Allocator, args: []const []const u8, env: ?*std.process.EnvMap) !void { + try testExecBat(allocator, "args1.bat", args, env); + try testExecBat(allocator, "args2.bat", args, env); + try testExecBat(allocator, "args3.bat", args, env); +} + +fn testExecBat(allocator: std.mem.Allocator, bat: []const u8, args: []const []const u8, env: ?*std.process.EnvMap) !void { + var argv = try std.ArrayList([]const u8).initCapacity(allocator, 1 + args.len); + defer argv.deinit(); + argv.appendAssumeCapacity(bat); + argv.appendSliceAssumeCapacity(args); + + const can_have_trailing_empty_args = std.mem.eql(u8, bat, "args3.bat"); + + const result = try std.ChildProcess.run(.{ + .allocator = allocator, + .env_map = env, + .argv = argv.items, + }); + defer allocator.free(result.stdout); + defer allocator.free(result.stderr); + + try std.testing.expectEqualStrings("", result.stderr); + var it = std.mem.splitScalar(u8, result.stdout, '\x00'); + var i: usize = 0; + while (it.next()) |actual_arg| { + if (i >= args.len and can_have_trailing_empty_args) { + try std.testing.expectEqualStrings("", actual_arg); + continue; + } + const expected_arg = args[i]; + try std.testing.expectEqualSlices(u8, expected_arg, actual_arg); + i += 1; + } +} + +fn randomArg(allocator: Allocator, rand: std.rand.Random) ![]const u8 { + const Choice = enum { + backslash, + quote, + space, + control, + printable, + surrogate_half, + non_ascii, + }; + + const choices = rand.uintAtMostBiased(u16, 256); + var buf = try std.ArrayList(u8).initCapacity(allocator, choices); + errdefer buf.deinit(); + + var last_codepoint: u21 = 0; + for (0..choices) |_| { + const choice = rand.enumValue(Choice); + const codepoint: u21 = switch (choice) { + .backslash => '\\', + .quote => '"', + .space => ' ', + .control => switch (rand.uintAtMostBiased(u8, 0x21)) { + // NUL/CR/LF can't roundtrip + '\x00', '\r', '\n' => ' ', + 0x21 => '\x7F', + else => |b| b, + }, + .printable => '!' + rand.uintAtMostBiased(u8, '~' - '!'), + .surrogate_half => rand.intRangeAtMostBiased(u16, 0xD800, 0xDFFF), + .non_ascii => rand.intRangeAtMostBiased(u21, 0x80, 0x10FFFF), + }; + // Ensure that we always return well-formed WTF-8. + // Instead of concatenating to ensure well-formed WTF-8, + // we just skip encoding the low surrogate. + if (std.unicode.isSurrogateCodepoint(last_codepoint) and std.unicode.isSurrogateCodepoint(codepoint)) { + if (std.unicode.utf16IsHighSurrogate(@intCast(last_codepoint)) and std.unicode.utf16IsLowSurrogate(@intCast(codepoint))) { + continue; + } + } + try buf.ensureUnusedCapacity(4); + const unused_slice = buf.unusedCapacitySlice(); + const len = std.unicode.wtf8Encode(codepoint, unused_slice) catch unreachable; + buf.items.len += len; + last_codepoint = codepoint; + } + + return buf.toOwnedSlice(); +} diff --git a/test/standalone/windows_bat_args/test.zig b/test/standalone/windows_bat_args/test.zig new file mode 100644 index 0000000000..8dc27b14c9 --- /dev/null +++ b/test/standalone/windows_bat_args/test.zig @@ -0,0 +1,132 @@ +const std = @import("std"); + +pub fn main() anyerror!void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer if (gpa.deinit() == .leak) @panic("found memory leaks"); + const allocator = gpa.allocator(); + + var it = try std.process.argsWithAllocator(allocator); + defer it.deinit(); + _ = it.next() orelse unreachable; // skip binary name + const child_exe_path = it.next() orelse unreachable; + + var tmp = std.testing.tmpDir(.{}); + defer tmp.cleanup(); + + try tmp.dir.setAsCwd(); + defer tmp.parent_dir.setAsCwd() catch {}; + + var buf = try std.ArrayList(u8).initCapacity(allocator, 128); + defer buf.deinit(); + try buf.appendSlice("@echo off\n"); + try buf.append('"'); + try buf.appendSlice(child_exe_path); + try buf.append('"'); + const preamble_len = buf.items.len; + + try buf.appendSlice(" %*"); + try tmp.dir.writeFile("args1.bat", buf.items); + buf.shrinkRetainingCapacity(preamble_len); + + try buf.appendSlice(" %1 %2 %3 %4 %5 %6 %7 %8 %9"); + try tmp.dir.writeFile("args2.bat", buf.items); + buf.shrinkRetainingCapacity(preamble_len); + + try buf.appendSlice(" \"%~1\" \"%~2\" \"%~3\" \"%~4\" \"%~5\" \"%~6\" \"%~7\" \"%~8\" \"%~9\""); + try tmp.dir.writeFile("args3.bat", buf.items); + buf.shrinkRetainingCapacity(preamble_len); + + // Test cases are from https://github.com/rust-lang/rust/blob/master/tests/ui/std/windows-bat-args.rs + try testExecError(error.InvalidBatchScriptArg, allocator, &.{"\x00"}); + try testExecError(error.InvalidBatchScriptArg, allocator, &.{"\n"}); + try testExecError(error.InvalidBatchScriptArg, allocator, &.{"\r"}); + try testExec(allocator, &.{ "a", "b" }, null); + try testExec(allocator, &.{ "c is for cat", "d is for dog" }, null); + try testExec(allocator, &.{ "\"", " \"" }, null); + try testExec(allocator, &.{ "\\", "\\" }, null); + try testExec(allocator, &.{">file.txt"}, null); + try testExec(allocator, &.{"whoami.exe"}, null); + try testExec(allocator, &.{"&a.exe"}, null); + try testExec(allocator, &.{"&echo hello "}, null); + try testExec(allocator, &.{ "&echo hello", "&whoami", ">file.txt" }, null); + try testExec(allocator, &.{"!TMP!"}, null); + try testExec(allocator, &.{"key=value"}, null); + try testExec(allocator, &.{"\"key=value\""}, null); + try testExec(allocator, &.{"key = value"}, null); + try testExec(allocator, &.{"key=[\"value\"]"}, null); + try testExec(allocator, &.{ "", "a=b" }, null); + try testExec(allocator, &.{"key=\"foo bar\""}, null); + try testExec(allocator, &.{"key=[\"my_value]"}, null); + try testExec(allocator, &.{"key=[\"my_value\",\"other-value\"]"}, null); + try testExec(allocator, &.{"key\\=value"}, null); + try testExec(allocator, &.{"key=\"&whoami\""}, null); + try testExec(allocator, &.{"key=\"value\"=5"}, null); + try testExec(allocator, &.{"key=[\">file.txt\"]"}, null); + try testExec(allocator, &.{"%hello"}, null); + try testExec(allocator, &.{"%PATH%"}, null); + try testExec(allocator, &.{"%%cd:~,%"}, null); + try testExec(allocator, &.{"%PATH%PATH%"}, null); + try testExec(allocator, &.{"\">file.txt"}, null); + try testExec(allocator, &.{"abc\"&echo hello"}, null); + try testExec(allocator, &.{"123\">file.txt"}, null); + try testExec(allocator, &.{"\"&echo hello&whoami.exe"}, null); + try testExec(allocator, &.{ "\"hello^\"world\"", "hello &echo oh no >file.txt" }, null); + try testExec(allocator, &.{"&whoami.exe"}, null); + + var env = env: { + var env = try std.process.getEnvMap(allocator); + errdefer env.deinit(); + // No escaping + try env.put("FOO", "123"); + // Some possible escaping of %FOO% that could be expanded + // when escaping cmd.exe meta characters with ^ + try env.put("FOO^", "123"); // only escaping % + try env.put("^F^O^O^", "123"); // escaping every char + break :env env; + }; + defer env.deinit(); + try testExec(allocator, &.{"%FOO%"}, &env); + + // Ensure that none of the `>file.txt`s have caused file.txt to be created + try std.testing.expectError(error.FileNotFound, tmp.dir.access("file.txt", .{})); +} + +fn testExecError(err: anyerror, allocator: std.mem.Allocator, args: []const []const u8) !void { + return std.testing.expectError(err, testExec(allocator, args, null)); +} + +fn testExec(allocator: std.mem.Allocator, args: []const []const u8, env: ?*std.process.EnvMap) !void { + try testExecBat(allocator, "args1.bat", args, env); + try testExecBat(allocator, "args2.bat", args, env); + try testExecBat(allocator, "args3.bat", args, env); +} + +fn testExecBat(allocator: std.mem.Allocator, bat: []const u8, args: []const []const u8, env: ?*std.process.EnvMap) !void { + var argv = try std.ArrayList([]const u8).initCapacity(allocator, 1 + args.len); + defer argv.deinit(); + argv.appendAssumeCapacity(bat); + argv.appendSliceAssumeCapacity(args); + + const can_have_trailing_empty_args = std.mem.eql(u8, bat, "args3.bat"); + + const result = try std.ChildProcess.run(.{ + .allocator = allocator, + .env_map = env, + .argv = argv.items, + }); + defer allocator.free(result.stdout); + defer allocator.free(result.stderr); + + try std.testing.expectEqualStrings("", result.stderr); + var it = std.mem.splitScalar(u8, result.stdout, '\x00'); + var i: usize = 0; + while (it.next()) |actual_arg| { + if (i >= args.len and can_have_trailing_empty_args) { + try std.testing.expectEqualStrings("", actual_arg); + continue; + } + const expected_arg = args[i]; + try std.testing.expectEqualStrings(expected_arg, actual_arg); + i += 1; + } +}