diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 62f945dd81..42a11120b5 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -1970,6 +1970,9 @@ const DeclGen = struct { .shl => try self.airShift(inst, .OpShiftLeftLogical), + .min => try self.airMinMax(inst, .lt), + .max => try self.airMinMax(inst, .gt), + .bitcast => try self.airBitCast(inst), .intcast, .trunc => try self.airIntCast(inst), .int_from_ptr => try self.airIntFromPtr(inst), @@ -2103,6 +2106,54 @@ const DeclGen = struct { return result_id; } + fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: std.math.CompareOperator) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + + const bin_op = self.air.instructions.items(.data)[inst].bin_op; + const lhs_id = try self.resolve(bin_op.lhs); + const rhs_id = try self.resolve(bin_op.rhs); + const result_ty = self.typeOfIndex(inst); + const result_ty_ref = try self.resolveType(result_ty, .direct); + + const info = try self.arithmeticTypeInfo(result_ty); + // TODO: Use fmin for OpenCL + const cmp_id = try self.cmp(op, result_ty, lhs_id, rhs_id); + const selection_id = switch (info.class) { + .float => blk: { + // cmp uses OpFOrd. When we have 0 [<>] nan this returns false, + // but we want it to pick lhs. Therefore we also have to check if + // rhs is nan. We don't need to care about the result when both + // are nan. + const rhs_is_nan_id = self.spv.allocId(); + const bool_ty_ref = try self.resolveType(Type.bool, .direct); + try self.func.body.emit(self.spv.gpa, .OpIsNan, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = rhs_is_nan_id, + .x = rhs_id, + }); + const float_cmp_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = float_cmp_id, + .operand_1 = cmp_id, + .operand_2 = rhs_is_nan_id, + }); + break :blk float_cmp_id; + }, + else => cmp_id, + }; + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpSelect, .{ + .id_result_type = self.typeId(result_ty_ref), + .id_result = result_id, + .condition = selection_id, + .object_1 = lhs_id, + .object_2 = rhs_id, + }); + return result_id; + } + fn maskStrangeInt(self: *DeclGen, ty_ref: CacheRef, value_id: IdRef, bits: u16) !IdRef { const mask_value = if (bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(bits))) - 1; const result_id = self.spv.allocId(); diff --git a/test/behavior/maximum_minimum.zig b/test/behavior/maximum_minimum.zig index c76dd7573d..0b9ca8c636 100644 --- a/test/behavior/maximum_minimum.zig +++ b/test/behavior/maximum_minimum.zig @@ -9,14 +9,16 @@ test "@max" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn doTheTest() !void { var x: i32 = 10; var y: f32 = 0.68; + var nan: f32 = std.math.nan(f32); try expect(@as(i32, 10) == @max(@as(i32, -3), x)); try expect(@as(f32, 3.2) == @max(@as(f32, 3.2), y)); + try expect(y == @max(nan, y)); + try expect(y == @max(y, nan)); } }; try S.doTheTest(); @@ -58,14 +60,16 @@ test "@min" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn doTheTest() !void { var x: i32 = 10; var y: f32 = 0.68; + var nan: f32 = std.math.nan(f32); try expect(@as(i32, -3) == @min(@as(i32, -3), x)); try expect(@as(f32, 0.68) == @min(@as(f32, 3.2), y)); + try expect(y == @min(nan, y)); + try expect(y == @min(y, nan)); } }; try S.doTheTest();