diff --git a/src/Air.zig b/src/Air.zig index 268b6c8631..6a2013ba40 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -579,6 +579,10 @@ pub const Inst = struct { /// Uses the `prefetch` field. prefetch, + /// Computes `(a * b) + c`, but only rounds once. + /// Uses the `ty_pl` field. + mul_add, + /// Implements @fieldParentPtr builtin. /// Uses the `ty_pl` field. field_parent_ptr, @@ -724,6 +728,12 @@ pub const Bin = struct { rhs: Inst.Ref, }; +pub const MulAdd = struct { + mulend1: Inst.Ref, + mulend2: Inst.Ref, + addend: Inst.Ref, +}; + pub const FieldParentPtr = struct { field_ptr: Inst.Ref, field_index: u32, @@ -889,6 +899,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type { .aggregate_init, .union_init, .field_parent_ptr, + .mul_add, => return air.getRefType(datas[inst].ty_pl.ty), .not, diff --git a/src/Liveness.zig b/src/Liveness.zig index 7f007b5718..27a5fed335 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -464,6 +464,10 @@ fn analyzeInst( const extra = a.air.extraData(Air.Cmpxchg, inst_datas[inst].ty_pl.payload).data; return trackOperands(a, new_set, inst, main_tomb, .{ extra.ptr, extra.expected_value, extra.new_value }); }, + .mul_add => { + const extra = a.air.extraData(Air.MulAdd, inst_datas[inst].ty_pl.payload).data; + return trackOperands(a, new_set, inst, main_tomb, .{ extra.mulend1, extra.mulend2, extra.addend }); + }, .atomic_load => { const ptr = inst_datas[inst].atomic_load.ptr; return trackOperands(a, new_set, inst, main_tomb, .{ ptr, .none, .none }); diff --git a/src/Sema.zig b/src/Sema.zig index fb7209ae0e..4ef20a99ba 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -13518,8 +13518,84 @@ fn zirAtomicStore(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError fn zirMulAdd(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const inst_data = sema.code.instructions.items(.data)[inst].pl_node; + const extra = sema.code.extraData(Zir.Inst.MulAdd, inst_data.payload_index).data; const src = inst_data.src(); - return sema.fail(block, src, "TODO: Sema.zirMulAdd", .{}); + + const mulend1_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node }; + const mulend2_src: LazySrcLoc = .{ .node_offset_builtin_call_arg2 = inst_data.src_node }; + const addend_src: LazySrcLoc = .{ .node_offset_builtin_call_arg3 = inst_data.src_node }; + + const mulend1 = sema.resolveInst(extra.mulend1); + const mulend2 = sema.resolveInst(extra.mulend2); + const addend = sema.resolveInst(extra.addend); + // All args have the same type + const ty = sema.typeOf(mulend1); + switch (ty.zigTypeTag()) { + .ComptimeFloat, .Float => {}, + .Vector => { + const scalar_ty = ty.scalarType(); + switch (scalar_ty.zigTypeTag()) { + .ComptimeFloat, .Float => {}, + else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{scalar_ty}), + } + }, + else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{ty}), + } + + const target = sema.mod.getTarget(); + switch (ty.zigTypeTag()) { + .ComptimeFloat, .Float => { + const maybe_mulend1 = try sema.resolveMaybeUndefVal(block, mulend1_src, mulend1); + const maybe_mulend2 = try sema.resolveMaybeUndefVal(block, mulend2_src, mulend2); + const maybe_addend = try sema.resolveMaybeUndefVal(block, addend_src, addend); + + if (maybe_mulend1) |mulend1_val| { + if (mulend1_val.isUndef()) + return sema.addConstUndef(ty); + } + + if (maybe_mulend2) |mulend2_val| { + if (mulend2_val.isUndef()) + return sema.addConstUndef(ty); + } + + if (maybe_addend) |addend_val| { + if (addend_val.isUndef()) + return sema.addConstUndef(ty); + } + + if (maybe_mulend1) |mulend1_val| { + if (maybe_mulend2) |mulend2_val| { + if (maybe_addend) |addend_val| { + const result_val = try Value.mulAdd( + ty, + mulend1_val, + mulend2_val, + addend_val, + sema.arena, + target, + ); + return sema.addConstant(ty, result_val); + } + } + } + + try sema.requireRuntimeBlock(block, src); + return block.addInst(.{ + .tag = .mul_add, + .data = .{ .ty_pl = .{ + .ty = try sema.addType(ty), + .payload = try sema.addExtra(Air.MulAdd{ + .mulend1 = mulend1, + .mulend2 = mulend2, + .addend = addend, + }), + } }, + }); + }, + .Vector => return sema.fail(block, src, "TODO: implement @mulAdd for vectors", .{}), + else => unreachable, + } } fn zirBuiltinCall(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index b3447f43e7..6cfe29667d 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -632,6 +632,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .aggregate_init => try self.airAggregateInit(inst), .union_init => try self.airUnionInit(inst), .prefetch => try self.airPrefetch(inst), + .mul_add => try self.airMulAdd(inst), .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered), .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic), @@ -3652,6 +3653,11 @@ fn airPrefetch(self: *Self, inst: Air.Inst.Index) !void { return self.finishAir(inst, MCValue.dead, .{ prefetch.ptr, .none, .none }); } +fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void { + _ = inst; + return self.fail("TODO implement airMulAdd for aarch64", .{}); +} + fn resolveInst(self: *Self, inst: Air.Inst.Ref) InnerError!MCValue { // First section of indexes correspond to a set number of constant values. const ref_int = @enumToInt(inst); diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index 80f0169ba5..80ab9806be 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -628,6 +628,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .aggregate_init => try self.airAggregateInit(inst), .union_init => try self.airUnionInit(inst), .prefetch => try self.airPrefetch(inst), + .mul_add => try self.airMulAdd(inst), .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered), .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic), @@ -4086,6 +4087,11 @@ fn airPrefetch(self: *Self, inst: Air.Inst.Index) !void { return self.finishAir(inst, MCValue.dead, .{ prefetch.ptr, .none, .none }); } +fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void { + _ = inst; + return self.fail("TODO implement airMulAdd for arm", .{}); +} + fn resolveInst(self: *Self, inst: Air.Inst.Ref) InnerError!MCValue { // First section of indexes correspond to a set number of constant values. const ref_int = @enumToInt(inst); diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig index 15600c09dd..d72ca22fb2 100644 --- a/src/arch/riscv64/CodeGen.zig +++ b/src/arch/riscv64/CodeGen.zig @@ -600,6 +600,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .aggregate_init => try self.airAggregateInit(inst), .union_init => try self.airUnionInit(inst), .prefetch => try self.airPrefetch(inst), + .mul_add => try self.airMulAdd(inst), .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered), .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic), @@ -2203,6 +2204,11 @@ fn airPrefetch(self: *Self, inst: Air.Inst.Index) !void { return self.finishAir(inst, MCValue.dead, .{ prefetch.ptr, .none, .none }); } +fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void { + _ = inst; + return self.fail("TODO implement airMulAdd for riscv64", .{}); +} + fn resolveInst(self: *Self, inst: Air.Inst.Ref) InnerError!MCValue { // First section of indexes correspond to a set number of constant values. const ref_int = @enumToInt(inst); diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index b293d20db9..c3e8bb7864 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -1333,6 +1333,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue { .error_name, .errunion_payload_ptr_set, .field_parent_ptr, + .mul_add, // For these 4, probably best to wait until https://github.com/ziglang/zig/issues/10248 // is implemented in the frontend before implementing them here in the wasm backend. diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 62dc924124..d963d8f801 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -717,6 +717,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .aggregate_init => try self.airAggregateInit(inst), .union_init => try self.airUnionInit(inst), .prefetch => try self.airPrefetch(inst), + .mul_add => try self.airMulAdd(inst), .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered), .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic), @@ -5559,6 +5560,11 @@ fn airPrefetch(self: *Self, inst: Air.Inst.Index) !void { return self.finishAir(inst, MCValue.dead, .{ prefetch.ptr, .none, .none }); } +fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void { + _ = inst; + return self.fail("TODO implement airMulAdd for x86_64", .{}); +} + fn resolveInst(self: *Self, inst: Air.Inst.Ref) InnerError!MCValue { // First section of indexes correspond to a set number of constant values. const ref_int = @enumToInt(inst); diff --git a/src/codegen/c.zig b/src/codegen/c.zig index ba7bb6fa3a..e24ff0a6b0 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -1635,6 +1635,8 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO .trunc_float, => |tag| return f.fail("TODO: C backend: implement unary op for tag '{s}'", .{@tagName(tag)}), + .mul_add => return f.fail("TODO: C backend: implement @mulAdd", .{}), + .add_with_overflow => try airAddWithOverflow(f, inst), .sub_with_overflow => try airSubWithOverflow(f, inst), .mul_with_overflow => try airMulWithOverflow(f, inst), diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 446876dfe5..86573e182e 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -2194,6 +2194,7 @@ pub const FuncGen = struct { .sub_with_overflow => try self.airOverflow(inst, "llvm.ssub.with.overflow", "llvm.usub.with.overflow"), .mul_with_overflow => try self.airOverflow(inst, "llvm.smul.with.overflow", "llvm.umul.with.overflow"), .shl_with_overflow => try self.airShlWithOverflow(inst), + .mul_add => try self.airMulAdd(inst), .bit_and, .bool_and => try self.airAnd(inst), .bit_or, .bool_or => try self.airOr(inst), @@ -3842,6 +3843,46 @@ pub const FuncGen = struct { return overflow_bit; } + fn airMulAdd(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { + if (self.liveness.isUnused(inst)) + return null; + + const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; + const extra = self.air.extraData(Air.MulAdd, ty_pl.payload).data; + + const mulend1 = try self.resolveInst(extra.mulend1); + const mulend2 = try self.resolveInst(extra.mulend2); + const addend = try self.resolveInst(extra.addend); + + const ty = self.air.typeOfIndex(inst); + const llvm_ty = try self.dg.llvmType(ty); + const target = self.dg.module.getTarget(); + + const fn_val = switch (ty.floatBits(target)) { + 16, 32, 64 => blk: { + break :blk self.getIntrinsic("llvm.fma", &.{llvm_ty}); + }, + // TODO: using `llvm.fma` for f80 does not seem to work for all targets, needs further investigation. + 80 => return self.dg.todo("Implement mulAdd for f80", .{}), + 128 => blk: { + // LLVM incorrectly lowers the fma builtin for f128 to fmal, which is for + // `long double`. On some targets this will be correct; on others it will be incorrect. + if (target.longDoubleIsF128()) { + break :blk self.getIntrinsic("llvm.fma", &.{llvm_ty}); + } else { + break :blk self.dg.object.llvm_module.getNamedFunction("fmaq") orelse fn_blk: { + const param_types = [_]*const llvm.Type{ llvm_ty, llvm_ty, llvm_ty }; + const fn_type = llvm.functionType(llvm_ty, ¶m_types, param_types.len, .False); + break :fn_blk self.dg.object.llvm_module.addFunction("fmaq", fn_type); + }; + } + }, + else => unreachable, + }; + const params = [_]*const llvm.Value{ mulend1, mulend2, addend }; + return self.builder.buildCall(fn_val, ¶ms, params.len, .C, .Auto, ""); + } + fn airShlWithOverflow(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { if (self.liveness.isUnused(inst)) return null; diff --git a/src/print_air.zig b/src/print_air.zig index 2149be764a..e1c495746f 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -252,6 +252,7 @@ const Writer = struct { .field_parent_ptr => try w.writeFieldParentPtr(s, inst), .wasm_memory_size => try w.writeWasmMemorySize(s, inst), .wasm_memory_grow => try w.writeWasmMemoryGrow(s, inst), + .mul_add => try w.writeMulAdd(s, inst), .add_with_overflow, .sub_with_overflow, @@ -358,6 +359,17 @@ const Writer = struct { }); } + fn writeMulAdd(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { + const ty_pl = w.air.instructions.items(.data)[inst].ty_pl; + const extra = w.air.extraData(Air.MulAdd, ty_pl.payload).data; + + try w.writeOperand(s, inst, 0, extra.mulend1); + try s.writeAll(", "); + try w.writeOperand(s, inst, 1, extra.mulend2); + try s.writeAll(", "); + try w.writeOperand(s, inst, 2, extra.addend); + } + fn writeFence(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { const atomic_order = w.air.instructions.items(.data)[inst].fence; diff --git a/src/value.zig b/src/value.zig index e667c566b9..e8fd84c618 100644 --- a/src/value.zig +++ b/src/value.zig @@ -2931,7 +2931,7 @@ pub const Value = extern union { return fromBigInt(arena, result_bigint.toConst()); } - /// operands must be integers; handles undefined. + /// operands must be integers; handles undefined. pub fn bitwiseAnd(lhs: Value, rhs: Value, arena: Allocator) !Value { if (lhs.isUndef() or rhs.isUndef()) return Value.initTag(.undef); @@ -2951,7 +2951,7 @@ pub const Value = extern union { return fromBigInt(arena, result_bigint.toConst()); } - /// operands must be integers; handles undefined. + /// operands must be integers; handles undefined. pub fn bitwiseNand(lhs: Value, rhs: Value, ty: Type, arena: Allocator, target: Target) !Value { if (lhs.isUndef() or rhs.isUndef()) return Value.initTag(.undef); @@ -2965,7 +2965,7 @@ pub const Value = extern union { return bitwiseXor(anded, all_ones, arena); } - /// operands must be integers; handles undefined. + /// operands must be integers; handles undefined. pub fn bitwiseOr(lhs: Value, rhs: Value, arena: Allocator) !Value { if (lhs.isUndef() or rhs.isUndef()) return Value.initTag(.undef); @@ -2984,7 +2984,7 @@ pub const Value = extern union { return fromBigInt(arena, result_bigint.toConst()); } - /// operands must be integers; handles undefined. + /// operands must be integers; handles undefined. pub fn bitwiseXor(lhs: Value, rhs: Value, arena: Allocator) !Value { if (lhs.isUndef() or rhs.isUndef()) return Value.initTag(.undef); @@ -4020,6 +4020,42 @@ pub const Value = extern union { } } + pub fn mulAdd(float_type: Type, mulend1: Value, mulend2: Value, addend: Value, arena: Allocator, target: Target) Allocator.Error!Value { + switch (float_type.floatBits(target)) { + 16 => { + if (true) { + // TODO: missing f16 implementation of FMA in `std.math.fma` or compiler-rt + @panic("TODO implement mulAdd for f16"); + } + }, + 32 => { + const m1 = mulend1.toFloat(f32); + const m2 = mulend2.toFloat(f32); + const a = addend.toFloat(f32); + return Value.Tag.float_32.create(arena, std.math.fma(f32, m1, m2, a)); + }, + 64 => { + const m1 = mulend1.toFloat(f64); + const m2 = mulend2.toFloat(f64); + const a = addend.toFloat(f64); + return Value.Tag.float_64.create(arena, std.math.fma(f64, m1, m2, a)); + }, + 80 => { + if (true) { + // TODO: missing f80 implementation of FMA in `std.math.fma` or compiler-rt + @panic("TODO implement mulAdd for f80"); + } + }, + 128 => { + const m1 = mulend1.toFloat(f128); + const m2 = mulend2.toFloat(f128); + const a = addend.toFloat(f128); + return Value.Tag.float_128.create(arena, std.math.fma(f128, m1, m2, a)); + }, + else => unreachable, + } + } + /// This type is not copyable since it may contain pointers to its inner data. pub const Payload = struct { tag: Tag, diff --git a/test/behavior/muladd.zig b/test/behavior/muladd.zig index bf50541b56..4ef0e44acb 100644 --- a/test/behavior/muladd.zig +++ b/test/behavior/muladd.zig @@ -2,29 +2,33 @@ const builtin = @import("builtin"); const expect = @import("std").testing.expect; test "@mulAdd" { - if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO comptime try testMulAdd(); try testMulAdd(); } fn testMulAdd() !void { - { - var a: f16 = 5.5; - var b: f16 = 2.5; - var c: f16 = 6.25; + if (builtin.zig_backend == .stage1) { + const a: f16 = 5.5; + const b: f16 = 2.5; + const c: f16 = 6.25; try expect(@mulAdd(f16, a, b, c) == 20); } { - var a: f32 = 5.5; - var b: f32 = 2.5; - var c: f32 = 6.25; + const a: f32 = 5.5; + const b: f32 = 2.5; + const c: f32 = 6.25; try expect(@mulAdd(f32, a, b, c) == 20); } { - var a: f64 = 5.5; - var b: f64 = 2.5; - var c: f64 = 6.25; + const a: f64 = 5.5; + const b: f64 = 2.5; + const c: f64 = 6.25; try expect(@mulAdd(f64, a, b, c) == 20); } } @@ -35,7 +39,9 @@ test "@mulAdd f80" { return error.SkipZigTest; } - comptime try testMulAdd80(); + // TODO: missing f80 implementation of FMA in `std.math.fma` or compiler-rt + // comptime try testMulAdd80(); + try testMulAdd80(); } @@ -43,24 +49,27 @@ fn testMulAdd80() !void { var a: f16 = 5.5; var b: f80 = 2.5; var c: f80 = 6.25; - try expect(@mulAdd(f80, a, b, c) == 20); + try expect(@mulAdd(f80, a, b, c) == 20.0); } test "@mulAdd f128" { - if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO - + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.os.tag == .macos and builtin.cpu.arch == .aarch64) { // https://github.com/ziglang/zig/issues/9900 return error.SkipZigTest; } - comptime try testMullAdd128(); - try testMullAdd128(); + comptime try testMulAdd128(); + try testMulAdd128(); } -fn testMullAdd128() !void { - var a: f16 = 5.5; - var b: f128 = 2.5; - var c: f128 = 6.25; +fn testMulAdd128() !void { + const a: f16 = 5.5; + const b: f128 = 2.5; + const c: f128 = 6.25; try expect(@mulAdd(f128, a, b, c) == 20); }