mirror of
https://github.com/ziglang/zig.git
synced 2026-01-20 22:35:24 +00:00
spirv: vectorize max, min
This commit is contained in:
parent
9f0227a326
commit
9641d2ebdb
@ -2391,45 +2391,51 @@ const DeclGen = struct {
|
||||
}
|
||||
|
||||
fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef {
|
||||
const result_ty_ref = try self.resolveType(result_ty, .direct);
|
||||
const info = self.arithmeticTypeInfo(result_ty);
|
||||
|
||||
// TODO: Use fmin for OpenCL
|
||||
const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id);
|
||||
const selection_id = switch (info.class) {
|
||||
.float => blk: {
|
||||
// cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
|
||||
// but we want it to pick lhs. Therefore we also have to check if
|
||||
// rhs is nan. We don't need to care about the result when both
|
||||
// are nan.
|
||||
const rhs_is_nan_id = self.spv.allocId();
|
||||
const bool_ty_ref = try self.resolveType(Type.bool, .direct);
|
||||
try self.func.body.emit(self.spv.gpa, .OpIsNan, .{
|
||||
.id_result_type = self.typeId(bool_ty_ref),
|
||||
.id_result = rhs_is_nan_id,
|
||||
.x = rhs_id,
|
||||
});
|
||||
const float_cmp_id = self.spv.allocId();
|
||||
try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
|
||||
.id_result_type = self.typeId(bool_ty_ref),
|
||||
.id_result = float_cmp_id,
|
||||
.operand_1 = cmp_id,
|
||||
.operand_2 = rhs_is_nan_id,
|
||||
});
|
||||
break :blk float_cmp_id;
|
||||
},
|
||||
else => cmp_id,
|
||||
};
|
||||
var wip = try self.elementWise(result_ty);
|
||||
defer wip.deinit();
|
||||
for (wip.results, 0..) |*result_id, i| {
|
||||
const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
|
||||
const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i);
|
||||
|
||||
const result_id = self.spv.allocId();
|
||||
try self.func.body.emit(self.spv.gpa, .OpSelect, .{
|
||||
.id_result_type = self.typeId(result_ty_ref),
|
||||
.id_result = result_id,
|
||||
.condition = selection_id,
|
||||
.object_1 = lhs_id,
|
||||
.object_2 = rhs_id,
|
||||
});
|
||||
return result_id;
|
||||
// TODO: Use fmin for OpenCL
|
||||
const cmp_id = try self.cmp(op, Type.bool, wip.scalar_ty, lhs_elem_id, rhs_elem_id);
|
||||
const selection_id = switch (info.class) {
|
||||
.float => blk: {
|
||||
// cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
|
||||
// but we want it to pick lhs. Therefore we also have to check if
|
||||
// rhs is nan. We don't need to care about the result when both
|
||||
// are nan.
|
||||
const rhs_is_nan_id = self.spv.allocId();
|
||||
const bool_ty_ref = try self.resolveType(Type.bool, .direct);
|
||||
try self.func.body.emit(self.spv.gpa, .OpIsNan, .{
|
||||
.id_result_type = self.typeId(bool_ty_ref),
|
||||
.id_result = rhs_is_nan_id,
|
||||
.x = rhs_elem_id,
|
||||
});
|
||||
const float_cmp_id = self.spv.allocId();
|
||||
try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
|
||||
.id_result_type = self.typeId(bool_ty_ref),
|
||||
.id_result = float_cmp_id,
|
||||
.operand_1 = cmp_id,
|
||||
.operand_2 = rhs_is_nan_id,
|
||||
});
|
||||
break :blk float_cmp_id;
|
||||
},
|
||||
else => cmp_id,
|
||||
};
|
||||
|
||||
result_id.* = self.spv.allocId();
|
||||
try self.func.body.emit(self.spv.gpa, .OpSelect, .{
|
||||
.id_result_type = wip.scalar_ty_id,
|
||||
.id_result = result_id.*,
|
||||
.condition = selection_id,
|
||||
.object_1 = lhs_elem_id,
|
||||
.object_2 = rhs_elem_id,
|
||||
});
|
||||
}
|
||||
return wip.finalize();
|
||||
}
|
||||
|
||||
/// This function normalizes values to a canonical representation
|
||||
@ -3107,20 +3113,15 @@ const DeclGen = struct {
|
||||
return result_id;
|
||||
},
|
||||
.Vector => {
|
||||
const child_ty = ty.childType(mod);
|
||||
const vector_len = ty.vectorLen(mod);
|
||||
|
||||
const constituents = try self.gpa.alloc(IdRef, vector_len);
|
||||
defer self.gpa.free(constituents);
|
||||
|
||||
for (constituents, 0..) |*constituent, i| {
|
||||
const lhs_index_id = try self.extractField(child_ty, cmp_lhs_id, @intCast(i));
|
||||
const rhs_index_id = try self.extractField(child_ty, cmp_rhs_id, @intCast(i));
|
||||
const result_id = try self.cmp(op, Type.bool, child_ty, lhs_index_id, rhs_index_id);
|
||||
constituent.* = try self.convertToIndirect(Type.bool, result_id);
|
||||
var wip = try self.elementWise(result_ty);
|
||||
defer wip.deinit();
|
||||
const scalar_ty = ty.scalarType(mod);
|
||||
for (wip.results, 0..) |*result_id, i| {
|
||||
const lhs_elem_id = try wip.elementAt(ty, lhs_id, i);
|
||||
const rhs_elem_id = try wip.elementAt(ty, rhs_id, i);
|
||||
result_id.* = try self.cmp(op, Type.bool, scalar_ty, lhs_elem_id, rhs_elem_id);
|
||||
}
|
||||
|
||||
return try self.constructArray(result_ty, constituents);
|
||||
return wip.finalize();
|
||||
},
|
||||
else => unreachable,
|
||||
};
|
||||
|
||||
@ -31,7 +31,6 @@ test "@max on vectors" {
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_x86_64 and
|
||||
!comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest;
|
||||
|
||||
@ -86,7 +85,6 @@ test "@min for vectors" {
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_x86_64 and
|
||||
!comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest;
|
||||
|
||||
@ -199,7 +197,6 @@ test "@min/@max notices vector bounds" {
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
|
||||
|
||||
var x: @Vector(2, u16) = .{ 140, 40 };
|
||||
@ -253,7 +250,6 @@ test "@min/@max notices bounds from vector types" {
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
|
||||
|
||||
var x: @Vector(2, u16) = .{ 30, 67 };
|
||||
@ -295,7 +291,6 @@ test "@min/@max notices bounds from vector types when element of comptime-known
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_x86_64 and
|
||||
!comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx)) return error.SkipZigTest;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user