wasm: Implement memset, and sret arguments.

We now detect if the return type will be set by passing the first argument
as a pointer to stack memory from the callee's frame. This way, we do not have to
worry about stack memory being overwritten.

Besides this, we implement memset by either using wasm's memory.fill instruction when available,
or lower it manually. In the future we can lower this to a compiler_rt call.
This commit is contained in:
Luuk de Gram 2022-01-04 17:00:21 +01:00
parent 5c21a45cf0
commit 89b1fdc443
No known key found for this signature in database
GPG Key ID: A8CFE58E4DC7D664
5 changed files with 185 additions and 56 deletions

View File

@ -212,6 +212,28 @@ test "Wasm - opcodes" {
try testing.expectEqual(@as(u16, 0xC4), i64_extend32_s);
}
/// Opcodes that require a prefix `0xFC`
pub const PrefixedOpcode = enum(u8) {
i32_trunc_sat_f32_s = 0x00,
i32_trunc_sat_f32_u = 0x01,
i32_trunc_sat_f64_s = 0x02,
i32_trunc_sat_f64_u = 0x03,
i64_trunc_sat_f32_s = 0x04,
i64_trunc_sat_f32_u = 0x05,
i64_trunc_sat_f64_s = 0x06,
i64_trunc_sat_f64_u = 0x07,
memory_init = 0x08,
data_drop = 0x09,
memory_copy = 0x0A,
memory_fill = 0x0B,
table_init = 0x0C,
elem_drop = 0x0D,
table_copy = 0x0E,
table_grow = 0x0F,
table_size = 0x10,
table_fill = 0x11,
};
/// Enum representing all Wasm value types as per spec:
/// https://webassembly.github.io/spec/core/binary/types.html
pub const Valtype = enum(u8) {
@ -266,7 +288,7 @@ pub const InitExpression = union(enum) {
global_get: u32,
};
///
/// Represents a function entry, holding the index to its type
pub const Func = struct {
type_index: u32,
};

View File

@ -623,6 +623,10 @@ fn addTag(self: *Self, tag: Mir.Inst.Tag) error{OutOfMemory}!void {
try self.addInst(.{ .tag = tag, .data = .{ .tag = {} } });
}
fn addExtended(self: *Self, opcode: wasm.PrefixedOpcode) error{OutOfMemory}!void {
try self.addInst(.{ .tag = .extended, .secondary = @enumToInt(opcode), .data = .{ .tag = {} } });
}
fn addLabel(self: *Self, tag: Mir.Inst.Tag, label: u32) error{OutOfMemory}!void {
try self.addInst(.{ .tag = tag, .data = .{ .label = label } });
}
@ -746,6 +750,13 @@ fn genFunctype(self: *Self, fn_ty: Type) !wasm.Type {
defer params.deinit();
var returns = std.ArrayList(wasm.Valtype).init(self.gpa);
defer returns.deinit();
const return_type = fn_ty.fnReturnType();
const want_sret = isByRef(return_type);
if (want_sret) {
try params.append(try self.typeToValtype(Type.usize));
}
// param types
if (fn_ty.fnParamLen() != 0) {
@ -759,11 +770,8 @@ fn genFunctype(self: *Self, fn_ty: Type) !wasm.Type {
}
// return type
const return_type = fn_ty.fnReturnType();
switch (return_type.zigTypeTag()) {
.Void, .NoReturn => {},
.Struct => return self.fail("TODO: Implement struct as return type for wasm", .{}),
else => try returns.append(try self.typeToValtype(return_type)),
if (!want_sret and return_type.hasCodeGenBits()) {
try returns.append(try self.typeToValtype(return_type));
}
return wasm.Type{
@ -785,6 +793,15 @@ pub fn genFunc(self: *Self) InnerError!Result {
// Generate MIR for function body
try self.genBody(self.air.getMainBody());
// In case we have a return value, but the last instruction is a noreturn (such as a while loop)
// we emit an unreachable instruction to tell the stack validator that part will never be reached.
if (func_type.returns.len != 0 and self.air.instructions.len > 0) {
const inst = @intCast(u32, self.air.instructions.len - 1);
if (self.air.typeOfIndex(inst).isNoReturn()) {
try self.addTag(.@"unreachable");
}
}
// End of function body
try self.addTag(.end);
@ -1074,6 +1091,15 @@ fn resolveCallingConventionValues(self: *Self, fn_ty: Type) InnerError!CallWValu
.return_value = .none,
};
errdefer self.gpa.free(result.args);
const ret_ty = fn_ty.fnReturnType();
// Check if we store the result as a pointer to the stack rather than
// by value
if (isByRef(ret_ty)) {
// the sret arg will be passed as first argument, therefore we
// set the `return_value` before allocating locals for regular args.
result.return_value = .{ .local = self.local_index };
self.local_index += 1;
}
switch (cc) {
.Naked => return result,
.Unspecified, .C => {
@ -1086,19 +1112,6 @@ fn resolveCallingConventionValues(self: *Self, fn_ty: Type) InnerError!CallWValu
result.args[ty_index] = .{ .local = self.local_index };
self.local_index += 1;
}
const ret_ty = fn_ty.fnReturnType();
// Check if we store the result as a pointer to the stack rather than
// by value
if (isByRef(ret_ty)) {
if (self.initial_stack_value == .none) try self.initializeStack();
result.return_value = try self.allocStack(ret_ty);
// We want to make sure the return value's stack value doesn't get overwritten,
// so set initial stack value to current's position instead.
try self.addLabel(.global_get, 0);
try self.addLabel(.local_set, self.initial_stack_value.local);
}
},
else => return self.fail("TODO implement function parameters for cc '{}' on wasm", .{cc}),
}
@ -1323,6 +1336,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
.load => self.airLoad(inst),
.loop => self.airLoop(inst),
.memset => self.airMemset(inst),
.not => self.airNot(inst),
.optional_payload => self.airOptionalPayload(inst),
.optional_payload_ptr => self.airOptionalPayloadPtr(inst),
@ -1335,18 +1349,21 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
.ret => self.airRet(inst),
.ret_ptr => self.airRetPtr(inst),
.ret_load => self.airRetLoad(inst),
.slice => self.airSlice(inst),
.slice_len => self.airSliceLen(inst),
.slice_elem_val => self.airSliceElemVal(inst),
.slice_elem_ptr => self.airSliceElemPtr(inst),
.slice_ptr => self.airSlicePtr(inst),
.store => self.airStore(inst),
.struct_field_ptr => self.airStructFieldPtr(inst),
.struct_field_ptr_index_0 => self.airStructFieldPtrIndex(inst, 0),
.struct_field_ptr_index_1 => self.airStructFieldPtrIndex(inst, 1),
.struct_field_ptr_index_2 => self.airStructFieldPtrIndex(inst, 2),
.struct_field_ptr_index_3 => self.airStructFieldPtrIndex(inst, 3),
.struct_field_val => self.airStructFieldVal(inst),
.switch_br => self.airSwitchBr(inst),
.trunc => self.airTrunc(inst),
.unreach => self.airUnreachable(inst),
@ -1374,7 +1391,6 @@ fn airRet(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
// to the stack instead
if (self.return_value != .none) {
try self.store(self.return_value, operand, self.decl.ty.fnReturnType(), 0);
try self.emitWValue(self.return_value);
} else {
try self.emitWValue(operand);
}
@ -1393,6 +1409,9 @@ fn airRetPtr(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
if (child_type.abiSize(self.target) == 0) return WValue{ .none = {} };
if (isByRef(child_type)) {
return self.return_value;
}
return self.allocStack(child_type);
}
@ -1402,9 +1421,7 @@ fn airRetLoad(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const ret_ty = self.air.typeOf(un_op).childType();
if (!ret_ty.hasCodeGenBits()) return WValue.none;
if (isByRef(ret_ty)) {
try self.emitWValue(operand);
} else {
if (!isByRef(ret_ty)) {
const result = try self.load(operand, ret_ty, 0);
try self.emitWValue(result);
}
@ -1425,6 +1442,8 @@ fn airCall(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
.Pointer => ty.childType(),
else => unreachable,
};
const ret_ty = fn_ty.fnReturnType();
const first_param_sret = isByRef(ret_ty);
const target: ?*Decl = blk: {
const func_val = self.air.value(pl_op.operand) orelse break :blk null;
@ -1437,6 +1456,12 @@ fn airCall(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
return self.fail("Expected a function, but instead found type '{s}'", .{func_val.tag()});
};
const sret = if (first_param_sret) blk: {
const sret_local = try self.allocStack(ret_ty);
try self.emitWValue(sret_local);
break :blk sret_local;
} else WValue{ .none = {} };
for (args) |arg| {
const arg_ref = @intToEnum(Air.Inst.Ref, arg);
const arg_val = self.resolveInst(arg_ref);
@ -1475,32 +1500,19 @@ fn airCall(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
try self.addLabel(.call_indirect, fn_type_index);
}
const ret_ty = fn_ty.fnReturnType();
if (!ret_ty.hasCodeGenBits()) return WValue.none;
// TODO: Implement this for all aggregate types
if (ret_ty.isSlice()) {
// first load the values onto the regular stack, before we move the stack pointer
// to prevent overwriting the return value.
const tmp = try self.allocLocal(ret_ty);
try self.addLabel(.local_set, tmp.local);
const field_ty = Type.@"usize";
const offset = @intCast(u32, field_ty.abiSize(self.target));
const ptr_local = try self.load(tmp, field_ty, 0);
const len_local = try self.load(tmp, field_ty, offset);
// As our values are now safe, we reserve space on the virtual stack and
// store the values there.
const result = try self.allocStack(ret_ty);
try self.store(result, ptr_local, field_ty, 0);
try self.store(result, len_local, field_ty, offset);
return result;
}
if (self.liveness.isUnused(inst) or !ret_ty.hasCodeGenBits()) {
return WValue.none;
} else if (ret_ty.isNoReturn()) {
try self.addTag(.@"unreachable");
return WValue.none;
} else if (first_param_sret) {
return sret;
} else {
const result_local = try self.allocLocal(ret_ty);
try self.addLabel(.local_set, result_local.local);
return result_local;
}
}
fn airAlloc(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const pointee_type = self.air.typeOfIndex(inst).childType();
@ -1989,7 +2001,7 @@ fn emitUndefined(self: *Self, ty: Type) InnerError!void {
// validator will not accept it due to out-of-bounds memory access);
.Array => try self.addImm32(@bitCast(i32, @as(u32, 0xaa))),
.Struct => {
// TODO: Write 0xaa to each field
// TODO: Write 0xaa struct's memory
const result = try self.allocStack(ty);
try self.addLabel(.local_get, result.local);
},
@ -2943,3 +2955,71 @@ fn airPtrBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
try self.addLabel(.local_set, result.local);
return result;
}
fn airMemset(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const pl_op = self.air.instructions.items(.data)[inst].pl_op;
const bin_op = self.air.extraData(Air.Bin, pl_op.payload).data;
const ptr = self.resolveInst(pl_op.operand);
const value = self.resolveInst(bin_op.lhs);
const len = self.resolveInst(bin_op.rhs);
try self.memSet(ptr, len, value);
return WValue.none;
}
/// Sets a region of memory at `ptr` to the value of `value`
/// When the user has enabled the bulk_memory feature, we lower
/// this to wasm's memset instruction. When the feature is not present,
/// we implement it manually.
fn memSet(self: *Self, ptr: WValue, len: WValue, value: WValue) InnerError!void {
// When bulk_memory is enabled, we lower it to wasm's memset instruction.
// If not, we lower it ourselves
if (std.Target.wasm.featureSetHas(self.target.cpu.features, .bulk_memory)) {
try self.emitWValue(ptr);
try self.emitWValue(value);
try self.emitWValue(len);
try self.addExtended(.memory_fill);
return;
}
// TODO: We should probably lower this to a call to compiler_rt
// But for now, we implement it manually
const offset = try self.allocLocal(Type.usize); // local for counter
// outer block to jump to when loop is done
try self.startBlock(.block, wasm.block_empty);
try self.startBlock(.loop, wasm.block_empty);
try self.emitWValue(offset);
try self.emitWValue(len);
switch (self.ptrSize()) {
4 => try self.addTag(.i32_eq),
8 => try self.addTag(.i64_eq),
else => unreachable,
}
try self.addLabel(.br_if, 1); // jump out of loop into outer block (finished)
try self.emitWValue(ptr);
try self.emitWValue(offset);
switch (self.ptrSize()) {
4 => try self.addTag(.i32_add),
8 => try self.addTag(.i64_add),
else => unreachable,
}
try self.emitWValue(value);
const mem_store_op: Mir.Inst.Tag = switch (self.ptrSize()) {
4 => .i32_store8,
8 => .i64_store8,
else => unreachable,
};
try self.addMemArg(mem_store_op, .{ .offset = 0, .alignment = 1 });
try self.emitWValue(offset);
try self.addImm32(1);
switch (self.ptrSize()) {
4 => try self.addTag(.i32_add),
8 => try self.addTag(.i64_add),
else => unreachable,
}
try self.addLabel(.local_set, offset.local);
try self.addLabel(.br, 0); // jump to start of loop
try self.endBlock();
try self.endBlock();
}

View File

@ -161,6 +161,8 @@ pub fn emitMir(emit: *Emit) InnerError!void {
.i64_extend8_s => try emit.emitTag(tag),
.i64_extend16_s => try emit.emitTag(tag),
.i64_extend32_s => try emit.emitTag(tag),
.extended => try emit.emitExtended(inst),
}
}
}
@ -321,3 +323,20 @@ fn emitMemAddress(emit: *Emit, inst: Mir.Inst.Index) !void {
.relocation_type = .R_WASM_MEMORY_ADDR_LEB,
});
}
fn emitExtended(emit: *Emit, inst: Mir.Inst.Index) !void {
const opcode = emit.mir.instructions.items(.secondary)[inst];
switch (@intToEnum(std.wasm.PrefixedOpcode, opcode)) {
.memory_fill => try emit.emitMemFill(),
else => |tag| return emit.fail("TODO: Implement extension instruction: {s}\n", .{@tagName(tag)}),
}
}
fn emitMemFill(emit: *Emit) !void {
try emit.code.append(0xFC);
try emit.code.append(0x0B);
// When multi-memory proposal reaches phase 4, we
// can emit a different memory index here.
// For now we will always emit index 0.
try leb128.writeULEB128(emit.code.writer(), @as(u32, 0));
}

View File

@ -19,6 +19,9 @@ extra: []const u32,
pub const Inst = struct {
/// The opcode that represents this instruction
tag: Tag,
/// This opcode will be set when `tag` represents an extended
/// instruction with prefix 0xFC, or a simd instruction with prefix 0xFD.
secondary: u8 = 0,
/// Data is determined by the set `tag`.
/// For example, `data` will be an i32 for when `tag` is 'i32_const'.
data: Data,
@ -373,6 +376,11 @@ pub const Inst = struct {
i64_extend16_s = 0xC3,
/// Uses `tag`
i64_extend32_s = 0xC4,
/// The instruction consists of an extension opcode
/// set in `secondary`
///
/// The `data` field depends on the extension instruction
extended = 0xFC,
/// Contains a symbol to a function pointer
/// uses `label`
///

View File

@ -39,16 +39,23 @@ test {
_ = @import("behavior/defer.zig");
_ = @import("behavior/enum.zig");
_ = @import("behavior/error.zig");
_ = @import("behavior/generics.zig");
_ = @import("behavior/if.zig");
_ = @import("behavior/import.zig");
_ = @import("behavior/incomplete_struct_param_tld.zig");
_ = @import("behavior/inttoptr.zig");
_ = @import("behavior/member_func.zig");
_ = @import("behavior/null.zig");
_ = @import("behavior/pointers.zig");
_ = @import("behavior/ptrcast.zig");
_ = @import("behavior/ref_var_in_if_after_if_2nd_switch_prong.zig");
_ = @import("behavior/struct.zig");
_ = @import("behavior/this.zig");
_ = @import("behavior/truncate.zig");
_ = @import("behavior/usingnamespace.zig");
_ = @import("behavior/underscore.zig");
_ = @import("behavior/usingnamespace.zig");
_ = @import("behavior/void.zig");
_ = @import("behavior/while.zig");
if (!builtin.zig_is_stage2 or builtin.stage2_arch != .wasm32) {
// Tests that pass for stage1, llvm backend, C backend
@ -56,16 +63,9 @@ test {
_ = @import("behavior/array.zig");
_ = @import("behavior/cast.zig");
_ = @import("behavior/for.zig");
_ = @import("behavior/generics.zig");
_ = @import("behavior/int128.zig");
_ = @import("behavior/member_func.zig");
_ = @import("behavior/null.zig");
_ = @import("behavior/optional.zig");
_ = @import("behavior/struct.zig");
_ = @import("behavior/this.zig");
_ = @import("behavior/translate_c_macros.zig");
_ = @import("behavior/while.zig");
_ = @import("behavior/void.zig");
if (builtin.object_format != .c) {
// Tests that pass for stage1 and the llvm backend.