diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 286b45f973..33e068032c 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -2187,6 +2187,7 @@ const DeclGen = struct { .sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan), .shl_with_overflow => try self.airShlOverflow(inst), + .reduce, .reduce_optimized => try self.airReduce(inst), .shuffle => try self.airShuffle(inst), .ptr_add => try self.airPtrAdd(inst), @@ -2388,9 +2389,14 @@ const DeclGen = struct { const lhs_id = try self.resolve(bin_op.lhs); const rhs_id = try self.resolve(bin_op.rhs); const result_ty = self.typeOfIndex(inst); - const result_ty_ref = try self.resolveType(result_ty, .direct); + return try self.minMax(result_ty, op, lhs_id, rhs_id); + } + + fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef { + const result_ty_ref = try self.resolveType(result_ty, .direct); const info = try self.arithmeticTypeInfo(result_ty); + // TODO: Use fmin for OpenCL const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id); const selection_id = switch (info.class) { @@ -2758,6 +2764,73 @@ const DeclGen = struct { ); } + fn airReduce(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + const mod = self.module; + const reduce = self.air.instructions.items(.data)[@intFromEnum(inst)].reduce; + const operand = try self.resolve(reduce.operand); + const operand_ty = self.typeOf(reduce.operand); + const scalar_ty = operand_ty.scalarType(mod); + const scalar_ty_ref = try self.resolveType(scalar_ty, .direct); + const scalar_ty_id = self.typeId(scalar_ty_ref); + + const info = try self.arithmeticTypeInfo(operand_ty); + + var result_id = try self.extractField(scalar_ty, operand, 0); + const len = operand_ty.vectorLen(mod); + + switch (reduce.operation) { + .Min, .Max => |op| { + const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt; + for (1..len) |i| { + const lhs = result_id; + const rhs = try self.extractField(scalar_ty, operand, @intCast(i)); + result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs); + } + + return result_id; + }, + else => {}, + } + + const opcode: Opcode = switch (info.class) { + .bool => switch (reduce.operation) { + .And => .OpLogicalAnd, + .Or => .OpLogicalOr, + .Xor => .OpLogicalNotEqual, + else => unreachable, + }, + .strange_integer, .integer => switch (reduce.operation) { + .And => .OpBitwiseAnd, + .Or => .OpBitwiseOr, + .Xor => .OpBitwiseXor, + .Add => .OpIAdd, + .Mul => .OpIMul, + else => unreachable, + }, + .float => switch (reduce.operation) { + .Add => .OpFAdd, + .Mul => .OpFMul, + else => unreachable, + }, + .composite_integer => unreachable, // TODO + }; + + for (1..len) |i| { + const lhs = result_id; + const rhs = try self.extractField(scalar_ty, operand, @intCast(i)); + result_id = self.spv.allocId(); + + try self.func.body.emitRaw(self.spv.gpa, opcode, 4); + self.func.body.writeOperand(spec.IdResultType, scalar_ty_id); + self.func.body.writeOperand(spec.IdResult, result_id); + self.func.body.writeOperand(spec.IdResultType, lhs); + self.func.body.writeOperand(spec.IdResultType, rhs); + } + + return result_id; + } + fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const mod = self.module; if (self.liveness.isUnused(inst)) return null; diff --git a/test/behavior/vector.zig b/test/behavior/vector.zig index d02f0b6515..f87f7b722d 100644 --- a/test/behavior/vector.zig +++ b/test/behavior/vector.zig @@ -1231,7 +1231,6 @@ test "byte vector initialized in inline function" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (comptime builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .x86_64 and builtin.cpu.features.isEnabled(@intFromEnum(std.Target.x86.Feature.avx512f))) @@ -1301,7 +1300,6 @@ test "@intCast to u0" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var zeros = @Vector(2, u32){ 0, 0 }; _ = &zeros;