spirv: element-wise operation helper

This commit is contained in:
Robin Voetter 2024-01-15 23:06:54 +01:00
parent 747f4ae3f5
commit cb9e20da00
No known key found for this signature in database
2 changed files with 97 additions and 27 deletions

View File

@ -1760,6 +1760,92 @@ const DeclGen = struct {
return union_layout;
}
/// This structure is used as helper for element-wise operations. It is intended
/// to be used with both vectors and single elements.
const WipElementWise = struct {
dg: *DeclGen,
result_ty: Type,
/// Always in direct representation.
result_ty_ref: CacheRef,
scalar_ty: Type,
/// Always in direct representation.
scalar_ty_ref: CacheRef,
scalar_ty_id: IdRef,
/// True if the input is actually a vector type.
is_vector: bool,
/// The element-wise operation should fill these results before calling finalize().
/// These should all be in **direct** representation! `finalize()` will convert
/// them to indirect if required.
results: []IdRef,
fn deinit(wip: *WipElementWise) void {
wip.dg.gpa.free(wip.results);
}
/// Utility function to extract the element at a particular index in an
/// input vector. This type is expected to be a vector if `wip.is_vector`, and
/// a scalar otherwise.
fn elementAt(wip: WipElementWise, ty: Type, value: IdRef, index: usize) !IdRef {
const mod = wip.dg.module;
if (wip.is_vector) {
assert(ty.isVector(mod));
return try wip.dg.extractField(ty, value, @intCast(index));
} else {
assert(!ty.isVector(mod));
assert(index == 0);
return value;
}
}
/// Turns the results of this WipElementWise into a result. This can either
/// be a vector or single element, depending on `result_ty`.
/// After calling this function, this WIP is no longer usable.
/// Results is in `direct` representation.
fn finalize(wip: *WipElementWise) !IdRef {
if (wip.is_vector) {
// Convert all the constituents to indirect, as required for the array.
for (wip.results) |*result| {
result.* = try wip.dg.convertToIndirect(wip.scalar_ty, result.*);
}
return try wip.dg.constructArray(wip.result_ty, wip.results);
} else {
return wip.results[0];
}
}
/// Allocate a result id at a particular index, and return it.
fn allocId(wip: *WipElementWise, index: usize) IdRef {
assert(wip.is_vector or index == 0);
wip.results[index] = wip.dg.spv.allocId();
return wip.results[index];
}
};
/// Create a new element-wise operation.
fn elementWise(self: *DeclGen, result_ty: Type) !WipElementWise {
const mod = self.module;
// For now, this operation also reasons in terms of `.direct` representation.
const result_ty_ref = try self.resolveType(result_ty, .direct);
const is_vector = result_ty.isVector(mod);
const num_results = if (is_vector) result_ty.vectorLen(mod) else 1;
const results = try self.gpa.alloc(IdRef, num_results);
for (results) |*result| result.* = undefined;
const scalar_ty = if (is_vector) result_ty.childType(mod) else result_ty;
const scalar_ty_ref = try self.resolveType(scalar_ty, .direct);
return .{
.dg = self,
.result_ty = result_ty,
.result_ty_ref = result_ty_ref,
.scalar_ty = scalar_ty,
.scalar_ty_ref = scalar_ty_ref,
.scalar_ty_id = self.typeId(scalar_ty_ref),
.is_vector = is_vector,
.results = results,
};
}
/// The SPIR-V backend is not yet advanced enough to support the std testing infrastructure.
/// In order to be able to run tests, we "temporarily" lower test kernels into separate entry-
/// points. The test executor will then be able to invoke these to run the tests.
@ -2214,34 +2300,17 @@ const DeclGen = struct {
}
fn binOpSimple(self: *DeclGen, ty: Type, lhs_id: IdRef, rhs_id: IdRef, comptime opcode: Opcode) !IdRef {
const mod = self.module;
if (ty.isVector(mod)) {
const child_ty = ty.childType(mod);
const vector_len = ty.vectorLen(mod);
const constituents = try self.gpa.alloc(IdRef, vector_len);
defer self.gpa.free(constituents);
for (constituents, 0..) |*constituent, i| {
const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i));
const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i));
const result_id = try self.binOpSimple(child_ty, lhs_index_id, rhs_index_id, opcode);
constituent.* = try self.convertToIndirect(child_ty, result_id);
}
return try self.constructArray(ty, constituents);
var wip = try self.elementWise(ty);
defer wip.deinit();
for (0..wip.results.len) |i| {
try self.func.body.emit(self.spv.gpa, opcode, .{
.id_result_type = wip.scalar_ty_id,
.id_result = wip.allocId(i),
.operand_1 = try wip.elementAt(ty, lhs_id, i),
.operand_2 = try wip.elementAt(ty, rhs_id, i),
});
}
const result_id = self.spv.allocId();
const result_type_id = try self.resolveTypeId(ty);
try self.func.body.emit(self.spv.gpa, opcode, .{
.id_result_type = result_type_id,
.id_result = result_id,
.operand_1 = lhs_id,
.operand_2 = rhs_id,
});
return result_id;
return try wip.finalize();
}
fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef {

View File

@ -12,6 +12,7 @@ const math = std.math;
test "assignment operators" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
var i: u32 = 0;
i += 5;