mirror of
https://github.com/ziglang/zig.git
synced 2026-01-21 06:45:24 +00:00
Implement @mulAdd for vectors
This commit is contained in:
parent
312536540b
commit
c8ed813097
124
src/Sema.zig
124
src/Sema.zig
@ -14499,19 +14499,24 @@ fn zirMulAdd(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
|
||||
|
||||
const target = sema.mod.getTarget();
|
||||
|
||||
const maybe_mulend1 = try sema.resolveMaybeUndefVal(block, mulend1_src, mulend1);
|
||||
const maybe_mulend2 = try sema.resolveMaybeUndefVal(block, mulend2_src, mulend2);
|
||||
const maybe_addend = try sema.resolveMaybeUndefVal(block, addend_src, addend);
|
||||
|
||||
switch (ty.zigTypeTag()) {
|
||||
.ComptimeFloat, .Float => {
|
||||
const maybe_mulend1 = try sema.resolveMaybeUndefVal(block, mulend1_src, mulend1);
|
||||
const maybe_mulend2 = try sema.resolveMaybeUndefVal(block, mulend2_src, mulend2);
|
||||
const maybe_addend = try sema.resolveMaybeUndefVal(block, addend_src, addend);
|
||||
.ComptimeFloat, .Float, .Vector => {},
|
||||
else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{ty}),
|
||||
}
|
||||
|
||||
const runtime_src = if (maybe_mulend1) |mulend1_val| rs: {
|
||||
if (maybe_mulend2) |mulend2_val| {
|
||||
if (mulend2_val.isUndef()) return sema.addConstUndef(ty);
|
||||
const runtime_src = if (maybe_mulend1) |mulend1_val| rs: {
|
||||
if (maybe_mulend2) |mulend2_val| {
|
||||
if (mulend2_val.isUndef()) return sema.addConstUndef(ty);
|
||||
|
||||
if (maybe_addend) |addend_val| {
|
||||
if (addend_val.isUndef()) return sema.addConstUndef(ty);
|
||||
if (maybe_addend) |addend_val| {
|
||||
if (addend_val.isUndef()) return sema.addConstUndef(ty);
|
||||
|
||||
switch (ty.zigTypeTag()) {
|
||||
.ComptimeFloat, .Float => {
|
||||
const result_val = try Value.mulAdd(
|
||||
ty,
|
||||
mulend1_val,
|
||||
@ -14521,47 +14526,70 @@ fn zirMulAdd(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
|
||||
target,
|
||||
);
|
||||
return sema.addConstant(ty, result_val);
|
||||
} else {
|
||||
break :rs addend_src;
|
||||
}
|
||||
} else {
|
||||
if (maybe_addend) |addend_val| {
|
||||
if (addend_val.isUndef()) return sema.addConstUndef(ty);
|
||||
}
|
||||
break :rs mulend2_src;
|
||||
}
|
||||
} else rs: {
|
||||
if (maybe_mulend2) |mulend2_val| {
|
||||
if (mulend2_val.isUndef()) return sema.addConstUndef(ty);
|
||||
}
|
||||
if (maybe_addend) |addend_val| {
|
||||
if (addend_val.isUndef()) return sema.addConstUndef(ty);
|
||||
}
|
||||
break :rs mulend1_src;
|
||||
};
|
||||
},
|
||||
.Vector => {
|
||||
const scalar_ty = ty.scalarType();
|
||||
switch (scalar_ty.zigTypeTag()) {
|
||||
.ComptimeFloat, .Float => {},
|
||||
else => return sema.fail(block, src, "expected vector of floats, found vector of '{}'", .{scalar_ty}),
|
||||
}
|
||||
|
||||
try sema.requireRuntimeBlock(block, runtime_src);
|
||||
return block.addInst(.{
|
||||
.tag = .mul_add,
|
||||
.data = .{ .pl_op = .{
|
||||
.operand = addend,
|
||||
.payload = try sema.addExtra(Air.Bin{
|
||||
.lhs = mulend1,
|
||||
.rhs = mulend2,
|
||||
}),
|
||||
} },
|
||||
});
|
||||
},
|
||||
.Vector => {
|
||||
const scalar_ty = ty.scalarType();
|
||||
switch (scalar_ty.zigTypeTag()) {
|
||||
.ComptimeFloat, .Float => {},
|
||||
else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{scalar_ty}),
|
||||
const vec_len = ty.vectorLen();
|
||||
const result_ty = try Type.vector(sema.arena, vec_len, scalar_ty);
|
||||
var mulend1_buf: Value.ElemValueBuffer = undefined;
|
||||
var mulend2_buf: Value.ElemValueBuffer = undefined;
|
||||
var addend_buf: Value.ElemValueBuffer = undefined;
|
||||
const elems = try sema.arena.alloc(Value, vec_len);
|
||||
for (elems) |*elem, i| {
|
||||
const mulend1_elem_val = mulend1_val.elemValueBuffer(i, &mulend1_buf);
|
||||
const mulend2_elem_val = mulend2_val.elemValueBuffer(i, &mulend2_buf);
|
||||
const addend_elem_val = addend_val.elemValueBuffer(i, &addend_buf);
|
||||
elem.* = try Value.mulAdd(
|
||||
scalar_ty,
|
||||
mulend1_elem_val,
|
||||
mulend2_elem_val,
|
||||
addend_elem_val,
|
||||
sema.arena,
|
||||
target,
|
||||
);
|
||||
}
|
||||
return sema.addConstant(
|
||||
result_ty,
|
||||
try Value.Tag.aggregate.create(sema.arena, elems),
|
||||
);
|
||||
},
|
||||
else => unreachable,
|
||||
}
|
||||
} else {
|
||||
break :rs addend_src;
|
||||
}
|
||||
return sema.fail(block, src, "TODO: implement @mulAdd for vectors", .{});
|
||||
},
|
||||
else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{ty}),
|
||||
}
|
||||
} else {
|
||||
if (maybe_addend) |addend_val| {
|
||||
if (addend_val.isUndef()) return sema.addConstUndef(ty);
|
||||
}
|
||||
break :rs mulend2_src;
|
||||
}
|
||||
} else rs: {
|
||||
if (maybe_mulend2) |mulend2_val| {
|
||||
if (mulend2_val.isUndef()) return sema.addConstUndef(ty);
|
||||
}
|
||||
if (maybe_addend) |addend_val| {
|
||||
if (addend_val.isUndef()) return sema.addConstUndef(ty);
|
||||
}
|
||||
break :rs mulend1_src;
|
||||
};
|
||||
|
||||
try sema.requireRuntimeBlock(block, runtime_src);
|
||||
return block.addInst(.{
|
||||
.tag = .mul_add,
|
||||
.data = .{ .pl_op = .{
|
||||
.operand = addend,
|
||||
.payload = try sema.addExtra(Air.Bin{
|
||||
.lhs = mulend1,
|
||||
.rhs = mulend2,
|
||||
}),
|
||||
} },
|
||||
});
|
||||
}
|
||||
|
||||
fn zirBuiltinCall(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
|
||||
|
||||
@ -5166,7 +5166,13 @@ pub const FuncGen = struct {
|
||||
intrinsic,
|
||||
libc: [*:0]const u8,
|
||||
};
|
||||
const strat: Strat = switch (ty.floatBits(target)) {
|
||||
|
||||
const scalar_ty = if (ty.zigTypeTag() == .Vector)
|
||||
ty.elemType()
|
||||
else
|
||||
ty;
|
||||
|
||||
const strat: Strat = switch (scalar_ty.floatBits(target)) {
|
||||
16, 32, 64 => Strat.intrinsic,
|
||||
80 => if (CType.longdouble.sizeInBits(target) == 80) Strat{ .intrinsic = {} } else Strat{ .libc = "__fmax" },
|
||||
// LLVM always lowers the fma builtin for f128 to fmal, which is for `long double`.
|
||||
@ -5175,17 +5181,46 @@ pub const FuncGen = struct {
|
||||
else => unreachable,
|
||||
};
|
||||
|
||||
const llvm_fn = switch (strat) {
|
||||
.intrinsic => self.getIntrinsic("llvm.fma", &.{llvm_ty}),
|
||||
.libc => |fn_name| self.dg.object.llvm_module.getNamedFunction(fn_name) orelse b: {
|
||||
const param_types = [_]*const llvm.Type{ llvm_ty, llvm_ty, llvm_ty };
|
||||
const fn_type = llvm.functionType(llvm_ty, ¶m_types, param_types.len, .False);
|
||||
break :b self.dg.object.llvm_module.addFunction(fn_name, fn_type);
|
||||
switch (strat) {
|
||||
.intrinsic => {
|
||||
const llvm_fn = self.getIntrinsic("llvm.fma", &.{llvm_ty});
|
||||
const params = [_]*const llvm.Value{ mulend1, mulend2, addend };
|
||||
return self.builder.buildCall(llvm_fn, ¶ms, params.len, .C, .Auto, "");
|
||||
},
|
||||
};
|
||||
.libc => |fn_name| {
|
||||
const scalar_llvm_ty = try self.dg.llvmType(scalar_ty);
|
||||
const llvm_fn = self.dg.object.llvm_module.getNamedFunction(fn_name) orelse b: {
|
||||
const param_types = [_]*const llvm.Type{ scalar_llvm_ty, scalar_llvm_ty, scalar_llvm_ty };
|
||||
const fn_type = llvm.functionType(scalar_llvm_ty, ¶m_types, param_types.len, .False);
|
||||
break :b self.dg.object.llvm_module.addFunction(fn_name, fn_type);
|
||||
};
|
||||
|
||||
const params = [_]*const llvm.Value{ mulend1, mulend2, addend };
|
||||
return self.builder.buildCall(llvm_fn, ¶ms, params.len, .C, .Auto, "");
|
||||
if (ty.zigTypeTag() == .Vector) {
|
||||
const llvm_i32 = self.context.intType(32);
|
||||
const vector_llvm_ty = try self.dg.llvmType(ty);
|
||||
|
||||
var i: usize = 0;
|
||||
var vector = vector_llvm_ty.getUndef();
|
||||
while (i < ty.vectorLen()) : (i += 1) {
|
||||
const index_i32 = llvm_i32.constInt(i, .False);
|
||||
|
||||
const mulend1_elem = self.builder.buildExtractElement(mulend1, index_i32, "");
|
||||
const mulend2_elem = self.builder.buildExtractElement(mulend2, index_i32, "");
|
||||
const addend_elem = self.builder.buildExtractElement(addend, index_i32, "");
|
||||
|
||||
const params = [_]*const llvm.Value{ mulend1_elem, mulend2_elem, addend_elem };
|
||||
const mul_add = self.builder.buildCall(llvm_fn, ¶ms, params.len, .C, .Auto, "");
|
||||
|
||||
vector = self.builder.buildInsertElement(vector, mul_add, index_i32, "");
|
||||
}
|
||||
|
||||
return vector;
|
||||
} else {
|
||||
const params = [_]*const llvm.Value{ mulend1, mulend2, addend };
|
||||
return self.builder.buildCall(llvm_fn, ¶ms, params.len, .C, .Auto, "");
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn airShlWithOverflow(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
|
||||
|
||||
@ -78,3 +78,136 @@ fn testMulAdd128() !void {
|
||||
var c: f128 = 6.25;
|
||||
try expect(@mulAdd(f128, a, b, c) == 20);
|
||||
}
|
||||
|
||||
fn vector16() !void {
|
||||
var a = @Vector(4, f16){ 5.5, 5.5, 5.5, 5.5 };
|
||||
var b = @Vector(4, f16){ 2.5, 2.5, 2.5, 2.5 };
|
||||
var c = @Vector(4, f16){ 6.25, 6.25, 6.25, 6.25 };
|
||||
var x = @mulAdd(@Vector(4, f16), a, b, c);
|
||||
|
||||
// TODO use `expectEqual` instead once stage2 supports it
|
||||
// var expected = @Vector(4, f16){ 20, 20, 20, 20 };
|
||||
// try expectEqual(expected, x);
|
||||
|
||||
try expect(x[0] == 20);
|
||||
try expect(x[1] == 20);
|
||||
try expect(x[2] == 20);
|
||||
try expect(x[3] == 20);
|
||||
}
|
||||
|
||||
test "vector f16" {
|
||||
if (builtin.zig_backend == .stage1) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
|
||||
comptime try vector16();
|
||||
try vector16();
|
||||
}
|
||||
|
||||
fn vector32() !void {
|
||||
var a = @Vector(4, f32){ 5.5, 5.5, 5.5, 5.5 };
|
||||
var b = @Vector(4, f32){ 2.5, 2.5, 2.5, 2.5 };
|
||||
var c = @Vector(4, f32){ 6.25, 6.25, 6.25, 6.25 };
|
||||
var x = @mulAdd(@Vector(4, f32), a, b, c);
|
||||
|
||||
// TODO use `expectEqual` instead once stage2 supports it
|
||||
// var expected = @Vector(4, f32){ 20, 20, 20, 20 };
|
||||
// try expectEqual(expected, x);
|
||||
|
||||
try expect(x[0] == 20);
|
||||
try expect(x[1] == 20);
|
||||
try expect(x[2] == 20);
|
||||
try expect(x[3] == 20);
|
||||
}
|
||||
|
||||
test "vector f32" {
|
||||
if (builtin.zig_backend == .stage1) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
|
||||
comptime try vector32();
|
||||
try vector32();
|
||||
}
|
||||
|
||||
fn vector64() !void {
|
||||
var a = @Vector(4, f64){ 5.5, 5.5, 5.5, 5.5 };
|
||||
var b = @Vector(4, f64){ 2.5, 2.5, 2.5, 2.5 };
|
||||
var c = @Vector(4, f64){ 6.25, 6.25, 6.25, 6.25 };
|
||||
var x = @mulAdd(@Vector(4, f64), a, b, c);
|
||||
|
||||
// TODO use `expectEqual` instead once stage2 supports it
|
||||
// var expected = @Vector(4, f64){ 20, 20, 20, 20 };
|
||||
// try expectEqual(expected, x);
|
||||
|
||||
try expect(x[0] == 20);
|
||||
try expect(x[1] == 20);
|
||||
try expect(x[2] == 20);
|
||||
try expect(x[3] == 20);
|
||||
}
|
||||
|
||||
test "vector f64" {
|
||||
if (builtin.zig_backend == .stage1) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
|
||||
comptime try vector64();
|
||||
try vector64();
|
||||
}
|
||||
|
||||
fn vector80() !void {
|
||||
var a = @Vector(4, f80){ 5.5, 5.5, 5.5, 5.5 };
|
||||
var b = @Vector(4, f80){ 2.5, 2.5, 2.5, 2.5 };
|
||||
var c = @Vector(4, f80){ 6.25, 6.25, 6.25, 6.25 };
|
||||
var x = @mulAdd(@Vector(4, f80), a, b, c);
|
||||
try expect(x[0] == 20);
|
||||
try expect(x[1] == 20);
|
||||
try expect(x[2] == 20);
|
||||
try expect(x[3] == 20);
|
||||
}
|
||||
|
||||
test "vector f80" {
|
||||
if (true) {
|
||||
// https://github.com/ziglang/zig/issues/11030
|
||||
return error.SkipZigTest;
|
||||
}
|
||||
|
||||
comptime try vector80();
|
||||
try vector80();
|
||||
}
|
||||
|
||||
fn vector128() !void {
|
||||
var a = @Vector(4, f128){ 5.5, 5.5, 5.5, 5.5 };
|
||||
var b = @Vector(4, f128){ 2.5, 2.5, 2.5, 2.5 };
|
||||
var c = @Vector(4, f128){ 6.25, 6.25, 6.25, 6.25 };
|
||||
var x = @mulAdd(@Vector(4, f128), a, b, c);
|
||||
|
||||
// TODO use `expectEqual` instead once stage2 supports it
|
||||
// var expected = @Vector(4, f128){ 20, 20, 20, 20 };
|
||||
// try expectEqual(expected, x);
|
||||
|
||||
try expect(x[0] == 20);
|
||||
try expect(x[1] == 20);
|
||||
try expect(x[2] == 20);
|
||||
try expect(x[3] == 20);
|
||||
}
|
||||
|
||||
test "vector f128" {
|
||||
if (builtin.zig_backend == .stage1) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
|
||||
comptime try vector128();
|
||||
try vector128();
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user