spirv: use extended instructions whenever possible

This commit is contained in:
Ali Chraghi 2024-02-15 17:25:33 +03:30
parent 6fe90a913a
commit 44c31194e3
3 changed files with 139 additions and 81 deletions

View File

@ -632,11 +632,15 @@ const DeclGen = struct {
/// Checks whether the type can be directly translated to SPIR-V vectors
fn isVector(self: *DeclGen, ty: Type) bool {
const mod = self.module;
const target = self.getTarget();
if (ty.zigTypeTag(mod) != .Vector) return false;
const elem_ty = ty.childType(mod);
const len = ty.vectorLen(mod);
const is_scalar = elem_ty.isNumeric(mod) or elem_ty.toIntern() == .bool_type;
return is_scalar and len > 1 and len <= 4;
const spirv_len = len > 1 and len <= 4;
const opencl_len = if (target.os.tag == .opencl) (len == 8 or len == 16) else false;
return is_scalar and (spirv_len or opencl_len);
}
fn arithmeticTypeInfo(self: *DeclGen, ty: Type) ArithmeticTypeInfo {
@ -1968,7 +1972,10 @@ const DeclGen = struct {
try self.func.prologue.emit(self.spv.gpa, .OpFunction, .{
.id_result_type = self.typeId(return_ty_ref),
.id_result = decl_id,
.function_control = .{}, // TODO: We can set inline here if the type requires it.
.function_control = switch (fn_info.cc) {
.Inline => .{ .Inline = true },
else => .{},
},
.function_type = prototype_id,
});
@ -2437,48 +2444,71 @@ const DeclGen = struct {
fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef {
const info = self.arithmeticTypeInfo(result_ty);
const target = self.getTarget();
var wip = try self.elementWise(result_ty, true);
const use_backup_codegen = target.os.tag == .opencl and info.class != .float;
var wip = try self.elementWise(result_ty, use_backup_codegen);
defer wip.deinit();
for (wip.results, 0..) |*result_id, i| {
const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i);
// TODO: Use fmin for OpenCL
if (use_backup_codegen) {
const cmp_id = try self.cmp(op, Type.bool, wip.ty, lhs_elem_id, rhs_elem_id);
const selection_id = switch (info.class) {
.float => blk: {
// cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
// but we want it to pick lhs. Therefore we also have to check if
// rhs is nan. We don't need to care about the result when both
// are nan.
const rhs_is_nan_id = self.spv.allocId();
const bool_ty_ref = try self.resolveType(Type.bool, .direct);
try self.func.body.emit(self.spv.gpa, .OpIsNan, .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = rhs_is_nan_id,
.x = rhs_elem_id,
});
const float_cmp_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = float_cmp_id,
.operand_1 = cmp_id,
.operand_2 = rhs_is_nan_id,
});
break :blk float_cmp_id;
},
else => cmp_id,
};
result_id.* = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpSelect, .{
.id_result_type = wip.ty_id,
.id_result = result_id.*,
.condition = selection_id,
.condition = cmp_id,
.object_1 = lhs_elem_id,
.object_2 = rhs_elem_id,
});
} else {
const ext_inst: Word = switch (target.os.tag) {
.opencl => switch (op) {
.lt => 28, // fmin
.gt => 27, // fmax
else => unreachable,
},
.vulkan => switch (info.class) {
.float => switch (op) {
.lt => 37, // FMin
.gt => 40, // FMax
else => unreachable,
},
.integer, .strange_integer => switch (info.signedness) {
.signed => switch (op) {
.lt => 39, // SMin
.gt => 42, // SMax
else => unreachable,
},
.unsigned => switch (op) {
.lt => 38, // UMin
.gt => 41, // UMax
else => unreachable,
},
},
.composite_integer => unreachable, // TODO
.bool => unreachable,
},
else => unreachable,
};
const set_id = switch (target.os.tag) {
.opencl => try self.spv.importInstructionSet(.opencl),
.vulkan => try self.spv.importInstructionSet(.glsl),
else => unreachable,
};
result_id.* = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpExtInst, .{
.id_result_type = wip.ty_id,
.id_result = result_id.*,
.set = set_id,
.instruction = .{ .inst = ext_inst },
.id_ref_4 = &.{ lhs_elem_id, rhs_elem_id },
});
}
}
return wip.finalize();
}
@ -2607,57 +2637,52 @@ const DeclGen = struct {
}
fn airAbs(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
const mod = self.module;
const target = self.getTarget();
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const operand_id = try self.resolve(ty_op.operand);
// Note: operand_ty may be signed, while ty is always unsigned!
const operand_ty = self.typeOf(ty_op.operand);
const result_ty = self.typeOfIndex(inst);
const info = self.arithmeticTypeInfo(result_ty);
const operand_scalar_ty = operand_ty.scalarType(mod);
const operand_scalar_ty_ref = try self.resolveType(operand_scalar_ty, .direct);
const operand_info = self.arithmeticTypeInfo(operand_ty);
var wip = try self.elementWise(result_ty, true);
var wip = try self.elementWise(result_ty, false);
defer wip.deinit();
const zero_id = switch (info.class) {
.float => try self.constFloat(operand_scalar_ty_ref, 0),
.integer, .strange_integer => try self.constInt(operand_scalar_ty_ref, 0),
.composite_integer => unreachable, // TODO
.bool => unreachable,
};
for (wip.results, 0..) |*result_id, i| {
const elem_id = try wip.elementAt(operand_ty, operand_id, i);
// Idk why spir-v doesn't have a dedicated abs() instruction in the base
// instruction set. For now we're just going to negate and check to avoid
// importing the extinst.
// TODO: Make this a call to compiler rt / ext inst
const neg_id = self.spv.allocId();
const args = .{
.id_result_type = self.typeId(operand_scalar_ty_ref),
.id_result = neg_id,
.operand_1 = zero_id,
.operand_2 = elem_id,
};
switch (info.class) {
.float => try self.func.body.emit(self.spv.gpa, .OpFSub, args),
.integer, .strange_integer => try self.func.body.emit(self.spv.gpa, .OpISub, args),
const ext_inst: Word = switch (target.os.tag) {
.opencl => switch (operand_info.class) {
.float => 23, // fabs
.integer, .strange_integer => switch (operand_info.signedness) {
.signed => 141, // s_abs
.unsigned => 201, // u_abs
},
.composite_integer => unreachable, // TODO
.bool => unreachable,
}
const neg_norm_id = try self.normalize(wip.ty_ref, neg_id, info);
},
.vulkan => switch (operand_info.class) {
.float => 4, // FAbs
.integer, .strange_integer => 5, // SAbs
.composite_integer => unreachable, // TODO
.bool => unreachable,
},
else => unreachable,
};
const set_id = switch (target.os.tag) {
.opencl => try self.spv.importInstructionSet(.opencl),
.vulkan => try self.spv.importInstructionSet(.glsl),
else => unreachable,
};
const gt_zero_id = try self.cmp(.gt, Type.bool, operand_scalar_ty, elem_id, zero_id);
const abs_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpSelect, .{
.id_result_type = self.typeId(operand_scalar_ty_ref),
.id_result = abs_id,
.condition = gt_zero_id,
.object_1 = elem_id,
.object_2 = neg_norm_id,
result_id.* = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpExtInst, .{
.id_result_type = wip.ty_id,
.id_result = result_id.*,
.set = set_id,
.instruction = .{ .inst = ext_inst },
.id_ref_4 = &.{elem_id},
});
// For Shader, we may need to cast from signed to unsigned here.
result_id.* = try self.bitCast(wip.ty, operand_scalar_ty, abs_id);
}
return try wip.finalize();
}

View File

@ -114,8 +114,10 @@ sections: struct {
capabilities: Section = .{},
/// OpExtension instructions
extensions: Section = .{},
// OpExtInstImport instructions - skip for now.
// memory model defined by target, not required here.
/// OpExtInstImport
extended_instruction_set: Section = .{},
/// memory model defined by target
memory_model: Section = .{},
/// OpEntryPoint instructions - Handled by `self.entry_points`.
/// OpExecutionMode and OpExecutionModeId instructions.
execution_modes: Section = .{},
@ -172,6 +174,9 @@ globals: struct {
section: Section = .{},
} = .{},
/// The list of extended instruction sets that should be imported.
extended_instruction_set: std.AutoHashMapUnmanaged(ExtendedInstructionSet, IdRef) = .{},
pub fn init(gpa: Allocator) Module {
return .{
.gpa = gpa,
@ -182,6 +187,8 @@ pub fn init(gpa: Allocator) Module {
pub fn deinit(self: *Module) void {
self.sections.capabilities.deinit(self.gpa);
self.sections.extensions.deinit(self.gpa);
self.sections.extended_instruction_set.deinit(self.gpa);
self.sections.memory_model.deinit(self.gpa);
self.sections.execution_modes.deinit(self.gpa);
self.sections.debug_strings.deinit(self.gpa);
self.sections.debug_names.deinit(self.gpa);
@ -200,6 +207,8 @@ pub fn deinit(self: *Module) void {
self.globals.globals.deinit(self.gpa);
self.globals.section.deinit(self.gpa);
self.extended_instruction_set.deinit(self.gpa);
self.* = undefined;
}
@ -448,6 +457,8 @@ pub fn flush(self: *Module, file: std.fs.File, target: std.Target) !void {
&header,
self.sections.capabilities.toWords(),
self.sections.extensions.toWords(),
self.sections.extended_instruction_set.toWords(),
self.sections.memory_model.toWords(),
entry_points.toWords(),
self.sections.execution_modes.toWords(),
source.toWords(),
@ -482,6 +493,29 @@ pub fn addFunction(self: *Module, decl_index: Decl.Index, func: Fn) !void {
try self.declareDeclDeps(decl_index, func.decl_deps.keys());
}
pub const ExtendedInstructionSet = enum {
glsl,
opencl,
};
/// Imports or returns the existing id of an extended instruction set
pub fn importInstructionSet(self: *Module, set: ExtendedInstructionSet) !IdRef {
const gop = try self.extended_instruction_set.getOrPut(self.gpa, set);
if (gop.found_existing) return gop.value_ptr.*;
const result_id = self.allocId();
try self.sections.extended_instruction_set.emit(self.gpa, .OpExtInstImport, .{
.id_result = result_id,
.name = switch (set) {
.glsl => "GLSL.std.450",
.opencl => "OpenCL.std",
},
});
gop.value_ptr.* = result_id;
return result_id;
}
/// Fetch the result-id of an OpString instruction that encodes the path of the source
/// file of the decl. This function may also emit an OpSource with source-level information regarding
/// the decl.

View File

@ -246,7 +246,7 @@ fn writeCapabilities(spv: *SpvModule, target: std.Target) !void {
const gpa = spv.gpa;
// TODO: Integrate with a hypothetical feature system
const caps: []const spec.Capability = switch (target.os.tag) {
.opencl => &.{ .Kernel, .Addresses, .Int8, .Int16, .Int64, .Float64, .Float16, .GenericPointer },
.opencl => &.{ .Kernel, .Addresses, .Int8, .Int16, .Int64, .Float64, .Float16, .Vector16, .GenericPointer },
.glsl450 => &.{.Shader},
.vulkan => &.{ .Shader, .VariablePointersStorageBuffer, .Int8, .Int16, .Int64, .Float64, .Float16 },
else => unreachable, // TODO
@ -279,8 +279,7 @@ fn writeMemoryModel(spv: *SpvModule, target: std.Target) !void {
else => unreachable,
};
// TODO: Put this in a proper section.
try spv.sections.extensions.emit(gpa, .OpMemoryModel, .{
try spv.sections.memory_model.emit(gpa, .OpMemoryModel, .{
.addressing_model = addressing_model,
.memory_model = memory_model,
});