From 836f9fceab03c7de56eba7a9c2e810206e7e8469 Mon Sep 17 00:00:00 2001 From: Luuk de Gram Date: Mon, 3 Jul 2023 20:08:13 +0200 Subject: [PATCH 1/3] llvm: add safety-check for Wasm memcpy When lowering the `memcpy` instruction, LLVM will lower it to WebAssembly's `memory.copy` instruction when the bulk-memory feature is enabled. This instruction will trap when the destination or source pointer is out-of-bounds. By Zig's semantics, it is valid to have an invalid pointer when the length is 0. To prevent runtimes from trapping, we add a safety-check for slices to only lower to a memcpy instruction when the length is larger than 0. --- src/codegen/llvm.zig | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 24ff706711..597dfc8e4f 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -8634,6 +8634,41 @@ pub const FuncGen = struct { const len = self.sliceOrArrayLenInBytes(dest_slice, dest_ptr_ty); const dest_ptr = self.sliceOrArrayPtr(dest_slice, dest_ptr_ty); const is_volatile = src_ptr_ty.isVolatilePtr(mod) or dest_ptr_ty.isVolatilePtr(mod); + + // When bulk-memory is enabled, this will be lowered to WebAssembly's memory.copy instruction. + // This instruction will trap on an invalid address, regardless of the length. + // For this reason we must add a safety-check for 0-sized slices as its pointer field can be undefined. + // We only have to do this for slices as arrays will have a valid pointer. + if (o.target.isWasm() and + std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory) and + (src_ptr_ty.isSlice(mod) or dest_ptr_ty.isSlice(mod))) + { + const parent_block = self.context.createBasicBlock("Block"); + + const llvm_usize_ty = self.context.intType(o.target.ptrBitWidth()); + const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .eq); + const then_block = self.context.appendBasicBlock(self.llvm_func, "Then"); + const else_block = self.context.appendBasicBlock(self.llvm_func, "Else"); + _ = self.builder.buildCondBr(cond, then_block, else_block); + + self.builder.positionBuilderAtEnd(then_block); + _ = self.builder.buildBr(parent_block); + + self.builder.positionBuilderAtEnd(else_block); + _ = self.builder.buildMemCpy( + dest_ptr, + dest_ptr_ty.ptrAlignment(mod), + src_ptr, + src_ptr_ty.ptrAlignment(mod), + len, + is_volatile, + ); + _ = self.builder.buildBr(parent_block); + self.llvm_func.appendExistingBasicBlock(parent_block); + self.builder.positionBuilderAtEnd(parent_block); + return null; + } + _ = self.builder.buildMemCpy( dest_ptr, dest_ptr_ty.ptrAlignment(mod), From d54ebf4356eaeeab4d256d0da4f81678226f81a6 Mon Sep 17 00:00:00 2001 From: Luuk de Gram Date: Tue, 4 Jul 2023 20:12:48 +0200 Subject: [PATCH 2/3] llvm: add safety-check for Wasm memset When lowering the `memset` instruction, LLVM will lower it to WebAssembly's `memory.fill` instruction when the bulk-memory feature is enabled. This instruction will trap when the destination address is out-of-bounds. By Zig's semantics, it is valid to have an invalid pointer when the length is 0. To prevent runtimes from trapping, we add a safety-check for slices to only lower to a memset instruction when the length is larger than 0. --- src/codegen/llvm.zig | 51 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 597dfc8e4f..cd679b9156 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -8511,6 +8511,14 @@ pub const FuncGen = struct { const dest_ptr = self.sliceOrArrayPtr(dest_slice, ptr_ty); const is_volatile = ptr_ty.isVolatilePtr(mod); + // Any WebAssembly runtime will trap when the destination pointer is out-of-bounds, regardless + // of the length. This means we need to emit a check where we skip the memset when the length + // is 0 as we allow for undefined pointers in 0-sized slices. + const needs_wasm_safety_check = safety and + o.target.isWasm() and + ptr_ty.isSlice(mod) and + std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory); + if (try self.air.value(bin_op.rhs, mod)) |elem_val| { if (elem_val.isUndefDeep(mod)) { // Even if safety is disabled, we still emit a memset to undefined since it conveys @@ -8521,7 +8529,11 @@ pub const FuncGen = struct { else u8_llvm_ty.getUndef(); const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty); - _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + if (needs_wasm_safety_check) { + try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + } else { + _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + } if (safety and mod.comp.bin_file.options.valgrind) { self.valgrindMarkUndef(dest_ptr, len); @@ -8539,7 +8551,12 @@ pub const FuncGen = struct { .val = byte_val, }); const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty); - _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + + if (needs_wasm_safety_check) { + try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + } else { + _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + } return null; } } @@ -8551,7 +8568,12 @@ pub const FuncGen = struct { // In this case we can take advantage of LLVM's intrinsic. const fill_byte = try self.bitCast(value, elem_ty, Type.u8); const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty); - _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + + if (needs_wasm_safety_check) { + try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + } else { + _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + } return null; } @@ -8622,6 +8644,29 @@ pub const FuncGen = struct { return null; } + fn safeWasmMemset( + self: *FuncGen, + dest_ptr: *llvm.Value, + fill_byte: *llvm.Value, + len: *llvm.Value, + dest_ptr_align: u32, + is_volatile: bool, + ) !void { + const parent_block = self.context.createBasicBlock("Block"); + const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth()); + const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .eq); + const then_block = self.context.appendBasicBlock(self.llvm_func, "Then"); + const else_block = self.context.appendBasicBlock(self.llvm_func, "Else"); + _ = self.builder.buildCondBr(cond, then_block, else_block); + self.builder.positionBuilderAtEnd(then_block); + _ = self.builder.buildBr(parent_block); + self.builder.positionBuilderAtEnd(else_block); + _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); + _ = self.builder.buildBr(parent_block); + self.llvm_func.appendExistingBasicBlock(parent_block); + self.builder.positionBuilderAtEnd(parent_block); + } + fn airMemcpy(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value { const o = self.dg.object; const mod = o.module; From 37e2a04da8688a168ed0ad81bf3ead5d9a3b8474 Mon Sep 17 00:00:00 2001 From: Luuk de Gram Date: Thu, 6 Jul 2023 19:31:08 +0200 Subject: [PATCH 3/3] add stand alone test to verify bulk-memory features This adds a standalone test case to ensure the runtime does not trap when performing a memory.copy or memory.fill instruction while the destination or source address is out-of-bounds and the length is 0. --- src/codegen/llvm.zig | 56 ++++++++----------- test/standalone.zig | 4 ++ test/standalone/zerolength_check/build.zig | 27 +++++++++ test/standalone/zerolength_check/src/main.zig | 23 ++++++++ 4 files changed, 77 insertions(+), 33 deletions(-) create mode 100644 test/standalone/zerolength_check/build.zig create mode 100644 test/standalone/zerolength_check/src/main.zig diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index cd679b9156..ec456e53a7 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -8514,8 +8514,8 @@ pub const FuncGen = struct { // Any WebAssembly runtime will trap when the destination pointer is out-of-bounds, regardless // of the length. This means we need to emit a check where we skip the memset when the length // is 0 as we allow for undefined pointers in 0-sized slices. - const needs_wasm_safety_check = safety and - o.target.isWasm() and + // This logic can be removed once https://github.com/ziglang/zig/issues/16360 is done. + const intrinsic_len0_traps = o.target.isWasm() and ptr_ty.isSlice(mod) and std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory); @@ -8529,7 +8529,7 @@ pub const FuncGen = struct { else u8_llvm_ty.getUndef(); const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty); - if (needs_wasm_safety_check) { + if (intrinsic_len0_traps) { try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); } else { _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); @@ -8552,7 +8552,7 @@ pub const FuncGen = struct { }); const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty); - if (needs_wasm_safety_check) { + if (intrinsic_len0_traps) { try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); } else { _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); @@ -8569,7 +8569,7 @@ pub const FuncGen = struct { const fill_byte = try self.bitCast(value, elem_ty, Type.u8); const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty); - if (needs_wasm_safety_check) { + if (intrinsic_len0_traps) { try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); } else { _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); @@ -8652,19 +8652,15 @@ pub const FuncGen = struct { dest_ptr_align: u32, is_volatile: bool, ) !void { - const parent_block = self.context.createBasicBlock("Block"); const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth()); - const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .eq); - const then_block = self.context.appendBasicBlock(self.llvm_func, "Then"); - const else_block = self.context.appendBasicBlock(self.llvm_func, "Else"); - _ = self.builder.buildCondBr(cond, then_block, else_block); - self.builder.positionBuilderAtEnd(then_block); - _ = self.builder.buildBr(parent_block); - self.builder.positionBuilderAtEnd(else_block); + const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .neq); + const memset_block = self.context.appendBasicBlock(self.llvm_func, "MemsetTrapSkip"); + const end_block = self.context.appendBasicBlock(self.llvm_func, "MemsetTrapEnd"); + _ = self.builder.buildCondBr(cond, memset_block, end_block); + self.builder.positionBuilderAtEnd(memset_block); _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile); - _ = self.builder.buildBr(parent_block); - self.llvm_func.appendExistingBasicBlock(parent_block); - self.builder.positionBuilderAtEnd(parent_block); + _ = self.builder.buildBr(end_block); + self.builder.positionBuilderAtEnd(end_block); } fn airMemcpy(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value { @@ -8682,24 +8678,19 @@ pub const FuncGen = struct { // When bulk-memory is enabled, this will be lowered to WebAssembly's memory.copy instruction. // This instruction will trap on an invalid address, regardless of the length. - // For this reason we must add a safety-check for 0-sized slices as its pointer field can be undefined. + // For this reason we must add a check for 0-sized slices as its pointer field can be undefined. // We only have to do this for slices as arrays will have a valid pointer. + // This logic can be removed once https://github.com/ziglang/zig/issues/16360 is done. if (o.target.isWasm() and std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory) and - (src_ptr_ty.isSlice(mod) or dest_ptr_ty.isSlice(mod))) + dest_ptr_ty.isSlice(mod)) { - const parent_block = self.context.createBasicBlock("Block"); - - const llvm_usize_ty = self.context.intType(o.target.ptrBitWidth()); - const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .eq); - const then_block = self.context.appendBasicBlock(self.llvm_func, "Then"); - const else_block = self.context.appendBasicBlock(self.llvm_func, "Else"); - _ = self.builder.buildCondBr(cond, then_block, else_block); - - self.builder.positionBuilderAtEnd(then_block); - _ = self.builder.buildBr(parent_block); - - self.builder.positionBuilderAtEnd(else_block); + const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth()); + const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .neq); + const memcpy_block = self.context.appendBasicBlock(self.llvm_func, "MemcpyTrapSkip"); + const end_block = self.context.appendBasicBlock(self.llvm_func, "MemcpyTrapEnd"); + _ = self.builder.buildCondBr(cond, memcpy_block, end_block); + self.builder.positionBuilderAtEnd(memcpy_block); _ = self.builder.buildMemCpy( dest_ptr, dest_ptr_ty.ptrAlignment(mod), @@ -8708,9 +8699,8 @@ pub const FuncGen = struct { len, is_volatile, ); - _ = self.builder.buildBr(parent_block); - self.llvm_func.appendExistingBasicBlock(parent_block); - self.builder.positionBuilderAtEnd(parent_block); + _ = self.builder.buildBr(end_block); + self.builder.positionBuilderAtEnd(end_block); return null; } diff --git a/test/standalone.zig b/test/standalone.zig index b15e5f7033..d812669664 100644 --- a/test/standalone.zig +++ b/test/standalone.zig @@ -230,6 +230,10 @@ pub const build_cases = [_]BuildCase{ .build_root = "test/standalone/cmakedefine", .import = @import("standalone/cmakedefine/build.zig"), }, + .{ + .build_root = "test/standalone/zerolength_check", + .import = @import("standalone/zerolength_check/build.zig"), + }, }; const std = @import("std"); diff --git a/test/standalone/zerolength_check/build.zig b/test/standalone/zerolength_check/build.zig new file mode 100644 index 0000000000..4e5c1937f7 --- /dev/null +++ b/test/standalone/zerolength_check/build.zig @@ -0,0 +1,27 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) void { + const test_step = b.step("test", "Test it"); + b.default_step = test_step; + + add(b, test_step, .Debug); + add(b, test_step, .ReleaseFast); + add(b, test_step, .ReleaseSmall); + add(b, test_step, .ReleaseSafe); +} + +fn add(b: *std.Build, test_step: *std.Build.Step, optimize: std.builtin.OptimizeMode) void { + const unit_tests = b.addTest(.{ + .root_source_file = .{ .path = "src/main.zig" }, + .target = .{ + .os_tag = .wasi, + .cpu_arch = .wasm32, + .cpu_features_add = std.Target.wasm.featureSet(&.{.bulk_memory}), + }, + .optimize = optimize, + }); + + const run_unit_tests = b.addRunArtifact(unit_tests); + run_unit_tests.skip_foreign_checks = true; + test_step.dependOn(&run_unit_tests.step); +} diff --git a/test/standalone/zerolength_check/src/main.zig b/test/standalone/zerolength_check/src/main.zig new file mode 100644 index 0000000000..e1e2c5d376 --- /dev/null +++ b/test/standalone/zerolength_check/src/main.zig @@ -0,0 +1,23 @@ +const std = @import("std"); + +test { + var dest = foo(); + var source = foo(); + + @memcpy(dest, source); + @memset(dest, 4); + @memset(dest, undefined); + + var dest2 = foo2(); + @memset(dest2, 0); +} + +fn foo() []u8 { + const ptr = comptime std.mem.alignBackward(usize, std.math.maxInt(usize), 1); + return @as([*]align(1) u8, @ptrFromInt(ptr))[0..0]; +} + +fn foo2() []u64 { + const ptr = comptime std.mem.alignBackward(usize, std.math.maxInt(usize), 1); + return @as([*]align(1) u64, @ptrFromInt(ptr))[0..0]; +}