spirv: air mul_add

This commit is contained in:
Robin Voetter 2024-01-21 12:17:19 +01:00
parent 345d6e280d
commit 7dfd403da1
No known key found for this signature in database
2 changed files with 73 additions and 42 deletions

View File

@ -2160,9 +2160,9 @@ const DeclGen = struct {
const air_tags = self.air.instructions.items(.tag);
const maybe_result_id: ?IdRef = switch (air_tags[@intFromEnum(inst)]) {
// zig fmt: off
.add, .add_wrap => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd),
.sub, .sub_wrap => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub),
.mul, .mul_wrap => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul),
.add, .add_wrap, .add_optimized => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd),
.sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub),
.mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul),
.div_float,
.div_float_optimized,
@ -2179,6 +2179,8 @@ const DeclGen = struct {
.sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan),
.shl_with_overflow => try self.airShlOverflow(inst),
.mul_add => try self.airMulAdd(inst),
.reduce, .reduce_optimized => try self.airReduce(inst),
.shuffle => try self.airShuffle(inst),
@ -2439,40 +2441,38 @@ const DeclGen = struct {
switch (info.class) {
.integer, .bool, .float => return value_id,
.composite_integer => unreachable, // TODO
.strange_integer => {
switch (info.signedness) {
.unsigned => {
const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1;
const result_id = self.spv.allocId();
const mask_id = try self.constInt(ty_ref, mask_value);
try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
.id_result_type = self.typeId(ty_ref),
.id_result = result_id,
.operand_1 = value_id,
.operand_2 = mask_id,
});
return result_id;
},
.signed => {
// Shift left and right so that we can copy the sight bit that way.
const shift_amt_id = try self.constInt(ty_ref, info.backing_bits - info.bits);
const left_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
.id_result_type = self.typeId(ty_ref),
.id_result = left_id,
.base = value_id,
.shift = shift_amt_id,
});
const right_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{
.id_result_type = self.typeId(ty_ref),
.id_result = right_id,
.base = left_id,
.shift = shift_amt_id,
});
return right_id;
},
}
.strange_integer => switch (info.signedness) {
.unsigned => {
const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1;
const result_id = self.spv.allocId();
const mask_id = try self.constInt(ty_ref, mask_value);
try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
.id_result_type = self.typeId(ty_ref),
.id_result = result_id,
.operand_1 = value_id,
.operand_2 = mask_id,
});
return result_id;
},
.signed => {
// Shift left and right so that we can copy the sight bit that way.
const shift_amt_id = try self.constInt(ty_ref, info.backing_bits - info.bits);
const left_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
.id_result_type = self.typeId(ty_ref),
.id_result = left_id,
.base = value_id,
.shift = shift_amt_id,
});
const right_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{
.id_result_type = self.typeId(ty_ref),
.id_result = right_id,
.base = left_id,
.shift = shift_amt_id,
});
return right_id;
},
},
}
}
@ -2761,6 +2761,42 @@ const DeclGen = struct {
);
}
fn airMulAdd(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
if (self.liveness.isUnused(inst)) return null;
const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
const mulend1 = try self.resolve(extra.lhs);
const mulend2 = try self.resolve(extra.rhs);
const addend = try self.resolve(pl_op.operand);
const ty = self.typeOfIndex(inst);
const info = self.arithmeticTypeInfo(ty);
assert(info.class == .float); // .mul_add is only emitted for floats
var wip = try self.elementWise(ty);
defer wip.deinit();
for (0..wip.results.len) |i| {
const mul_result = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpFMul, .{
.id_result_type = wip.scalar_ty_id,
.id_result = mul_result,
.operand_1 = try wip.elementAt(ty, mulend1, i),
.operand_2 = try wip.elementAt(ty, mulend2, i),
});
try self.func.body.emit(self.spv.gpa, .OpFAdd, .{
.id_result_type = wip.scalar_ty_id,
.id_result = wip.allocId(i),
.operand_1 = mul_result,
.operand_2 = try wip.elementAt(ty, addend, i),
});
}
return try wip.finalize();
}
fn airReduce(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
if (self.liveness.isUnused(inst)) return null;
const mod = self.module;

View File

@ -10,7 +10,6 @@ test "@mulAdd" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
try comptime testMulAdd();
try testMulAdd();
@ -37,7 +36,6 @@ test "@mulAdd f16" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest;
try comptime testMulAdd16();
@ -111,7 +109,6 @@ test "vector f16" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
try comptime vector16();
try vector16();
@ -136,7 +133,6 @@ test "vector f32" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
try comptime vector32();
try vector32();
@ -161,7 +157,6 @@ test "vector f64" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
try comptime vector64();
try vector64();