From 9641d2ebdb74926a56ff3b916082534052dc637f Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sun, 21 Jan 2024 20:12:25 +0100 Subject: [PATCH] spirv: vectorize max, min --- src/codegen/spirv.zig | 99 ++++++++++++++++--------------- test/behavior/maximum_minimum.zig | 5 -- 2 files changed, 50 insertions(+), 54 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index eda8d88cdb..be3c9957e1 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -2391,45 +2391,51 @@ const DeclGen = struct { } fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef { - const result_ty_ref = try self.resolveType(result_ty, .direct); const info = self.arithmeticTypeInfo(result_ty); - // TODO: Use fmin for OpenCL - const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id); - const selection_id = switch (info.class) { - .float => blk: { - // cmp uses OpFOrd. When we have 0 [<>] nan this returns false, - // but we want it to pick lhs. Therefore we also have to check if - // rhs is nan. We don't need to care about the result when both - // are nan. - const rhs_is_nan_id = self.spv.allocId(); - const bool_ty_ref = try self.resolveType(Type.bool, .direct); - try self.func.body.emit(self.spv.gpa, .OpIsNan, .{ - .id_result_type = self.typeId(bool_ty_ref), - .id_result = rhs_is_nan_id, - .x = rhs_id, - }); - const float_cmp_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{ - .id_result_type = self.typeId(bool_ty_ref), - .id_result = float_cmp_id, - .operand_1 = cmp_id, - .operand_2 = rhs_is_nan_id, - }); - break :blk float_cmp_id; - }, - else => cmp_id, - }; + var wip = try self.elementWise(result_ty); + defer wip.deinit(); + for (wip.results, 0..) |*result_id, i| { + const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i); + const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = self.typeId(result_ty_ref), - .id_result = result_id, - .condition = selection_id, - .object_1 = lhs_id, - .object_2 = rhs_id, - }); - return result_id; + // TODO: Use fmin for OpenCL + const cmp_id = try self.cmp(op, Type.bool, wip.scalar_ty, lhs_elem_id, rhs_elem_id); + const selection_id = switch (info.class) { + .float => blk: { + // cmp uses OpFOrd. When we have 0 [<>] nan this returns false, + // but we want it to pick lhs. Therefore we also have to check if + // rhs is nan. We don't need to care about the result when both + // are nan. + const rhs_is_nan_id = self.spv.allocId(); + const bool_ty_ref = try self.resolveType(Type.bool, .direct); + try self.func.body.emit(self.spv.gpa, .OpIsNan, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = rhs_is_nan_id, + .x = rhs_elem_id, + }); + const float_cmp_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = float_cmp_id, + .operand_1 = cmp_id, + .operand_2 = rhs_is_nan_id, + }); + break :blk float_cmp_id; + }, + else => cmp_id, + }; + + result_id.* = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpSelect, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = result_id.*, + .condition = selection_id, + .object_1 = lhs_elem_id, + .object_2 = rhs_elem_id, + }); + } + return wip.finalize(); } /// This function normalizes values to a canonical representation @@ -3107,20 +3113,15 @@ const DeclGen = struct { return result_id; }, .Vector => { - const child_ty = ty.childType(mod); - const vector_len = ty.vectorLen(mod); - - const constituents = try self.gpa.alloc(IdRef, vector_len); - defer self.gpa.free(constituents); - - for (constituents, 0..) |*constituent, i| { - const lhs_index_id = try self.extractField(child_ty, cmp_lhs_id, @intCast(i)); - const rhs_index_id = try self.extractField(child_ty, cmp_rhs_id, @intCast(i)); - const result_id = try self.cmp(op, Type.bool, child_ty, lhs_index_id, rhs_index_id); - constituent.* = try self.convertToIndirect(Type.bool, result_id); + var wip = try self.elementWise(result_ty); + defer wip.deinit(); + const scalar_ty = ty.scalarType(mod); + for (wip.results, 0..) |*result_id, i| { + const lhs_elem_id = try wip.elementAt(ty, lhs_id, i); + const rhs_elem_id = try wip.elementAt(ty, rhs_id, i); + result_id.* = try self.cmp(op, Type.bool, scalar_ty, lhs_elem_id, rhs_elem_id); } - - return try self.constructArray(result_ty, constituents); + return wip.finalize(); }, else => unreachable, }; diff --git a/test/behavior/maximum_minimum.zig b/test/behavior/maximum_minimum.zig index f7cb1ee513..a6a2e3b8e8 100644 --- a/test/behavior/maximum_minimum.zig +++ b/test/behavior/maximum_minimum.zig @@ -31,7 +31,6 @@ test "@max on vectors" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest; @@ -86,7 +85,6 @@ test "@min for vectors" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest; @@ -199,7 +197,6 @@ test "@min/@max notices vector bounds" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; var x: @Vector(2, u16) = .{ 140, 40 }; @@ -253,7 +250,6 @@ test "@min/@max notices bounds from vector types" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; var x: @Vector(2, u16) = .{ 30, 67 }; @@ -295,7 +291,6 @@ test "@min/@max notices bounds from vector types when element of comptime-known if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx)) return error.SkipZigTest;