Skip to content

std.process.Child: Mitigate arbitrary command execution vulnerability on Windows (BatBadBut) #19698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 276 additions & 12 deletions lib/std/child_process.zig

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions lib/std/os/windows/kernel32.zig
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ pub extern "kernel32" fn GetSystemInfo(lpSystemInfo: *SYSTEM_INFO) callconv(WINA
pub extern "kernel32" fn GetSystemTimeAsFileTime(*FILETIME) callconv(WINAPI) void;
pub extern "kernel32" fn IsProcessorFeaturePresent(ProcessorFeature: DWORD) BOOL;

pub extern "kernel32" fn GetSystemDirectoryW(lpBuffer: LPWSTR, uSize: UINT) callconv(WINAPI) UINT;

pub extern "kernel32" fn HeapCreate(flOptions: DWORD, dwInitialSize: SIZE_T, dwMaximumSize: SIZE_T) callconv(WINAPI) ?HANDLE;
pub extern "kernel32" fn HeapDestroy(hHeap: HANDLE) callconv(WINAPI) BOOL;
pub extern "kernel32" fn HeapReAlloc(hHeap: HANDLE, dwFlags: DWORD, lpMem: *anyopaque, dwBytes: SIZE_T) callconv(WINAPI) ?*anyopaque;
Expand Down
70 changes: 64 additions & 6 deletions lib/std/unicode.zig
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ fn utf16LeToUtf8ArrayListImpl(
.cannot_encode_surrogate_half => Utf16LeToUtf8AllocError,
.can_encode_surrogate_half => mem.Allocator.Error,
})!void {
assert(result.capacity >= utf16le.len);
assert(result.unusedCapacitySlice().len >= utf16le.len);

var remaining = utf16le;
vectorized: {
Expand Down Expand Up @@ -979,7 +979,7 @@ fn utf16LeToUtf8ArrayListImpl(
pub const Utf16LeToUtf8AllocError = mem.Allocator.Error || Utf16LeToUtf8Error;

pub fn utf16LeToUtf8ArrayList(result: *std.ArrayList(u8), utf16le: []const u16) Utf16LeToUtf8AllocError!void {
try result.ensureTotalCapacityPrecise(utf16le.len);
try result.ensureUnusedCapacity(utf16le.len);
return utf16LeToUtf8ArrayListImpl(result, utf16le, .cannot_encode_surrogate_half);
}

Expand Down Expand Up @@ -1138,7 +1138,7 @@ test utf16LeToUtf8 {
}

fn utf8ToUtf16LeArrayListImpl(result: *std.ArrayList(u16), utf8: []const u8, comptime surrogates: Surrogates) !void {
assert(result.capacity >= utf8.len);
assert(result.unusedCapacitySlice().len >= utf8.len);

var remaining = utf8;
vectorized: {
Expand Down Expand Up @@ -1176,7 +1176,7 @@ fn utf8ToUtf16LeArrayListImpl(result: *std.ArrayList(u16), utf8: []const u8, com
}

pub fn utf8ToUtf16LeArrayList(result: *std.ArrayList(u16), utf8: []const u8) error{ InvalidUtf8, OutOfMemory }!void {
try result.ensureTotalCapacityPrecise(utf8.len);
try result.ensureUnusedCapacity(utf8.len);
return utf8ToUtf16LeArrayListImpl(result, utf8, .cannot_encode_surrogate_half);
}

Expand Down Expand Up @@ -1351,6 +1351,64 @@ test utf8ToUtf16LeAllocZ {
}
}

test "ArrayList functions on a re-used list" {
// utf8ToUtf16LeArrayList
{
var list = std.ArrayList(u16).init(testing.allocator);
defer list.deinit();

const init_slice = utf8ToUtf16LeStringLiteral("abcdefg");
try list.ensureTotalCapacityPrecise(init_slice.len);
list.appendSliceAssumeCapacity(init_slice);

try utf8ToUtf16LeArrayList(&list, "hijklmnopqrstuvwyxz");

try testing.expectEqualSlices(u16, utf8ToUtf16LeStringLiteral("abcdefghijklmnopqrstuvwyxz"), list.items);
}

// utf16LeToUtf8ArrayList
{
var list = std.ArrayList(u8).init(testing.allocator);
defer list.deinit();

const init_slice = "abcdefg";
try list.ensureTotalCapacityPrecise(init_slice.len);
list.appendSliceAssumeCapacity(init_slice);

try utf16LeToUtf8ArrayList(&list, utf8ToUtf16LeStringLiteral("hijklmnopqrstuvwyxz"));

try testing.expectEqualStrings("abcdefghijklmnopqrstuvwyxz", list.items);
}

// wtf8ToWtf16LeArrayList
{
var list = std.ArrayList(u16).init(testing.allocator);
defer list.deinit();

const init_slice = utf8ToUtf16LeStringLiteral("abcdefg");
try list.ensureTotalCapacityPrecise(init_slice.len);
list.appendSliceAssumeCapacity(init_slice);

try wtf8ToWtf16LeArrayList(&list, "hijklmnopqrstuvwyxz");

try testing.expectEqualSlices(u16, utf8ToUtf16LeStringLiteral("abcdefghijklmnopqrstuvwyxz"), list.items);
}

// wtf16LeToWtf8ArrayList
{
var list = std.ArrayList(u8).init(testing.allocator);
defer list.deinit();

const init_slice = "abcdefg";
try list.ensureTotalCapacityPrecise(init_slice.len);
list.appendSliceAssumeCapacity(init_slice);

try wtf16LeToWtf8ArrayList(&list, utf8ToUtf16LeStringLiteral("hijklmnopqrstuvwyxz"));

try testing.expectEqualStrings("abcdefghijklmnopqrstuvwyxz", list.items);
}
}

/// Converts a UTF-8 string literal into a UTF-16LE string literal.
pub fn utf8ToUtf16LeStringLiteral(comptime utf8: []const u8) *const [calcUtf16LeLen(utf8) catch |err| @compileError(err):0]u16 {
return comptime blk: {
Expand Down Expand Up @@ -1685,7 +1743,7 @@ pub const Wtf8Iterator = struct {
};

pub fn wtf16LeToWtf8ArrayList(result: *std.ArrayList(u8), utf16le: []const u16) mem.Allocator.Error!void {
try result.ensureTotalCapacityPrecise(utf16le.len);
try result.ensureUnusedCapacity(utf16le.len);
return utf16LeToUtf8ArrayListImpl(result, utf16le, .can_encode_surrogate_half);
}

Expand Down Expand Up @@ -1714,7 +1772,7 @@ pub fn wtf16LeToWtf8(wtf8: []u8, wtf16le: []const u16) usize {
}

pub fn wtf8ToWtf16LeArrayList(result: *std.ArrayList(u16), wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }!void {
try result.ensureTotalCapacityPrecise(wtf8.len);
try result.ensureUnusedCapacity(wtf8.len);
return utf8ToUtf16LeArrayListImpl(result, wtf8, .can_encode_surrogate_half);
}

Expand Down
3 changes: 3 additions & 0 deletions test/standalone/build.zig.zon
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@
.windows_argv = .{
.path = "windows_argv",
},
.windows_bat_args = .{
.path = "windows_bat_args",
},
.self_exe_symlink = .{
.path = "self_exe_symlink",
},
Expand Down
58 changes: 58 additions & 0 deletions test/standalone/windows_bat_args/build.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
const std = @import("std");
const builtin = @import("builtin");

pub fn build(b: *std.Build) !void {
const test_step = b.step("test", "Test it");
b.default_step = test_step;

const optimize: std.builtin.OptimizeMode = .Debug;
const target = b.host;

if (builtin.os.tag != .windows) return;

const echo_args = b.addExecutable(.{
.name = "echo-args",
.root_source_file = b.path("echo-args.zig"),
.optimize = optimize,
.target = target,
});

const test_exe = b.addExecutable(.{
.name = "test",
.root_source_file = b.path("test.zig"),
.optimize = optimize,
.target = target,
});

const run = b.addRunArtifact(test_exe);
run.addArtifactArg(echo_args);
run.expectExitCode(0);
run.skip_foreign_checks = true;

test_step.dependOn(&run.step);

const fuzz = b.addExecutable(.{
.name = "fuzz",
.root_source_file = b.path("fuzz.zig"),
.optimize = optimize,
.target = target,
});

const fuzz_max_iterations = b.option(u64, "iterations", "The max fuzz iterations (default: 100)") orelse 100;
const fuzz_iterations_arg = std.fmt.allocPrint(b.allocator, "{}", .{fuzz_max_iterations}) catch @panic("oom");

const fuzz_seed = b.option(u64, "seed", "Seed to use for the PRNG (default: random)") orelse seed: {
var buf: [8]u8 = undefined;
try std.posix.getrandom(&buf);
break :seed std.mem.readInt(u64, &buf, builtin.cpu.arch.endian());
};
const fuzz_seed_arg = std.fmt.allocPrint(b.allocator, "{}", .{fuzz_seed}) catch @panic("oom");

const fuzz_run = b.addRunArtifact(fuzz);
fuzz_run.addArtifactArg(echo_args);
fuzz_run.addArgs(&.{ fuzz_iterations_arg, fuzz_seed_arg });
fuzz_run.expectExitCode(0);
fuzz_run.skip_foreign_checks = true;

test_step.dependOn(&fuzz_run.step);
}
14 changes: 14 additions & 0 deletions test/standalone/windows_bat_args/echo-args.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
const std = @import("std");

pub fn main() !void {
var arena_state = std.heap.ArenaAllocator.init(std.heap.page_allocator);
defer arena_state.deinit();
const arena = arena_state.allocator();

const stdout = std.io.getStdOut().writer();
var args = try std.process.argsAlloc(arena);
for (args[1..], 1..) |arg, i| {
try stdout.writeAll(arg);
if (i != args.len - 1) try stdout.writeByte('\x00');
}
}
160 changes: 160 additions & 0 deletions test/standalone/windows_bat_args/fuzz.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
const std = @import("std");
const builtin = @import("builtin");
const Allocator = std.mem.Allocator;

pub fn main() anyerror!void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer if (gpa.deinit() == .leak) @panic("found memory leaks");
const allocator = gpa.allocator();

var it = try std.process.argsWithAllocator(allocator);
defer it.deinit();
_ = it.next() orelse unreachable; // skip binary name
const child_exe_path = it.next() orelse unreachable;

const iterations: u64 = iterations: {
const arg = it.next() orelse "0";
break :iterations try std.fmt.parseUnsigned(u64, arg, 10);
};

var rand_seed = false;
const seed: u64 = seed: {
const seed_arg = it.next() orelse {
rand_seed = true;
var buf: [8]u8 = undefined;
try std.posix.getrandom(&buf);
break :seed std.mem.readInt(u64, &buf, builtin.cpu.arch.endian());
};
break :seed try std.fmt.parseUnsigned(u64, seed_arg, 10);
};
var random = std.rand.DefaultPrng.init(seed);
const rand = random.random();

// If the seed was not given via the CLI, then output the
// randomly chosen seed so that this run can be reproduced
if (rand_seed) {
std.debug.print("rand seed: {}\n", .{seed});
}

var tmp = std.testing.tmpDir(.{});
defer tmp.cleanup();

try tmp.dir.setAsCwd();
defer tmp.parent_dir.setAsCwd() catch {};

var buf = try std.ArrayList(u8).initCapacity(allocator, 128);
defer buf.deinit();
try buf.appendSlice("@echo off\n");
try buf.append('"');
try buf.appendSlice(child_exe_path);
try buf.append('"');
const preamble_len = buf.items.len;

try buf.appendSlice(" %*");
try tmp.dir.writeFile("args1.bat", buf.items);
buf.shrinkRetainingCapacity(preamble_len);

try buf.appendSlice(" %1 %2 %3 %4 %5 %6 %7 %8 %9");
try tmp.dir.writeFile("args2.bat", buf.items);
buf.shrinkRetainingCapacity(preamble_len);

try buf.appendSlice(" \"%~1\" \"%~2\" \"%~3\" \"%~4\" \"%~5\" \"%~6\" \"%~7\" \"%~8\" \"%~9\"");
try tmp.dir.writeFile("args3.bat", buf.items);
buf.shrinkRetainingCapacity(preamble_len);

var i: u64 = 0;
while (iterations == 0 or i < iterations) {
const rand_arg = try randomArg(allocator, rand);
defer allocator.free(rand_arg);

try testExec(allocator, &.{rand_arg}, null);

i += 1;
}
}

fn testExec(allocator: std.mem.Allocator, args: []const []const u8, env: ?*std.process.EnvMap) !void {
try testExecBat(allocator, "args1.bat", args, env);
try testExecBat(allocator, "args2.bat", args, env);
try testExecBat(allocator, "args3.bat", args, env);
}

fn testExecBat(allocator: std.mem.Allocator, bat: []const u8, args: []const []const u8, env: ?*std.process.EnvMap) !void {
var argv = try std.ArrayList([]const u8).initCapacity(allocator, 1 + args.len);
defer argv.deinit();
argv.appendAssumeCapacity(bat);
argv.appendSliceAssumeCapacity(args);

const can_have_trailing_empty_args = std.mem.eql(u8, bat, "args3.bat");

const result = try std.ChildProcess.run(.{
.allocator = allocator,
.env_map = env,
.argv = argv.items,
});
defer allocator.free(result.stdout);
defer allocator.free(result.stderr);

try std.testing.expectEqualStrings("", result.stderr);
var it = std.mem.splitScalar(u8, result.stdout, '\x00');
var i: usize = 0;
while (it.next()) |actual_arg| {
if (i >= args.len and can_have_trailing_empty_args) {
try std.testing.expectEqualStrings("", actual_arg);
continue;
}
const expected_arg = args[i];
try std.testing.expectEqualSlices(u8, expected_arg, actual_arg);
i += 1;
}
}

fn randomArg(allocator: Allocator, rand: std.rand.Random) ![]const u8 {
const Choice = enum {
backslash,
quote,
space,
control,
printable,
surrogate_half,
non_ascii,
};

const choices = rand.uintAtMostBiased(u16, 256);
var buf = try std.ArrayList(u8).initCapacity(allocator, choices);
errdefer buf.deinit();

var last_codepoint: u21 = 0;
for (0..choices) |_| {
const choice = rand.enumValue(Choice);
const codepoint: u21 = switch (choice) {
.backslash => '\\',
.quote => '"',
.space => ' ',
.control => switch (rand.uintAtMostBiased(u8, 0x21)) {
// NUL/CR/LF can't roundtrip
'\x00', '\r', '\n' => ' ',
0x21 => '\x7F',
else => |b| b,
},
.printable => '!' + rand.uintAtMostBiased(u8, '~' - '!'),
.surrogate_half => rand.intRangeAtMostBiased(u16, 0xD800, 0xDFFF),
.non_ascii => rand.intRangeAtMostBiased(u21, 0x80, 0x10FFFF),
};
// Ensure that we always return well-formed WTF-8.
// Instead of concatenating to ensure well-formed WTF-8,
// we just skip encoding the low surrogate.
if (std.unicode.isSurrogateCodepoint(last_codepoint) and std.unicode.isSurrogateCodepoint(codepoint)) {
if (std.unicode.utf16IsHighSurrogate(@intCast(last_codepoint)) and std.unicode.utf16IsLowSurrogate(@intCast(codepoint))) {
continue;
}
}
try buf.ensureUnusedCapacity(4);
const unused_slice = buf.unusedCapacitySlice();
const len = std.unicode.wtf8Encode(codepoint, unused_slice) catch unreachable;
buf.items.len += len;
last_codepoint = codepoint;
}

return buf.toOwnedSlice();
}
Loading