wasm: Implement codegen for C-ABI

This implements passing arguments and storing return values correctly
for the C-ABI as specified by the tool-convention:
https://github.com/WebAssembly/tool-conventions/blob/main/BasicCABI.md

There's definitely room for better codegen in follow-up commits.
This commit is contained in:
Luuk de Gram 2022-04-24 21:49:12 +02:00
parent cb49af6c9a
commit 5f2d0d414d
2 changed files with 209 additions and 48 deletions

View File

@ -21,6 +21,7 @@ const Air = @import("../../Air.zig");
const Liveness = @import("../../Liveness.zig");
const Mir = @import("Mir.zig");
const Emit = @import("Emit.zig");
const abi = @import("abi.zig");
/// Wasm Value, created when generating an instruction
const WValue = union(enum) {
@ -722,18 +723,15 @@ fn typeToValtype(ty: Type, target: std.Target) wasm.Valtype {
const bits = ty.floatBits(target);
if (bits == 16 or bits == 32) break :blk wasm.Valtype.f32;
if (bits == 64) break :blk wasm.Valtype.f64;
if (bits == 128) break :blk wasm.Valtype.i64;
return wasm.Valtype.i32; // represented as pointer to stack
},
.Int => blk: {
.Int, .Enum => blk: {
const info = ty.intInfo(target);
if (info.bits <= 32) break :blk wasm.Valtype.i32;
if (info.bits > 32 and info.bits <= 64) break :blk wasm.Valtype.i64;
if (info.bits > 32 and info.bits <= 128) break :blk wasm.Valtype.i64;
break :blk wasm.Valtype.i32; // represented as pointer to stack
},
.Enum => {
var buf: Type.Payload.Bits = undefined;
return typeToValtype(ty.intTagType(&buf), target);
},
else => wasm.Valtype.i32, // all represented as reference/immediate
};
}
@ -787,33 +785,46 @@ fn allocLocal(self: *Self, ty: Type) InnerError!WValue {
/// Generates a `wasm.Type` from a given function type.
/// Memory is owned by the caller.
fn genFunctype(gpa: Allocator, fn_ty: Type, target: std.Target) !wasm.Type {
fn genFunctype(gpa: Allocator, fn_info: Type.Payload.Function.Data, target: std.Target) !wasm.Type {
var params = std.ArrayList(wasm.Valtype).init(gpa);
defer params.deinit();
var returns = std.ArrayList(wasm.Valtype).init(gpa);
defer returns.deinit();
const return_type = fn_ty.fnReturnType();
const want_sret = isByRef(return_type, target);
if (want_sret) {
try params.append(typeToValtype(return_type, target));
}
// param types
if (fn_ty.fnParamLen() != 0) {
const fn_params = try gpa.alloc(Type, fn_ty.fnParamLen());
defer gpa.free(fn_params);
fn_ty.fnParamTypes(fn_params);
for (fn_params) |param_type| {
if (!param_type.hasRuntimeBitsIgnoreComptime()) continue;
try params.append(typeToValtype(param_type, target));
if (firstParamSRet(fn_info, target)) {
try params.append(typeToValtype(fn_info.return_type, target));
} else if (fn_info.return_type.hasRuntimeBitsIgnoreComptime()) {
if (fn_info.cc == .C) {
const res_classes = abi.classifyType(fn_info.return_type, target);
assert(res_classes[0] == .direct and res_classes[1] == .none);
const scalar_type = abi.scalarType(fn_info.return_type, target);
try returns.append(typeToValtype(scalar_type, target));
} else {
try returns.append(typeToValtype(fn_info.return_type, target));
}
}
// return type
if (!want_sret and return_type.hasRuntimeBitsIgnoreComptime()) {
try returns.append(typeToValtype(return_type, target));
// param types
if (fn_info.param_types.len != 0) {
for (fn_info.param_types) |param_type| {
if (!param_type.hasRuntimeBitsIgnoreComptime()) continue;
switch (fn_info.cc) {
.C => {
const param_classes = abi.classifyType(param_type, target);
for (param_classes) |class| {
if (class == .none) continue;
if (class == .direct) {
const scalar_type = abi.scalarType(param_type, target);
try params.append(typeToValtype(scalar_type, target));
} else {
try params.append(typeToValtype(param_type, target));
}
}
},
else => try params.append(typeToValtype(param_type, target)),
}
}
}
return wasm.Type{
@ -857,7 +868,7 @@ pub fn generate(
}
fn genFunc(self: *Self) InnerError!void {
var func_type = try genFunctype(self.gpa, self.decl.ty, self.target);
var func_type = try genFunctype(self.gpa, self.decl.ty.fnInfo(), self.target);
defer func_type.deinit(self.gpa);
self.decl.fn_link.wasm.type_index = try self.bin_file.putOrGetFuncType(func_type);
@ -957,21 +968,22 @@ fn resolveCallingConventionValues(self: *Self, fn_ty: Type) InnerError!CallWValu
.args = &.{},
.return_value = .none,
};
if (cc == .Naked) return result;
var args = std.ArrayList(WValue).init(self.gpa);
defer args.deinit();
const ret_ty = fn_ty.fnReturnType();
// Check if we store the result as a pointer to the stack rather than
// by value
if (isByRef(ret_ty, self.target)) {
if (firstParamSRet(fn_ty.fnInfo(), self.target)) {
// the sret arg will be passed as first argument, therefore we
// set the `return_value` before allocating locals for regular args.
result.return_value = .{ .local = self.local_index };
self.local_index += 1;
}
switch (cc) {
.Naked => return result,
.Unspecified, .C => {
.Unspecified => {
for (param_types) |ty| {
if (!ty.hasRuntimeBitsIgnoreComptime()) {
continue;
@ -981,12 +993,105 @@ fn resolveCallingConventionValues(self: *Self, fn_ty: Type) InnerError!CallWValu
self.local_index += 1;
}
},
else => return self.fail("TODO implement function parameters for cc '{}' on wasm", .{cc}),
.C => {
for (param_types) |ty| {
const ty_classes = abi.classifyType(ty, self.target);
for (ty_classes) |class| {
if (class == .none) continue;
try args.append(.{ .local = self.local_index });
self.local_index += 1;
}
}
},
else => return self.fail("calling convention '{s}' not supported for Wasm", .{@tagName(cc)}),
}
result.args = args.toOwnedSlice();
return result;
}
fn firstParamSRet(fn_info: Type.Payload.Function.Data, target: std.Target) bool {
switch (fn_info.cc) {
.Unspecified, .Inline => return isByRef(fn_info.return_type, target),
.C => {
const ty_classes = abi.classifyType(fn_info.return_type, target);
if (ty_classes[0] == .indirect) return true;
if (ty_classes[0] == .direct and ty_classes[1] == .direct) return true;
return false;
},
else => return false,
}
}
/// Lowers a Zig type and its value based on a given calling convention to ensure
/// it matches the ABI.
fn lowerArg(self: *Self, cc: std.builtin.CallingConvention, ty: Type, value: WValue) !void {
if (cc != .C) {
return self.lowerToStack(value);
}
const ty_classes = abi.classifyType(ty, self.target);
assert(ty_classes[0] != .none);
switch (ty.zigTypeTag()) {
.Struct, .Union => {
if (ty_classes[0] == .indirect) {
return self.lowerToStack(value);
}
assert(ty_classes[0] == .direct);
const scalar_type = abi.scalarType(ty, self.target);
const abi_size = scalar_type.abiSize(self.target);
const opcode = buildOpcode(.{
.op = .load,
.width = @intCast(u8, abi_size),
.signedness = if (scalar_type.isSignedInt()) .signed else .unsigned,
.valtype1 = typeToValtype(scalar_type, self.target),
});
try self.emitWValue(value);
try self.addMemArg(Mir.Inst.Tag.fromOpcode(opcode), .{
.offset = value.offset(),
.alignment = scalar_type.abiAlignment(self.target),
});
},
.Int, .Float => {
if (ty_classes[1] == .none) {
return self.lowerToStack(value);
}
assert(ty_classes[0] == .direct and ty_classes[1] == .direct);
assert(ty.abiSize(self.target) == 16);
// in this case we have an integer or float that must be lowered as 2 i64's.
try self.emitWValue(value);
try self.addMemArg(.i64_load, .{ .offset = value.offset(), .alignment = 16 });
try self.emitWValue(value);
try self.addMemArg(.i64_load, .{ .offset = value.offset() + 8, .alignment = 16 });
},
else => return self.lowerToStack(value),
}
}
/// Lowers a `WValue` to the stack. This means when the `value` results in
/// `.stack_offset` we calculate the pointer of this offset and use that.
/// The value is left on the stack, and not stored in any temporary.
fn lowerToStack(self: *Self, value: WValue) !void {
switch (value) {
.stack_offset => |offset| {
try self.emitWValue(value);
if (offset > 0) {
switch (self.arch()) {
.wasm32 => {
try self.addImm32(@bitCast(i32, offset));
try self.addTag(.i32_add);
},
.wasm64 => {
try self.addImm64(offset);
try self.addTag(.i64_add);
},
else => unreachable,
}
}
},
else => try self.emitWValue(value),
}
}
/// Creates a local for the initial stack value
/// Asserts `initial_stack_value` is `.none`
fn initializeStack(self: *Self) !void {
@ -1489,11 +1594,31 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
fn airRet(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const un_op = self.air.instructions.items(.data)[inst].un_op;
const operand = try self.resolveInst(un_op);
const ret_ty = self.decl.ty.fnReturnType();
// result must be stored in the stack and we return a pointer
// to the stack instead
if (self.return_value != .none) {
try self.store(self.return_value, operand, self.decl.ty.fnReturnType(), 0);
} else if (self.decl.ty.fnInfo().cc == .C and ret_ty.hasRuntimeBitsIgnoreComptime()) {
switch (ret_ty.zigTypeTag()) {
// Aggregate types can be lowered as a singular value
.Struct, .Union => {
const scalar_type = abi.scalarType(ret_ty, self.target);
try self.emitWValue(operand);
const opcode = buildOpcode(.{
.op = .load,
.width = @intCast(u8, scalar_type.abiSize(self.target)),
.signedness = if (scalar_type.isSignedInt()) .signed else .unsigned,
.valtype1 = typeToValtype(scalar_type, self.target),
});
try self.addMemArg(Mir.Inst.Tag.fromOpcode(opcode), .{
.offset = operand.offset(),
.alignment = scalar_type.abiAlignment(self.target),
});
},
else => try self.emitWValue(operand),
}
} else {
try self.emitWValue(operand);
}
@ -1509,9 +1634,10 @@ fn airRetPtr(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
return self.allocStack(Type.usize); // create pointer to void
}
if (isByRef(child_type, self.target)) {
if (firstParamSRet(self.decl.ty.fnInfo(), self.target)) {
return self.return_value;
}
return self.allocStackPtr(inst);
}
@ -1521,7 +1647,7 @@ fn airRetLoad(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const ret_ty = self.air.typeOf(un_op).childType();
if (!ret_ty.hasRuntimeBitsIgnoreComptime()) return WValue.none;
if (!isByRef(ret_ty, self.target)) {
if (!firstParamSRet(self.decl.ty.fnInfo(), self.target)) {
const result = try self.load(operand, ret_ty, 0);
try self.emitWValue(result);
}
@ -1544,7 +1670,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
else => unreachable,
};
const ret_ty = fn_ty.fnReturnType();
const first_param_sret = isByRef(ret_ty, self.target);
const first_param_sret = firstParamSRet(fn_ty.fnInfo(), self.target);
const callee: ?*Decl = blk: {
const func_val = self.air.value(pl_op.operand) orelse break :blk null;
@ -1554,7 +1680,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
break :blk module.declPtr(func.data.owner_decl);
} else if (func_val.castTag(.extern_fn)) |extern_fn| {
const ext_decl = module.declPtr(extern_fn.data.owner_decl);
var func_type = try genFunctype(self.gpa, ext_decl.ty, self.target);
var func_type = try genFunctype(self.gpa, ext_decl.ty.fnInfo(), self.target);
defer func_type.deinit(self.gpa);
ext_decl.fn_link.wasm.type_index = try self.bin_file.putOrGetFuncType(func_type);
try self.bin_file.addOrUpdateImport(ext_decl);
@ -1579,10 +1705,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
const arg_ty = self.air.typeOf(arg_ref);
if (!arg_ty.hasRuntimeBitsIgnoreComptime()) continue;
switch (arg_val) {
.stack_offset => try self.emitWValue(try self.buildPointerOffset(arg_val, 0, .new)),
else => try self.emitWValue(arg_val),
}
try self.lowerArg(fn_ty.fnInfo().cc, arg_ty, arg_val);
}
if (callee) |direct| {
@ -1594,7 +1717,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
const operand = try self.resolveInst(pl_op.operand);
try self.emitWValue(operand);
var fn_type = try genFunctype(self.gpa, fn_ty, self.target);
var fn_type = try genFunctype(self.gpa, fn_ty.fnInfo(), self.target);
defer fn_type.deinit(self.gpa);
const fn_type_index = try self.bin_file.putOrGetFuncType(fn_type);
@ -1608,6 +1731,14 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
return WValue.none;
} else if (first_param_sret) {
return sret;
// TODO: Make this less fragile and optimize
} else if (fn_ty.fnInfo().cc == .C and ret_ty.zigTypeTag() == .Struct or ret_ty.zigTypeTag() == .Union) {
const result_local = try self.allocLocal(ret_ty);
try self.addLabel(.local_set, result_local.local);
const scalar_type = abi.scalarType(ret_ty, self.target);
const result = try self.allocStack(scalar_type);
try self.store(result, result_local, scalar_type, 0);
return result;
} else {
const result_local = try self.allocLocal(ret_ty);
try self.addLabel(.local_set, result_local.local);
@ -1749,9 +1880,20 @@ fn load(self: *Self, operand: WValue, ty: Type, offset: u32) InnerError!WValue {
}
fn airArg(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
_ = inst;
defer self.arg_index += 1;
return self.args[self.arg_index];
const arg = self.args[self.arg_index];
const cc = self.decl.ty.fnInfo().cc;
if (cc == .C) {
const ty = self.air.typeOfIndex(inst);
const arg_classes = abi.classifyType(ty, self.target);
for (arg_classes) |class| {
if (class != .none) {
self.arg_index += 1;
}
}
} else {
self.arg_index += 1;
}
return arg;
}
fn airBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {

View File

@ -52,6 +52,7 @@ pub fn classifyType(ty: Type, target: Target) [2]Class {
return memory;
},
.Bool => return direct,
.Array => return memory,
.ErrorUnion => {
const has_tag = ty.errorUnionSet().hasRuntimeBitsIgnoreComptime();
const has_pl = ty.errorUnionPayload().hasRuntimeBitsIgnoreComptime();
@ -73,16 +74,13 @@ pub fn classifyType(ty: Type, target: Target) [2]Class {
if (ty.isSlice()) return memory;
return direct;
},
.Array => {
if (ty.arrayLen() == 1) return direct;
return memory;
},
.Union => {
const layout = ty.unionGetLayout(target);
if (layout.payload_size == 0 and layout.tag_size != 0) {
return classifyType(ty.unionTagType().?, target);
}
return classifyType(ty.errorUnionPayload(), target);
if (ty.unionFields().count() > 1) return memory;
return classifyType(ty.unionFields().values()[0].ty, target);
},
.AnyFrame, .Frame => return direct,
@ -100,3 +98,24 @@ pub fn classifyType(ty: Type, target: Target) [2]Class {
=> unreachable,
}
}
/// Returns the scalar type a given type can represent.
/// Asserts given type can be represented as scalar, such as
/// a struct with a single scalar field.
pub fn scalarType(ty: Type, target: std.Target) Type {
switch (ty.zigTypeTag()) {
.Struct => {
std.debug.assert(ty.structFieldCount() == 1);
return scalarType(ty.structFieldType(0), target);
},
.Union => {
const layout = ty.unionGetLayout(target);
if (layout.payload_size == 0 and layout.tag_size != 0) {
return scalarType(ty.unionTagType().?, target);
}
std.debug.assert(ty.unionFields().count() == 1);
return scalarType(ty.unionFields().values()[0].ty, target);
},
else => return ty,
}
}