From 5f2d0d414dc44af7bda0e8d805d038e8f1a6f9d3 Mon Sep 17 00:00:00 2001 From: Luuk de Gram Date: Sun, 24 Apr 2022 21:49:12 +0200 Subject: [PATCH] wasm: Implement codegen for C-ABI This implements passing arguments and storing return values correctly for the C-ABI as specified by the tool-convention: https://github.com/WebAssembly/tool-conventions/blob/main/BasicCABI.md There's definitely room for better codegen in follow-up commits. --- src/arch/wasm/CodeGen.zig | 228 +++++++++++++++++++++++++++++++------- src/arch/wasm/abi.zig | 29 ++++- 2 files changed, 209 insertions(+), 48 deletions(-) diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index 4586f5624b..8eadfe6cd8 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -21,6 +21,7 @@ const Air = @import("../../Air.zig"); const Liveness = @import("../../Liveness.zig"); const Mir = @import("Mir.zig"); const Emit = @import("Emit.zig"); +const abi = @import("abi.zig"); /// Wasm Value, created when generating an instruction const WValue = union(enum) { @@ -722,18 +723,15 @@ fn typeToValtype(ty: Type, target: std.Target) wasm.Valtype { const bits = ty.floatBits(target); if (bits == 16 or bits == 32) break :blk wasm.Valtype.f32; if (bits == 64) break :blk wasm.Valtype.f64; + if (bits == 128) break :blk wasm.Valtype.i64; return wasm.Valtype.i32; // represented as pointer to stack }, - .Int => blk: { + .Int, .Enum => blk: { const info = ty.intInfo(target); if (info.bits <= 32) break :blk wasm.Valtype.i32; - if (info.bits > 32 and info.bits <= 64) break :blk wasm.Valtype.i64; + if (info.bits > 32 and info.bits <= 128) break :blk wasm.Valtype.i64; break :blk wasm.Valtype.i32; // represented as pointer to stack }, - .Enum => { - var buf: Type.Payload.Bits = undefined; - return typeToValtype(ty.intTagType(&buf), target); - }, else => wasm.Valtype.i32, // all represented as reference/immediate }; } @@ -787,33 +785,46 @@ fn allocLocal(self: *Self, ty: Type) InnerError!WValue { /// Generates a `wasm.Type` from a given function type. /// Memory is owned by the caller. -fn genFunctype(gpa: Allocator, fn_ty: Type, target: std.Target) !wasm.Type { +fn genFunctype(gpa: Allocator, fn_info: Type.Payload.Function.Data, target: std.Target) !wasm.Type { var params = std.ArrayList(wasm.Valtype).init(gpa); defer params.deinit(); var returns = std.ArrayList(wasm.Valtype).init(gpa); defer returns.deinit(); - const return_type = fn_ty.fnReturnType(); - const want_sret = isByRef(return_type, target); - - if (want_sret) { - try params.append(typeToValtype(return_type, target)); - } - - // param types - if (fn_ty.fnParamLen() != 0) { - const fn_params = try gpa.alloc(Type, fn_ty.fnParamLen()); - defer gpa.free(fn_params); - fn_ty.fnParamTypes(fn_params); - for (fn_params) |param_type| { - if (!param_type.hasRuntimeBitsIgnoreComptime()) continue; - try params.append(typeToValtype(param_type, target)); + if (firstParamSRet(fn_info, target)) { + try params.append(typeToValtype(fn_info.return_type, target)); + } else if (fn_info.return_type.hasRuntimeBitsIgnoreComptime()) { + if (fn_info.cc == .C) { + const res_classes = abi.classifyType(fn_info.return_type, target); + assert(res_classes[0] == .direct and res_classes[1] == .none); + const scalar_type = abi.scalarType(fn_info.return_type, target); + try returns.append(typeToValtype(scalar_type, target)); + } else { + try returns.append(typeToValtype(fn_info.return_type, target)); } } - // return type - if (!want_sret and return_type.hasRuntimeBitsIgnoreComptime()) { - try returns.append(typeToValtype(return_type, target)); + // param types + if (fn_info.param_types.len != 0) { + for (fn_info.param_types) |param_type| { + if (!param_type.hasRuntimeBitsIgnoreComptime()) continue; + + switch (fn_info.cc) { + .C => { + const param_classes = abi.classifyType(param_type, target); + for (param_classes) |class| { + if (class == .none) continue; + if (class == .direct) { + const scalar_type = abi.scalarType(param_type, target); + try params.append(typeToValtype(scalar_type, target)); + } else { + try params.append(typeToValtype(param_type, target)); + } + } + }, + else => try params.append(typeToValtype(param_type, target)), + } + } } return wasm.Type{ @@ -857,7 +868,7 @@ pub fn generate( } fn genFunc(self: *Self) InnerError!void { - var func_type = try genFunctype(self.gpa, self.decl.ty, self.target); + var func_type = try genFunctype(self.gpa, self.decl.ty.fnInfo(), self.target); defer func_type.deinit(self.gpa); self.decl.fn_link.wasm.type_index = try self.bin_file.putOrGetFuncType(func_type); @@ -957,21 +968,22 @@ fn resolveCallingConventionValues(self: *Self, fn_ty: Type) InnerError!CallWValu .args = &.{}, .return_value = .none, }; + if (cc == .Naked) return result; + var args = std.ArrayList(WValue).init(self.gpa); defer args.deinit(); - const ret_ty = fn_ty.fnReturnType(); // Check if we store the result as a pointer to the stack rather than // by value - if (isByRef(ret_ty, self.target)) { + if (firstParamSRet(fn_ty.fnInfo(), self.target)) { // the sret arg will be passed as first argument, therefore we // set the `return_value` before allocating locals for regular args. result.return_value = .{ .local = self.local_index }; self.local_index += 1; } + switch (cc) { - .Naked => return result, - .Unspecified, .C => { + .Unspecified => { for (param_types) |ty| { if (!ty.hasRuntimeBitsIgnoreComptime()) { continue; @@ -981,12 +993,105 @@ fn resolveCallingConventionValues(self: *Self, fn_ty: Type) InnerError!CallWValu self.local_index += 1; } }, - else => return self.fail("TODO implement function parameters for cc '{}' on wasm", .{cc}), + .C => { + for (param_types) |ty| { + const ty_classes = abi.classifyType(ty, self.target); + for (ty_classes) |class| { + if (class == .none) continue; + try args.append(.{ .local = self.local_index }); + self.local_index += 1; + } + } + }, + else => return self.fail("calling convention '{s}' not supported for Wasm", .{@tagName(cc)}), } result.args = args.toOwnedSlice(); return result; } +fn firstParamSRet(fn_info: Type.Payload.Function.Data, target: std.Target) bool { + switch (fn_info.cc) { + .Unspecified, .Inline => return isByRef(fn_info.return_type, target), + .C => { + const ty_classes = abi.classifyType(fn_info.return_type, target); + if (ty_classes[0] == .indirect) return true; + if (ty_classes[0] == .direct and ty_classes[1] == .direct) return true; + return false; + }, + else => return false, + } +} + +/// Lowers a Zig type and its value based on a given calling convention to ensure +/// it matches the ABI. +fn lowerArg(self: *Self, cc: std.builtin.CallingConvention, ty: Type, value: WValue) !void { + if (cc != .C) { + return self.lowerToStack(value); + } + + const ty_classes = abi.classifyType(ty, self.target); + assert(ty_classes[0] != .none); + switch (ty.zigTypeTag()) { + .Struct, .Union => { + if (ty_classes[0] == .indirect) { + return self.lowerToStack(value); + } + assert(ty_classes[0] == .direct); + const scalar_type = abi.scalarType(ty, self.target); + const abi_size = scalar_type.abiSize(self.target); + const opcode = buildOpcode(.{ + .op = .load, + .width = @intCast(u8, abi_size), + .signedness = if (scalar_type.isSignedInt()) .signed else .unsigned, + .valtype1 = typeToValtype(scalar_type, self.target), + }); + try self.emitWValue(value); + try self.addMemArg(Mir.Inst.Tag.fromOpcode(opcode), .{ + .offset = value.offset(), + .alignment = scalar_type.abiAlignment(self.target), + }); + }, + .Int, .Float => { + if (ty_classes[1] == .none) { + return self.lowerToStack(value); + } + assert(ty_classes[0] == .direct and ty_classes[1] == .direct); + assert(ty.abiSize(self.target) == 16); + // in this case we have an integer or float that must be lowered as 2 i64's. + try self.emitWValue(value); + try self.addMemArg(.i64_load, .{ .offset = value.offset(), .alignment = 16 }); + try self.emitWValue(value); + try self.addMemArg(.i64_load, .{ .offset = value.offset() + 8, .alignment = 16 }); + }, + else => return self.lowerToStack(value), + } +} + +/// Lowers a `WValue` to the stack. This means when the `value` results in +/// `.stack_offset` we calculate the pointer of this offset and use that. +/// The value is left on the stack, and not stored in any temporary. +fn lowerToStack(self: *Self, value: WValue) !void { + switch (value) { + .stack_offset => |offset| { + try self.emitWValue(value); + if (offset > 0) { + switch (self.arch()) { + .wasm32 => { + try self.addImm32(@bitCast(i32, offset)); + try self.addTag(.i32_add); + }, + .wasm64 => { + try self.addImm64(offset); + try self.addTag(.i64_add); + }, + else => unreachable, + } + } + }, + else => try self.emitWValue(value), + } +} + /// Creates a local for the initial stack value /// Asserts `initial_stack_value` is `.none` fn initializeStack(self: *Self) !void { @@ -1489,11 +1594,31 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { fn airRet(self: *Self, inst: Air.Inst.Index) InnerError!WValue { const un_op = self.air.instructions.items(.data)[inst].un_op; const operand = try self.resolveInst(un_op); + const ret_ty = self.decl.ty.fnReturnType(); // result must be stored in the stack and we return a pointer // to the stack instead if (self.return_value != .none) { try self.store(self.return_value, operand, self.decl.ty.fnReturnType(), 0); + } else if (self.decl.ty.fnInfo().cc == .C and ret_ty.hasRuntimeBitsIgnoreComptime()) { + switch (ret_ty.zigTypeTag()) { + // Aggregate types can be lowered as a singular value + .Struct, .Union => { + const scalar_type = abi.scalarType(ret_ty, self.target); + try self.emitWValue(operand); + const opcode = buildOpcode(.{ + .op = .load, + .width = @intCast(u8, scalar_type.abiSize(self.target)), + .signedness = if (scalar_type.isSignedInt()) .signed else .unsigned, + .valtype1 = typeToValtype(scalar_type, self.target), + }); + try self.addMemArg(Mir.Inst.Tag.fromOpcode(opcode), .{ + .offset = operand.offset(), + .alignment = scalar_type.abiAlignment(self.target), + }); + }, + else => try self.emitWValue(operand), + } } else { try self.emitWValue(operand); } @@ -1509,9 +1634,10 @@ fn airRetPtr(self: *Self, inst: Air.Inst.Index) InnerError!WValue { return self.allocStack(Type.usize); // create pointer to void } - if (isByRef(child_type, self.target)) { + if (firstParamSRet(self.decl.ty.fnInfo(), self.target)) { return self.return_value; } + return self.allocStackPtr(inst); } @@ -1521,7 +1647,7 @@ fn airRetLoad(self: *Self, inst: Air.Inst.Index) InnerError!WValue { const ret_ty = self.air.typeOf(un_op).childType(); if (!ret_ty.hasRuntimeBitsIgnoreComptime()) return WValue.none; - if (!isByRef(ret_ty, self.target)) { + if (!firstParamSRet(self.decl.ty.fnInfo(), self.target)) { const result = try self.load(operand, ret_ty, 0); try self.emitWValue(result); } @@ -1544,7 +1670,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions. else => unreachable, }; const ret_ty = fn_ty.fnReturnType(); - const first_param_sret = isByRef(ret_ty, self.target); + const first_param_sret = firstParamSRet(fn_ty.fnInfo(), self.target); const callee: ?*Decl = blk: { const func_val = self.air.value(pl_op.operand) orelse break :blk null; @@ -1554,7 +1680,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions. break :blk module.declPtr(func.data.owner_decl); } else if (func_val.castTag(.extern_fn)) |extern_fn| { const ext_decl = module.declPtr(extern_fn.data.owner_decl); - var func_type = try genFunctype(self.gpa, ext_decl.ty, self.target); + var func_type = try genFunctype(self.gpa, ext_decl.ty.fnInfo(), self.target); defer func_type.deinit(self.gpa); ext_decl.fn_link.wasm.type_index = try self.bin_file.putOrGetFuncType(func_type); try self.bin_file.addOrUpdateImport(ext_decl); @@ -1579,10 +1705,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions. const arg_ty = self.air.typeOf(arg_ref); if (!arg_ty.hasRuntimeBitsIgnoreComptime()) continue; - switch (arg_val) { - .stack_offset => try self.emitWValue(try self.buildPointerOffset(arg_val, 0, .new)), - else => try self.emitWValue(arg_val), - } + try self.lowerArg(fn_ty.fnInfo().cc, arg_ty, arg_val); } if (callee) |direct| { @@ -1594,7 +1717,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions. const operand = try self.resolveInst(pl_op.operand); try self.emitWValue(operand); - var fn_type = try genFunctype(self.gpa, fn_ty, self.target); + var fn_type = try genFunctype(self.gpa, fn_ty.fnInfo(), self.target); defer fn_type.deinit(self.gpa); const fn_type_index = try self.bin_file.putOrGetFuncType(fn_type); @@ -1608,6 +1731,14 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions. return WValue.none; } else if (first_param_sret) { return sret; + // TODO: Make this less fragile and optimize + } else if (fn_ty.fnInfo().cc == .C and ret_ty.zigTypeTag() == .Struct or ret_ty.zigTypeTag() == .Union) { + const result_local = try self.allocLocal(ret_ty); + try self.addLabel(.local_set, result_local.local); + const scalar_type = abi.scalarType(ret_ty, self.target); + const result = try self.allocStack(scalar_type); + try self.store(result, result_local, scalar_type, 0); + return result; } else { const result_local = try self.allocLocal(ret_ty); try self.addLabel(.local_set, result_local.local); @@ -1749,9 +1880,20 @@ fn load(self: *Self, operand: WValue, ty: Type, offset: u32) InnerError!WValue { } fn airArg(self: *Self, inst: Air.Inst.Index) InnerError!WValue { - _ = inst; - defer self.arg_index += 1; - return self.args[self.arg_index]; + const arg = self.args[self.arg_index]; + const cc = self.decl.ty.fnInfo().cc; + if (cc == .C) { + const ty = self.air.typeOfIndex(inst); + const arg_classes = abi.classifyType(ty, self.target); + for (arg_classes) |class| { + if (class != .none) { + self.arg_index += 1; + } + } + } else { + self.arg_index += 1; + } + return arg; } fn airBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue { diff --git a/src/arch/wasm/abi.zig b/src/arch/wasm/abi.zig index 99398fd00c..63b319613b 100644 --- a/src/arch/wasm/abi.zig +++ b/src/arch/wasm/abi.zig @@ -52,6 +52,7 @@ pub fn classifyType(ty: Type, target: Target) [2]Class { return memory; }, .Bool => return direct, + .Array => return memory, .ErrorUnion => { const has_tag = ty.errorUnionSet().hasRuntimeBitsIgnoreComptime(); const has_pl = ty.errorUnionPayload().hasRuntimeBitsIgnoreComptime(); @@ -73,16 +74,13 @@ pub fn classifyType(ty: Type, target: Target) [2]Class { if (ty.isSlice()) return memory; return direct; }, - .Array => { - if (ty.arrayLen() == 1) return direct; - return memory; - }, .Union => { const layout = ty.unionGetLayout(target); if (layout.payload_size == 0 and layout.tag_size != 0) { return classifyType(ty.unionTagType().?, target); } - return classifyType(ty.errorUnionPayload(), target); + if (ty.unionFields().count() > 1) return memory; + return classifyType(ty.unionFields().values()[0].ty, target); }, .AnyFrame, .Frame => return direct, @@ -100,3 +98,24 @@ pub fn classifyType(ty: Type, target: Target) [2]Class { => unreachable, } } + +/// Returns the scalar type a given type can represent. +/// Asserts given type can be represented as scalar, such as +/// a struct with a single scalar field. +pub fn scalarType(ty: Type, target: std.Target) Type { + switch (ty.zigTypeTag()) { + .Struct => { + std.debug.assert(ty.structFieldCount() == 1); + return scalarType(ty.structFieldType(0), target); + }, + .Union => { + const layout = ty.unionGetLayout(target); + if (layout.payload_size == 0 and layout.tag_size != 0) { + return scalarType(ty.unionTagType().?, target); + } + std.debug.assert(ty.unionFields().count() == 1); + return scalarType(ty.unionFields().values()[0].ty, target); + }, + else => return ty, + } +}