spirv: use new vector stuff for arithOp and shift

This commit is contained in:
Robin Voetter 2024-01-15 23:38:43 +01:00
parent cb9e20da00
commit 403c6262bb
No known key found for this signature in database

View File

@ -1782,6 +1782,19 @@ const DeclGen = struct {
wip.dg.gpa.free(wip.results);
}
/// Return the scalar type of an input vector. This type is expected to be a vector
/// if `wip.is_vector`, and a scalar otherwise.
fn scalarType(wip: WipElementWise, ty: Type) Type {
const mod = wip.dg.module;
if (wip.is_vector) {
assert(ty.isVector(mod));
return ty.childType(mod);
} else {
assert(!ty.isVector(mod));
return ty;
}
}
/// Utility function to extract the element at a particular index in an
/// input vector. This type is expected to be a vector if `wip.is_vector`, and
/// a scalar otherwise.
@ -1789,7 +1802,7 @@ const DeclGen = struct {
const mod = wip.dg.module;
if (wip.is_vector) {
assert(ty.isVector(mod));
return try wip.dg.extractField(ty, value, @intCast(index));
return try wip.dg.extractField(ty.childType(mod), value, @intCast(index));
} else {
assert(!ty.isVector(mod));
assert(index == 0);
@ -2331,36 +2344,45 @@ const DeclGen = struct {
const lhs_id = try self.resolve(bin_op.lhs);
const rhs_id = try self.resolve(bin_op.rhs);
const result_ty = self.typeOfIndex(inst);
const result_ty_ref = try self.resolveType(result_ty, .direct);
const result_id = self.spv.allocId();
// Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
// so just manually upcast it if required.
const shift_ty_ref = try self.resolveType(self.typeOf(bin_op.rhs), .direct);
const shift_id = if (shift_ty_ref != result_ty_ref) blk: {
const shift_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
.id_result_type = self.typeId(result_ty_ref),
.id_result = shift_id,
.unsigned_value = rhs_id,
});
break :blk shift_id;
} else rhs_id;
// TODO(robin)
const args = .{
.id_result_type = self.typeId(result_ty_ref),
.id_result = result_id,
.base = lhs_id,
.shift = shift_id,
};
var wip = try self.elementWise(result_ty);
defer wip.deinit();
if (result_ty.isSignedInt(mod)) {
try self.func.body.emit(self.spv.gpa, signed, args);
} else {
try self.func.body.emit(self.spv.gpa, unsigned, args);
const shift_ty = wip.scalarType(self.typeOf(bin_op.rhs));
const shift_ty_ref = try self.resolveType(shift_ty, .direct);
for (0..wip.results.len) |i| {
const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i);
const shift_id = if (shift_ty_ref != wip.result_ty_ref) blk: {
const shift_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
.id_result_type = wip.scalar_ty_id,
.id_result = shift_id,
.unsigned_value = rhs_elem_id,
});
break :blk shift_id;
} else rhs_elem_id;
const args = .{
.id_result_type = wip.scalar_ty_id,
.id_result = wip.allocId(i),
.base = lhs_elem_id,
.shift = shift_id,
};
if (result_ty.isSignedInt(mod)) {
try self.func.body.emit(self.spv.gpa, signed, args);
} else {
try self.func.body.emit(self.spv.gpa, unsigned, args);
}
}
return result_id;
return try wip.finalize();
}
fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: std.math.CompareOperator) !?IdRef {
@ -2483,35 +2505,14 @@ const DeclGen = struct {
fn arithOp(
self: *DeclGen,
ty: Type,
lhs_id_: IdRef,
rhs_id_: IdRef,
lhs_id: IdRef,
rhs_id: IdRef,
comptime fop: Opcode,
comptime sop: Opcode,
comptime uop: Opcode,
/// true if this operation holds under modular arithmetic.
comptime modular: bool,
) !IdRef {
var rhs_id = rhs_id_;
var lhs_id = lhs_id_;
const mod = self.module;
const result_ty_ref = try self.resolveType(ty, .direct);
if (ty.isVector(mod)) {
const child_ty = ty.childType(mod);
const vector_len = ty.vectorLen(mod);
const constituents = try self.gpa.alloc(IdRef, vector_len);
defer self.gpa.free(constituents);
for (constituents, 0..) |*constituent, i| {
const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i));
const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i));
constituent.* = try self.arithOp(child_ty, lhs_index_id, rhs_index_id, fop, sop, uop, modular);
}
return self.constructArray(ty, constituents);
}
// Binary operations are generally applicable to both scalar and vector operations
// in SPIR-V, but int and float versions of operations require different opcodes.
const info = try self.arithmeticTypeInfo(ty);
@ -2520,17 +2521,7 @@ const DeclGen = struct {
.composite_integer => {
return self.todo("binary operations for composite integers", .{});
},
.strange_integer => blk: {
if (!modular) {
lhs_id = try self.normalizeInt(result_ty_ref, lhs_id, info);
rhs_id = try self.normalizeInt(result_ty_ref, rhs_id, info);
}
break :blk switch (info.signedness) {
.signed => @as(usize, 1),
.unsigned => @as(usize, 2),
};
},
.integer => switch (info.signedness) {
.integer, .strange_integer => switch (info.signedness) {
.signed => @as(usize, 1),
.unsigned => @as(usize, 2),
},
@ -2538,24 +2529,41 @@ const DeclGen = struct {
.bool => unreachable,
};
const result_id = self.spv.allocId();
const operands = .{
.id_result_type = self.typeId(result_ty_ref),
.id_result = result_id,
.operand_1 = lhs_id,
.operand_2 = rhs_id,
};
var wip = try self.elementWise(ty);
defer wip.deinit();
for (0..wip.results.len) |i| {
const lhs_elem_id = try wip.elementAt(ty, lhs_id, i);
const rhs_elem_id = try wip.elementAt(ty, rhs_id, i);
switch (opcode_index) {
0 => try self.func.body.emit(self.spv.gpa, fop, operands),
1 => try self.func.body.emit(self.spv.gpa, sop, operands),
2 => try self.func.body.emit(self.spv.gpa, uop, operands),
else => unreachable,
const lhs_norm_id = if (modular and info.class == .strange_integer)
try self.normalizeInt(wip.scalar_ty_ref, lhs_elem_id, info)
else
lhs_elem_id;
const rhs_norm_id = if (modular and info.class == .strange_integer)
try self.normalizeInt(wip.scalar_ty_ref, rhs_elem_id, info)
else
rhs_elem_id;
const operands = .{
.id_result_type = wip.scalar_ty_id,
.id_result = wip.allocId(i),
.operand_1 = lhs_norm_id,
.operand_2 = rhs_norm_id,
};
switch (opcode_index) {
0 => try self.func.body.emit(self.spv.gpa, fop, operands),
1 => try self.func.body.emit(self.spv.gpa, sop, operands),
2 => try self.func.body.emit(self.spv.gpa, uop, operands),
else => unreachable,
}
// TODO: Trap on overflow? Probably going to be annoying.
// TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap.
}
// TODO: Trap on overflow? Probably going to be annoying.
// TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap.
return result_id;
return try wip.finalize();
}
fn airAddSubOverflow(