diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 1cf67357f0..215a9421f1 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -3,6 +3,7 @@ const Allocator = std.mem.Allocator; const Target = std.Target; const log = std.log.scoped(.codegen); const assert = std.debug.assert; +const Signedness = std.builtin.Signedness; const Module = @import("../Module.zig"); const Decl = Module.Decl; @@ -423,6 +424,17 @@ const DeclGen = struct { return self.fail("TODO (SPIR-V): " ++ format, args); } + /// This imports the "default" extended instruction set for the target + /// For OpenCL, OpenCL.std.100. For Vulkan, GLSL.std.450. + fn importExtendedSet(self: *DeclGen) !IdResult { + const target = self.getTarget(); + return switch (target.os.tag) { + .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"), + .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"), + else => unreachable, + }; + } + /// Fetch the result-id for a previously generated instruction or constant. fn resolve(self: *DeclGen, inst: Air.Inst.Ref) !IdRef { const mod = self.module; @@ -631,6 +643,19 @@ const DeclGen = struct { const mod = self.module; const target = self.getTarget(); if (ty.zigTypeTag(mod) != .Vector) return false; + + // TODO: This check must be expanded for types that can be represented + // as integers (enums / packed structs?) and types that are represented + // by multiple SPIR-V values. + const scalar_ty = ty.scalarType(mod); + switch (scalar_ty.zigTypeTag(mod)) { + .Bool, + .Int, + .Float, + => {}, + else => return false, + } + const elem_ty = ty.childType(mod); const len = ty.vectorLen(mod); @@ -723,9 +748,13 @@ const DeclGen = struct { // Use backing bits so that negatives are sign extended const backing_bits = self.backingIntBits(int_info.bits).?; // Assertion failure means big int - const bits: u64 = switch (int_info.signedness) { - // Intcast needed to silence compile errors for when the wrong path is compiled. - // Lazy fix. + const signedness: Signedness = switch (@typeInfo(@TypeOf(value))) { + .Int => |int| int.signedness, + .ComptimeInt => if (value < 0) .signed else .unsigned, + else => unreachable, + }; + + const bits: u64 = switch (signedness) { .signed => @bitCast(@as(i64, @intCast(value))), .unsigned => @as(u64, @intCast(value)), }; @@ -1392,6 +1421,19 @@ const DeclGen = struct { return ty_id; } + fn zigScalarOrVectorTypeLike(self: *DeclGen, new_ty: Type, base_ty: Type) !Type { + const mod = self.module; + const new_scalar_ty = new_ty.scalarType(mod); + if (!base_ty.isVector(mod)) { + return new_scalar_ty; + } + + return try mod.vectorType(.{ + .len = base_ty.vectorLen(mod), + .child = new_scalar_ty.toIntern(), + }); + } + /// Generate a union type. Union types are always generated with the /// most aligned field active. If the tag alignment is greater /// than that of the payload, a regular union (non-packed, with both tag and @@ -1928,77 +1970,897 @@ const DeclGen = struct { return union_layout; } - /// This structure is used as helper for element-wise operations. It is intended - /// to be used with vectors, fake vectors (arrays) and single elements. - const WipElementWise = struct { - dg: *DeclGen, - result_ty: Type, + /// This structure represents a "temporary" value: Something we are currently + /// operating on. It typically lives no longer than the function that + /// implements a particular AIR operation. These are used to easier + /// implement vectorizable operations (see Vectorization and the build* + /// functions), and typically are only used for vectors of primitive types. + const Temporary = struct { + /// The type of the temporary. This is here mainly + /// for easier bookkeeping. Because we will never really + /// store Temporaries, they only cause extra stack space, + /// therefore no real storage is wasted. ty: Type, - /// Always in direct representation. - ty_id: IdRef, - /// True if the input is an array type. - is_array: bool, - /// The element-wise operation should fill these results before calling finalize(). - /// These should all be in **direct** representation! `finalize()` will convert - /// them to indirect if required. - results: []IdRef, + /// The value that this temporary holds. This is not necessarily + /// a value that is actually usable, or a single value: It is virtual + /// until materialize() is called, at which point is turned into + /// the usual SPIR-V representation of `self.ty`. + value: Temporary.Value, - fn deinit(wip: *WipElementWise) void { - wip.dg.gpa.free(wip.results); + const Value = union(enum) { + singleton: IdResult, + exploded_vector: IdRange, + }; + + fn init(ty: Type, singleton: IdResult) Temporary { + return .{ .ty = ty, .value = .{ .singleton = singleton } }; } - /// Utility function to extract the element at a particular index in an - /// input array. This type is expected to be a fake vector (array) if `wip.is_array`, and - /// a vector or scalar otherwise. - fn elementAt(wip: WipElementWise, ty: Type, value: IdRef, index: usize) !IdRef { - const mod = wip.dg.module; - if (wip.is_array) { - assert(ty.isVector(mod)); - return try wip.dg.extractVectorComponent(ty.childType(mod), value, @intCast(index)); - } else { - assert(index == 0); - return value; + fn materialize(self: Temporary, dg: *DeclGen) !IdResult { + const mod = dg.module; + switch (self.value) { + .singleton => |id| return id, + .exploded_vector => |range| { + assert(self.ty.isVector(mod)); + assert(self.ty.vectorLen(mod) == range.len); + const consituents = try dg.gpa.alloc(IdRef, range.len); + defer dg.gpa.free(consituents); + for (consituents, 0..range.len) |*id, i| { + id.* = range.at(i); + } + return dg.constructVector(self.ty, consituents); + }, } } - /// Turns the results of this WipElementWise into a result. This can be - /// vectors, fake vectors (arrays) and single elements, depending on `result_ty`. - /// After calling this function, this WIP is no longer usable. - /// Results is in `direct` representation. - fn finalize(wip: *WipElementWise) !IdRef { - if (wip.is_array) { - return try wip.dg.constructVector(wip.result_ty, wip.results); - } else { - return wip.results[0]; - } + fn vectorization(self: Temporary, dg: *DeclGen) Vectorization { + return Vectorization.fromType(self.ty, dg); } - /// Allocate a result id at a particular index, and return it. - fn allocId(wip: *WipElementWise, index: usize) IdRef { - assert(wip.is_array or index == 0); - wip.results[index] = wip.dg.spv.allocId(); - return wip.results[index]; + fn pun(self: Temporary, new_ty: Type) Temporary { + return .{ + .ty = new_ty, + .value = self.value, + }; + } + + /// 'Explode' a temporary into separate elements. This turns a vector + /// into a bag of elements. + fn explode(self: Temporary, dg: *DeclGen) !IdRange { + const mod = dg.module; + + // If the value is a scalar, then this is a no-op. + if (!self.ty.isVector(mod)) { + return switch (self.value) { + .singleton => |id| IdRange{ .base = @intFromEnum(id), .len = 1 }, + .exploded_vector => |range| range, + }; + } + + const ty_id = try dg.resolveType(self.ty.scalarType(mod), .direct); + const n = self.ty.vectorLen(mod); + const results = dg.spv.allocIds(n); + + const id = switch (self.value) { + .singleton => |id| id, + .exploded_vector => |range| return range, + }; + + for (0..n) |i| { + const indexes = [_]u32{@intCast(i)}; + try dg.func.body.emit(dg.spv.gpa, .OpCompositeExtract, .{ + .id_result_type = ty_id, + .id_result = results.at(i), + .composite = id, + .indexes = &indexes, + }); + } + + return results; } }; - /// Create a new element-wise operation. - fn elementWise(self: *DeclGen, result_ty: Type, force_element_wise: bool) !WipElementWise { - const mod = self.module; - const is_array = result_ty.isVector(mod) and (!self.isSpvVector(result_ty) or force_element_wise); - const num_results = if (is_array) result_ty.vectorLen(mod) else 1; - const results = try self.gpa.alloc(IdRef, num_results); - @memset(results, undefined); - - const ty = if (is_array) result_ty.scalarType(mod) else result_ty; - const ty_id = try self.resolveType(ty, .direct); - + /// Initialize a `Temporary` from an AIR value. + fn temporary(self: *DeclGen, inst: Air.Inst.Ref) !Temporary { return .{ - .dg = self, - .result_ty = result_ty, - .ty = ty, - .ty_id = ty_id, - .is_array = is_array, - .results = results, + .ty = self.typeOf(inst), + .value = .{ .singleton = try self.resolve(inst) }, + }; + } + + /// This union describes how a particular operation should be vectorized. + /// That depends on the operation and number of components of the inputs. + const Vectorization = union(enum) { + /// This is an operation between scalars. + scalar, + /// This is an operation between SPIR-V vectors. + /// Value is number of components. + spv_vectorized: u32, + /// This operation is unrolled into separate operations. + /// Inputs may still be SPIR-V vectors, for example, + /// when the operation can't be vectorized in SPIR-V. + /// Value is number of components. + unrolled: u32, + + /// Derive a vectorization from a particular type. This usually + /// only checks the size, but the source-of-truth is implemented + /// by `isSpvVector()`. + fn fromType(ty: Type, dg: *DeclGen) Vectorization { + const mod = dg.module; + if (!ty.isVector(mod)) { + return .scalar; + } else if (dg.isSpvVector(ty)) { + return .{ .spv_vectorized = ty.vectorLen(mod) }; + } else { + return .{ .unrolled = ty.vectorLen(mod) }; + } + } + + /// Given two vectorization methods, compute a "unification": a fallback + /// that works for both, according to the following rules: + /// - Scalars may broadcast + /// - SPIR-V vectorized operations may unroll + /// - Prefer scalar > SPIR-V vectorized > unrolled + fn unify(a: Vectorization, b: Vectorization) Vectorization { + if (a == .scalar and b == .scalar) { + return .scalar; + } else if (a == .spv_vectorized and b == .spv_vectorized) { + assert(a.components() == b.components()); + return .{ .spv_vectorized = a.components() }; + } else if (a == .unrolled or b == .unrolled) { + if (a == .unrolled and b == .unrolled) { + assert(a.components() == b.components()); + return .{ .unrolled = a.components() }; + } else if (a == .unrolled) { + return .{ .unrolled = a.components() }; + } else if (b == .unrolled) { + return .{ .unrolled = b.components() }; + } else { + unreachable; + } + } else { + if (a == .spv_vectorized) { + return .{ .spv_vectorized = a.components() }; + } else if (b == .spv_vectorized) { + return .{ .spv_vectorized = b.components() }; + } else { + unreachable; + } + } + } + + /// Force this vectorization to be unrolled, if its + /// an operation involving vectors. + fn unroll(self: Vectorization) Vectorization { + return switch (self) { + .scalar, .unrolled => self, + .spv_vectorized => |n| .{ .unrolled = n }, + }; + } + + /// Query the number of components that inputs of this operation have. + /// Note: for broadcasting scalars, this returns the number of elements + /// that the broadcasted vector would have. + fn components(self: Vectorization) u32 { + return switch (self) { + .scalar => 1, + .spv_vectorized => |n| n, + .unrolled => |n| n, + }; + } + + /// Query the number of operations involving this vectorization. + /// This is basically the number of components, except that SPIR-V vectorized + /// operations only need a single SPIR-V instruction. + fn operations(self: Vectorization) u32 { + return switch (self) { + .scalar, .spv_vectorized => 1, + .unrolled => |n| n, + }; + } + + /// Turns `ty` into the result-type of an individual vector operation. + /// `ty` may be a scalar or vector, it doesn't matter. + fn operationType(self: Vectorization, dg: *DeclGen, ty: Type) !Type { + const mod = dg.module; + const scalar_ty = ty.scalarType(mod); + return switch (self) { + .scalar, .unrolled => scalar_ty, + .spv_vectorized => |n| try mod.vectorType(.{ + .len = n, + .child = scalar_ty.toIntern(), + }), + }; + } + + /// Turns `ty` into the result-type of the entire operation. + /// `ty` may be a scalar or vector, it doesn't matter. + fn resultType(self: Vectorization, dg: *DeclGen, ty: Type) !Type { + const mod = dg.module; + const scalar_ty = ty.scalarType(mod); + return switch (self) { + .scalar => scalar_ty, + .unrolled, .spv_vectorized => |n| try mod.vectorType(.{ + .len = n, + .child = scalar_ty.toIntern(), + }), + }; + } + + /// Before a temporary can be used, some setup may need to be one. This function implements + /// this setup, and returns a new type that holds the relevant information on how to access + /// elements of the input. + fn prepare(self: Vectorization, dg: *DeclGen, tmp: Temporary) !PreparedOperand { + const mod = dg.module; + const is_vector = tmp.ty.isVector(mod); + const is_spv_vector = dg.isSpvVector(tmp.ty); + const value: PreparedOperand.Value = switch (tmp.value) { + .singleton => |id| switch (self) { + .scalar => blk: { + assert(!is_vector); + break :blk .{ .scalar = id }; + }, + .spv_vectorized => blk: { + if (is_vector) { + assert(is_spv_vector); + break :blk .{ .spv_vectorwise = id }; + } + + // Broadcast scalar into vector. + const vector_ty = try mod.vectorType(.{ + .len = self.components(), + .child = tmp.ty.toIntern(), + }); + + const vector = try dg.constructVectorSplat(vector_ty, id); + return .{ + .ty = vector_ty, + .value = .{ .spv_vectorwise = vector }, + }; + }, + .unrolled => blk: { + if (is_vector) { + break :blk .{ .vector_exploded = try tmp.explode(dg) }; + } else { + break :blk .{ .scalar_broadcast = id }; + } + }, + }, + .exploded_vector => |range| switch (self) { + .scalar => unreachable, + .spv_vectorized => |n| blk: { + // We can vectorize this operation, but we have an exploded vector. This can happen + // when a vectorizable operation succeeds a non-vectorizable operation. In this case, + // pack up the IDs into a SPIR-V vector. This path should not be able to be hit with + // a type that cannot do that. + assert(is_spv_vector); + assert(range.len == n); + const vec = try tmp.materialize(dg); + break :blk .{ .spv_vectorwise = vec }; + }, + .unrolled => |n| blk: { + assert(range.len == n); + break :blk .{ .vector_exploded = range }; + }, + }, + }; + + return .{ + .ty = tmp.ty, + .value = value, + }; + } + + /// Finalize the results of an operation back into a temporary. `results` is + /// a list of result-ids of the operation. + fn finalize(self: Vectorization, ty: Type, results: IdRange) Temporary { + assert(self.operations() == results.len); + const value: Temporary.Value = switch (self) { + .scalar, .spv_vectorized => blk: { + break :blk .{ .singleton = results.at(0) }; + }, + .unrolled => blk: { + break :blk .{ .exploded_vector = results }; + }, + }; + + return .{ .ty = ty, .value = value }; + } + + /// This struct represents an operand that has gone through some setup, and is + /// ready to be used as part of an operation. + const PreparedOperand = struct { + ty: Type, + value: PreparedOperand.Value, + + /// The types of value that a prepared operand can hold internally. Depends + /// on the operation and input value. + const Value = union(enum) { + /// A single scalar value that is used by a scalar operation. + scalar: IdResult, + /// A single scalar that is broadcasted in an unrolled operation. + scalar_broadcast: IdResult, + /// A SPIR-V vector that is used in SPIR-V vectorize operation. + spv_vectorwise: IdResult, + /// A vector represented by a consecutive list of IDs that is used in an unrolled operation. + vector_exploded: IdRange, + }; + + /// Query the value at a particular index of the operation. Note that + /// the index is *not* the component/lane, but the index of the *operation*. When + /// this operation is vectorized, the return value of this function is a SPIR-V vector. + /// See also `Vectorization.operations()`. + fn at(self: PreparedOperand, i: usize) IdResult { + switch (self.value) { + .scalar => |id| { + assert(i == 0); + return id; + }, + .scalar_broadcast => |id| { + return id; + }, + .spv_vectorwise => |id| { + assert(i == 0); + return id; + }, + .vector_exploded => |range| { + return range.at(i); + }, + } + } + }; + }; + + /// A utility function to compute the vectorization style of + /// a list of values. These values may be any of the following: + /// - A `Vectorization` instance + /// - A Type, in which case the vectorization is computed via `Vectorization.fromType`. + /// - A Temporary, in which case the vectorization is computed via `Temporary.vectorization`. + fn vectorization(self: *DeclGen, args: anytype) Vectorization { + var v: Vectorization = undefined; + assert(args.len >= 1); + inline for (args, 0..) |arg, i| { + const iv: Vectorization = switch (@TypeOf(arg)) { + Vectorization => arg, + Type => Vectorization.fromType(arg, self), + Temporary => arg.vectorization(self), + else => @compileError("invalid type"), + }; + if (i == 0) { + v = iv; + } else { + v = v.unify(iv); + } + } + return v; + } + + /// This function builds an OpSConvert of OpUConvert depending on the + /// signedness of the types. + fn buildIntConvert(self: *DeclGen, dst_ty: Type, src: Temporary) !Temporary { + const mod = self.module; + + const dst_ty_id = try self.resolveType(dst_ty.scalarType(mod), .direct); + const src_ty_id = try self.resolveType(src.ty.scalarType(mod), .direct); + + const v = self.vectorization(.{ dst_ty, src }); + const result_ty = try v.resultType(self, dst_ty); + + // We can directly compare integers, because those type-IDs are cached. + if (dst_ty_id == src_ty_id) { + // Nothing to do, type-pun to the right value. + // Note, Caller guarantees that the types fit (or caller will normalize after), + // so we don't have to normalize here. + // Note, dst_ty may be a scalar type even if we expect a vector, so we have to + // convert to the right type here. + return src.pun(result_ty); + } + + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, dst_ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + + const opcode: Opcode = if (dst_ty.isSignedInt(mod)) .OpSConvert else .OpUConvert; + + const op_src = try v.prepare(self, src); + + for (0..ops) |i| { + try self.func.body.emitRaw(self.spv.gpa, opcode, 3); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, results.at(i)); + self.func.body.writeOperand(IdResult, op_src.at(i)); + } + + return v.finalize(result_ty, results); + } + + fn buildFma(self: *DeclGen, a: Temporary, b: Temporary, c: Temporary) !Temporary { + const target = self.getTarget(); + + const v = self.vectorization(.{ a, b, c }); + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, a.ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, a.ty); + + const op_a = try v.prepare(self, a); + const op_b = try v.prepare(self, b); + const op_c = try v.prepare(self, c); + + const set = try self.importExtendedSet(); + + // TODO: Put these numbers in some definition + const instruction: u32 = switch (target.os.tag) { + .opencl => 26, // fma + // NOTE: Vulkan's FMA instruction does *NOT* produce the right values! + // its precision guarantees do NOT match zigs and it does NOT match OpenCLs! + // it needs to be emulated! + .vulkan => unreachable, // TODO: See above + else => unreachable, + }; + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = op_result_ty_id, + .id_result = results.at(i), + .set = set, + .instruction = .{ .inst = instruction }, + .id_ref_4 = &.{ op_a.at(i), op_b.at(i), op_c.at(i) }, + }); + } + + return v.finalize(result_ty, results); + } + + fn buildSelect(self: *DeclGen, condition: Temporary, lhs: Temporary, rhs: Temporary) !Temporary { + const mod = self.module; + + const v = self.vectorization(.{ condition, lhs, rhs }); + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, lhs.ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, lhs.ty); + + assert(condition.ty.scalarType(mod).zigTypeTag(mod) == .Bool); + + const cond = try v.prepare(self, condition); + const object_1 = try v.prepare(self, lhs); + const object_2 = try v.prepare(self, rhs); + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpSelect, .{ + .id_result_type = op_result_ty_id, + .id_result = results.at(i), + .condition = cond.at(i), + .object_1 = object_1.at(i), + .object_2 = object_2.at(i), + }); + } + + return v.finalize(result_ty, results); + } + + const CmpPredicate = enum { + l_eq, + l_ne, + i_ne, + i_eq, + s_lt, + s_gt, + s_le, + s_ge, + u_lt, + u_gt, + u_le, + u_ge, + f_oeq, + f_une, + f_olt, + f_ole, + f_ogt, + f_oge, + }; + + fn buildCmp(self: *DeclGen, pred: CmpPredicate, lhs: Temporary, rhs: Temporary) !Temporary { + const v = self.vectorization(.{ lhs, rhs }); + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, Type.bool); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, Type.bool); + + const op_lhs = try v.prepare(self, lhs); + const op_rhs = try v.prepare(self, rhs); + + const opcode: Opcode = switch (pred) { + .l_eq => .OpLogicalEqual, + .l_ne => .OpLogicalNotEqual, + .i_eq => .OpIEqual, + .i_ne => .OpINotEqual, + .s_lt => .OpSLessThan, + .s_gt => .OpSGreaterThan, + .s_le => .OpSLessThanEqual, + .s_ge => .OpSGreaterThanEqual, + .u_lt => .OpULessThan, + .u_gt => .OpUGreaterThan, + .u_le => .OpULessThanEqual, + .u_ge => .OpUGreaterThanEqual, + .f_oeq => .OpFOrdEqual, + .f_une => .OpFUnordNotEqual, + .f_olt => .OpFOrdLessThan, + .f_ole => .OpFOrdLessThanEqual, + .f_ogt => .OpFOrdGreaterThan, + .f_oge => .OpFOrdGreaterThanEqual, + }; + + for (0..ops) |i| { + try self.func.body.emitRaw(self.spv.gpa, opcode, 4); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, results.at(i)); + self.func.body.writeOperand(IdResult, op_lhs.at(i)); + self.func.body.writeOperand(IdResult, op_rhs.at(i)); + } + + return v.finalize(result_ty, results); + } + + const UnaryOp = enum { + l_not, + bit_not, + i_neg, + f_neg, + i_abs, + f_abs, + clz, + ctz, + floor, + ceil, + trunc, + round, + sqrt, + sin, + cos, + tan, + exp, + exp2, + log, + log2, + log10, + }; + + fn buildUnary(self: *DeclGen, op: UnaryOp, operand: Temporary) !Temporary { + const target = self.getTarget(); + const v = blk: { + const v = self.vectorization(.{operand}); + break :blk switch (op) { + // TODO: These instructions don't seem to be working + // properly for LLVM-based backends on OpenCL for 8- and + // 16-component vectors. + .i_abs => if (target.os.tag == .opencl and v.components() >= 8) v.unroll() else v, + else => v, + }; + }; + + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, operand.ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, operand.ty); + + const op_operand = try v.prepare(self, operand); + + if (switch (op) { + .l_not => .OpLogicalNot, + .bit_not => .OpNot, + .i_neg => .OpSNegate, + .f_neg => .OpFNegate, + else => @as(?Opcode, null), + }) |opcode| { + for (0..ops) |i| { + try self.func.body.emitRaw(self.spv.gpa, opcode, 3); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, results.at(i)); + self.func.body.writeOperand(IdResult, op_operand.at(i)); + } + } else { + const set = try self.importExtendedSet(); + const extinst: u32 = switch (target.os.tag) { + .opencl => switch (op) { + .i_abs => 141, // s_abs + .f_abs => 23, // fabs + .clz => 151, // clz + .ctz => 152, // ctz + .floor => 25, // floor + .ceil => 12, // ceil + .trunc => 66, // trunc + .round => 55, // round + .sqrt => 61, // sqrt + .sin => 57, // sin + .cos => 14, // cos + .tan => 62, // tan + .exp => 19, // exp + .exp2 => 20, // exp2 + .log => 37, // log + .log2 => 38, // log2 + .log10 => 39, // log10 + else => unreachable, + }, + // Note: We'll need to check these for floating point accuracy + // Vulkan does not put tight requirements on these, for correction + // we might want to emulate them at some point. + .vulkan => switch (op) { + .i_abs => 5, // SAbs + .f_abs => 4, // FAbs + .clz => unreachable, // TODO + .ctz => unreachable, // TODO + .floor => 8, // Floor + .ceil => 9, // Ceil + .trunc => 3, // Trunc + .round => 1, // Round + .sqrt, + .sin, + .cos, + .tan, + .exp, + .exp2, + .log, + .log2, + .log10, + => unreachable, // TODO + else => unreachable, + }, + else => unreachable, + }; + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = op_result_ty_id, + .id_result = results.at(i), + .set = set, + .instruction = .{ .inst = extinst }, + .id_ref_4 = &.{op_operand.at(i)}, + }); + } + } + + return v.finalize(result_ty, results); + } + + const BinaryOp = enum { + i_add, + f_add, + i_sub, + f_sub, + i_mul, + f_mul, + s_div, + u_div, + f_div, + s_rem, + f_rem, + s_mod, + u_mod, + f_mod, + srl, + sra, + sll, + bit_and, + bit_or, + bit_xor, + f_max, + s_max, + u_max, + f_min, + s_min, + u_min, + l_and, + l_or, + }; + + fn buildBinary(self: *DeclGen, op: BinaryOp, lhs: Temporary, rhs: Temporary) !Temporary { + const target = self.getTarget(); + + const v = self.vectorization(.{ lhs, rhs }); + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, lhs.ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, lhs.ty); + + const op_lhs = try v.prepare(self, lhs); + const op_rhs = try v.prepare(self, rhs); + + if (switch (op) { + .i_add => .OpIAdd, + .f_add => .OpFAdd, + .i_sub => .OpISub, + .f_sub => .OpFSub, + .i_mul => .OpIMul, + .f_mul => .OpFMul, + .s_div => .OpSDiv, + .u_div => .OpUDiv, + .f_div => .OpFDiv, + .s_rem => .OpSRem, + .f_rem => .OpFRem, + .s_mod => .OpSMod, + .u_mod => .OpUMod, + .f_mod => .OpFMod, + .srl => .OpShiftRightLogical, + .sra => .OpShiftRightArithmetic, + .sll => .OpShiftLeftLogical, + .bit_and => .OpBitwiseAnd, + .bit_or => .OpBitwiseOr, + .bit_xor => .OpBitwiseXor, + .l_and => .OpLogicalAnd, + .l_or => .OpLogicalOr, + else => @as(?Opcode, null), + }) |opcode| { + for (0..ops) |i| { + try self.func.body.emitRaw(self.spv.gpa, opcode, 4); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, results.at(i)); + self.func.body.writeOperand(IdResult, op_lhs.at(i)); + self.func.body.writeOperand(IdResult, op_rhs.at(i)); + } + } else { + const set = try self.importExtendedSet(); + + // TODO: Put these numbers in some definition + const extinst: u32 = switch (target.os.tag) { + .opencl => switch (op) { + .f_max => 27, // fmax + .s_max => 156, // s_max + .u_max => 157, // u_max + .f_min => 28, // fmin + .s_min => 158, // s_min + .u_min => 159, // u_min + else => unreachable, + }, + .vulkan => switch (op) { + .f_max => 40, // FMax + .s_max => 42, // SMax + .u_max => 41, // UMax + .f_min => 37, // FMin + .s_min => 39, // SMin + .u_min => 38, // UMin + else => unreachable, + }, + else => unreachable, + }; + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = op_result_ty_id, + .id_result = results.at(i), + .set = set, + .instruction = .{ .inst = extinst }, + .id_ref_4 = &.{ op_lhs.at(i), op_rhs.at(i) }, + }); + } + } + + return v.finalize(result_ty, results); + } + + /// This function builds an extended multiplication, either OpSMulExtended or OpUMulExtended on Vulkan, + /// or OpIMul and s_mul_hi or u_mul_hi on OpenCL. + fn buildWideMul( + self: *DeclGen, + op: enum { + s_mul_extended, + u_mul_extended, + }, + lhs: Temporary, + rhs: Temporary, + ) !struct { Temporary, Temporary } { + const mod = self.module; + const target = self.getTarget(); + const ip = &mod.intern_pool; + + const v = lhs.vectorization(self).unify(rhs.vectorization(self)); + const ops = v.operations(); + + const arith_op_ty = try v.operationType(self, lhs.ty); + const arith_op_ty_id = try self.resolveType(arith_op_ty, .direct); + + const lhs_op = try v.prepare(self, lhs); + const rhs_op = try v.prepare(self, rhs); + + const value_results = self.spv.allocIds(ops); + const overflow_results = self.spv.allocIds(ops); + + switch (target.os.tag) { + .opencl => { + // Currently, SPIRV-LLVM-Translator based backends cannot deal with OpSMulExtended and + // OpUMulExtended. For these we will use the OpenCL s_mul_hi to compute the high-order bits + // instead. + const set = try self.importExtendedSet(); + const overflow_inst: u32 = switch (op) { + .s_mul_extended => 160, // s_mul_hi + .u_mul_extended => 203, // u_mul_hi + }; + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpIMul, .{ + .id_result_type = arith_op_ty_id, + .id_result = value_results.at(i), + .operand_1 = lhs_op.at(i), + .operand_2 = rhs_op.at(i), + }); + + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = arith_op_ty_id, + .id_result = overflow_results.at(i), + .set = set, + .instruction = .{ .inst = overflow_inst }, + .id_ref_4 = &.{ lhs_op.at(i), rhs_op.at(i) }, + }); + } + }, + .vulkan => { + const op_result_ty = blk: { + // Operations return a struct{T, T} + // where T is maybe vectorized. + const types = [2]InternPool.Index{ arith_op_ty.toIntern(), arith_op_ty.toIntern() }; + const values = [2]InternPool.Index{ .none, .none }; + const index = try ip.getAnonStructType(mod.gpa, .{ + .types = &types, + .values = &values, + .names = &.{}, + }); + break :blk Type.fromInterned(index); + }; + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + + const opcode: Opcode = switch (op) { + .s_mul_extended => .OpSMulExtended, + .u_mul_extended => .OpUMulExtended, + }; + + for (0..ops) |i| { + const op_result = self.spv.allocId(); + + try self.func.body.emitRaw(self.spv.gpa, opcode, 4); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, op_result); + self.func.body.writeOperand(IdResult, lhs_op.at(i)); + self.func.body.writeOperand(IdResult, rhs_op.at(i)); + + // The above operation returns a struct. We might want to expand + // Temporary to deal with the fact that these are structs eventually, + // but for now, take the struct apart and return two separate vectors. + + try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{ + .id_result_type = arith_op_ty_id, + .id_result = value_results.at(i), + .composite = op_result, + .indexes = &.{0}, + }); + + try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{ + .id_result_type = arith_op_ty_id, + .id_result = overflow_results.at(i), + .composite = op_result, + .indexes = &.{1}, + }); + } + }, + else => unreachable, + } + + const result_ty = try v.resultType(self, lhs.ty); + return .{ + v.finalize(result_ty, value_results), + v.finalize(result_ty, overflow_results), }; } @@ -2237,59 +3099,42 @@ const DeclGen = struct { } } - fn intFromBool(self: *DeclGen, ty: Type, condition_id: IdRef) !IdRef { - const zero_id = try self.constInt(ty, 0, .direct); - const one_id = try self.constInt(ty, 1, .direct); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = result_id, - .condition = condition_id, - .object_1 = one_id, - .object_2 = zero_id, - }); - return result_id; + fn intFromBool(self: *DeclGen, value: Temporary) !Temporary { + return try self.intFromBool2(value, Type.u1); + } + + fn intFromBool2(self: *DeclGen, value: Temporary, result_ty: Type) !Temporary { + const zero_id = try self.constInt(result_ty, 0, .direct); + const one_id = try self.constInt(result_ty, 1, .direct); + + return try self.buildSelect( + value, + Temporary.init(result_ty, one_id), + Temporary.init(result_ty, zero_id), + ); } /// Convert representation from indirect (in memory) to direct (in 'register') /// This converts the argument type from resolveType(ty, .indirect) to resolveType(ty, .direct). fn convertToDirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { const mod = self.module; - const scalar_ty = ty.scalarType(mod); - const is_spv_vector = self.isSpvVector(ty); - switch (scalar_ty.zigTypeTag(mod)) { + switch (ty.scalarType(mod).zigTypeTag(mod)) { .Bool => { - // TODO: We may want to use something like elementWise in this function. - // First we need to audit whether this would recursively call into itself. - if (!ty.isVector(mod) or is_spv_vector) { - const result_id = self.spv.allocId(); - const scalar_false_id = try self.constBool(false, .indirect); - const false_id = if (is_spv_vector) blk: { - const index = try mod.intern_pool.get(mod.gpa, .{ - .vector_type = .{ - .len = ty.vectorLen(mod), - .child = Type.u1.toIntern(), - }, - }); - const vec_ty = Type.fromInterned(index); - break :blk try self.constructVectorSplat(vec_ty, scalar_false_id); - } else scalar_false_id; + const false_id = try self.constBool(false, .indirect); + // The operation below requires inputs in direct representation, but the operand + // is actually in indirect representation. + // Cheekily swap out the type to the direct equivalent of the indirect type here, they have the + // same representation when converted to SPIR-V. + const operand_ty = try self.zigScalarOrVectorTypeLike(Type.u1, ty); + // Note: We can guarantee that these are the same ID due to the SPIR-V Module's `vector_types` cache! + assert(try self.resolveType(operand_ty, .direct) == try self.resolveType(ty, .indirect)); - try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = result_id, - .operand_1 = operand_id, - .operand_2 = false_id, - }); - return result_id; - } - - const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod)); - for (constituents, 0..) |*id, i| { - const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i)); - id.* = try self.convertToDirect(scalar_ty, element); - } - return try self.constructVector(ty, constituents); + const result = try self.buildCmp( + .i_ne, + Temporary.init(operand_ty, operand_id), + Temporary.init(Type.u1, false_id), + ); + return try result.materialize(self); }, else => return operand_id, } @@ -2299,55 +3144,10 @@ const DeclGen = struct { /// This converts the argument type from resolveType(ty, .direct) to resolveType(ty, .indirect). fn convertToIndirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { const mod = self.module; - const scalar_ty = ty.scalarType(mod); - const is_spv_vector = self.isSpvVector(ty); - switch (scalar_ty.zigTypeTag(mod)) { + switch (ty.scalarType(mod).zigTypeTag(mod)) { .Bool => { - const result_ty = if (is_spv_vector) blk: { - const index = try mod.intern_pool.get(mod.gpa, .{ - .vector_type = .{ - .len = ty.vectorLen(mod), - .child = Type.u1.toIntern(), - }, - }); - break :blk Type.fromInterned(index); - } else Type.u1; - - if (!ty.isVector(mod) or is_spv_vector) { - // TODO: We may want to use something like elementWise in this function. - // First we need to audit whether this would recursively call into itself. - // Also unify it with intFromBool - - const scalar_zero_id = try self.constInt(Type.u1, 0, .direct); - const scalar_one_id = try self.constInt(Type.u1, 1, .direct); - - const zero_id = if (is_spv_vector) - try self.constructVectorSplat(result_ty, scalar_zero_id) - else - scalar_zero_id; - - const one_id = if (is_spv_vector) - try self.constructVectorSplat(result_ty, scalar_one_id) - else - scalar_one_id; - - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = try self.resolveType(result_ty, .direct), - .id_result = result_id, - .condition = operand_id, - .object_1 = one_id, - .object_2 = zero_id, - }); - return result_id; - } - - const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod)); - for (constituents, 0..) |*id, i| { - const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i)); - id.* = try self.convertToIndirect(scalar_ty, element); - } - return try self.constructVector(result_ty, constituents); + const result = try self.intFromBool(Temporary.init(ty, operand_id)); + return try result.materialize(self); }, else => return operand_id, } @@ -2428,26 +3228,35 @@ const DeclGen = struct { const air_tags = self.air.instructions.items(.tag); const maybe_result_id: ?IdRef = switch (air_tags[@intFromEnum(inst)]) { // zig fmt: off - .add, .add_wrap, .add_optimized => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd), - .sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub), - .mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul), - + .add, .add_wrap, .add_optimized => try self.airArithOp(inst, .f_add, .i_add, .i_add), + .sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .f_sub, .i_sub, .i_sub), + .mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .f_mul, .i_mul, .i_mul), + .sqrt => try self.airUnOpSimple(inst, .sqrt), + .sin => try self.airUnOpSimple(inst, .sin), + .cos => try self.airUnOpSimple(inst, .cos), + .tan => try self.airUnOpSimple(inst, .tan), + .exp => try self.airUnOpSimple(inst, .exp), + .exp2 => try self.airUnOpSimple(inst, .exp2), + .log => try self.airUnOpSimple(inst, .log), + .log2 => try self.airUnOpSimple(inst, .log2), + .log10 => try self.airUnOpSimple(inst, .log10), .abs => try self.airAbs(inst), - .floor => try self.airFloor(inst), + .floor => try self.airUnOpSimple(inst, .floor), + .ceil => try self.airUnOpSimple(inst, .ceil), + .round => try self.airUnOpSimple(inst, .round), + .trunc_float => try self.airUnOpSimple(inst, .trunc), + .neg, .neg_optimized => try self.airUnOpSimple(inst, .f_neg), - .div_floor => try self.airDivFloor(inst), + .div_float, .div_float_optimized => try self.airArithOp(inst, .f_div, .s_div, .u_div), + .div_floor, .div_floor_optimized => try self.airDivFloor(inst), + .div_trunc, .div_trunc_optimized => try self.airDivTrunc(inst), - .div_float, - .div_float_optimized, - .div_trunc, - .div_trunc_optimized => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv), - .rem, .rem_optimized => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem), - .mod, .mod_optimized => try self.airArithOp(inst, .OpFMod, .OpSMod, .OpSMod), + .rem, .rem_optimized => try self.airArithOp(inst, .f_rem, .s_rem, .u_mod), + .mod, .mod_optimized => try self.airArithOp(inst, .f_mod, .s_mod, .u_mod), - - .add_with_overflow => try self.airAddSubOverflow(inst, .OpIAdd, .OpULessThan, .OpSLessThan), - .sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan), + .add_with_overflow => try self.airAddSubOverflow(inst, .i_add, .u_lt, .s_lt), + .sub_with_overflow => try self.airAddSubOverflow(inst, .i_sub, .u_gt, .s_gt), .mul_with_overflow => try self.airMulOverflow(inst), .shl_with_overflow => try self.airShlOverflow(inst), @@ -2456,6 +3265,8 @@ const DeclGen = struct { .ctz => try self.airClzCtz(inst, .ctz), .clz => try self.airClzCtz(inst, .clz), + .select => try self.airSelect(inst), + .splat => try self.airSplat(inst), .reduce, .reduce_optimized => try self.airReduce(inst), .shuffle => try self.airShuffle(inst), @@ -2463,17 +3274,17 @@ const DeclGen = struct { .ptr_add => try self.airPtrAdd(inst), .ptr_sub => try self.airPtrSub(inst), - .bit_and => try self.airBinOpSimple(inst, .OpBitwiseAnd), - .bit_or => try self.airBinOpSimple(inst, .OpBitwiseOr), - .xor => try self.airBinOpSimple(inst, .OpBitwiseXor), - .bool_and => try self.airBinOpSimple(inst, .OpLogicalAnd), - .bool_or => try self.airBinOpSimple(inst, .OpLogicalOr), + .bit_and => try self.airBinOpSimple(inst, .bit_and), + .bit_or => try self.airBinOpSimple(inst, .bit_or), + .xor => try self.airBinOpSimple(inst, .bit_xor), + .bool_and => try self.airBinOpSimple(inst, .l_and), + .bool_or => try self.airBinOpSimple(inst, .l_or), - .shl, .shl_exact => try self.airShift(inst, .OpShiftLeftLogical, .OpShiftLeftLogical), - .shr, .shr_exact => try self.airShift(inst, .OpShiftRightLogical, .OpShiftRightArithmetic), + .shl, .shl_exact => try self.airShift(inst, .sll, .sll), + .shr, .shr_exact => try self.airShift(inst, .srl, .sra), - .min => try self.airMinMax(inst, .lt), - .max => try self.airMinMax(inst, .gt), + .min => try self.airMinMax(inst, .min), + .max => try self.airMinMax(inst, .max), .bitcast => try self.airBitCast(inst), .intcast, .trunc => try self.airIntCast(inst), @@ -2574,39 +3385,23 @@ const DeclGen = struct { try self.inst_results.putNoClobber(self.gpa, inst, result_id); } - fn binOpSimple(self: *DeclGen, ty: Type, lhs_id: IdRef, rhs_id: IdRef, comptime opcode: Opcode) !IdRef { - var wip = try self.elementWise(ty, false); - defer wip.deinit(); - for (0..wip.results.len) |i| { - try self.func.body.emit(self.spv.gpa, opcode, .{ - .id_result_type = wip.ty_id, - .id_result = wip.allocId(i), - .operand_1 = try wip.elementAt(ty, lhs_id, i), - .operand_2 = try wip.elementAt(ty, rhs_id, i), - }); - } - return try wip.finalize(); - } - - fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef { + fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, op: BinaryOp) !?IdRef { const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - const ty = self.typeOf(bin_op.lhs); + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); - return try self.binOpSimple(ty, lhs_id, rhs_id, opcode); + const result = try self.buildBinary(op, lhs, rhs); + return try result.materialize(self); } - fn airShift(self: *DeclGen, inst: Air.Inst.Index, comptime unsigned: Opcode, comptime signed: Opcode) !?IdRef { + fn airShift(self: *DeclGen, inst: Air.Inst.Index, unsigned: BinaryOp, signed: BinaryOp) !?IdRef { const mod = self.module; const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); + + const base = try self.temporary(bin_op.lhs); + const shift = try self.temporary(bin_op.rhs); const result_ty = self.typeOfIndex(inst); - const shift_ty = self.typeOf(bin_op.rhs); - const scalar_result_ty_id = try self.resolveType(result_ty.scalarType(mod), .direct); - const scalar_shift_ty_id = try self.resolveType(shift_ty.scalarType(mod), .direct); const info = self.arithmeticTypeInfo(result_ty); switch (info.class) { @@ -2615,121 +3410,58 @@ const DeclGen = struct { .float, .bool => unreachable, } - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { - const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i); - const rhs_elem_id = try wip.elementAt(shift_ty, rhs_id, i); + // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, + // so just manually upcast it if required. - // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, - // so just manually upcast it if required. - const shift_id = if (scalar_shift_ty_id != scalar_result_ty_id) blk: { - const shift_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = wip.ty_id, - .id_result = shift_id, - .unsigned_value = rhs_elem_id, - }); - break :blk shift_id; - } else rhs_elem_id; + // Note: The sign may differ here between the shift and the base type, in case + // of an arithmetic right shift. SPIR-V still expects the same type, + // so in that case we have to cast convert to signed. + const casted_shift = try self.buildIntConvert(base.ty.scalarType(mod), shift); - const value_id = self.spv.allocId(); - const args = .{ - .id_result_type = wip.ty_id, - .id_result = value_id, - .base = lhs_elem_id, - .shift = shift_id, - }; + const shifted = switch (info.signedness) { + .unsigned => try self.buildBinary(unsigned, base, casted_shift), + .signed => try self.buildBinary(signed, base, casted_shift), + }; - if (result_ty.isSignedInt(mod)) { - try self.func.body.emit(self.spv.gpa, signed, args); - } else { - try self.func.body.emit(self.spv.gpa, unsigned, args); - } - - result_id.* = try self.normalize(wip.ty, value_id, info); - } - return try wip.finalize(); + const result = try self.normalize(shifted, info); + return try result.materialize(self); } - fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: std.math.CompareOperator) !?IdRef { + const MinMax = enum { min, max }; + + fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: MinMax) !?IdRef { const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - const result_ty = self.typeOfIndex(inst); - return try self.minMax(result_ty, op, lhs_id, rhs_id); + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); + + const result = try self.minMax(lhs, rhs, op); + return try result.materialize(self); } - fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef { - const info = self.arithmeticTypeInfo(result_ty); - const target = self.getTarget(); + fn minMax(self: *DeclGen, lhs: Temporary, rhs: Temporary, op: MinMax) !Temporary { + const info = self.arithmeticTypeInfo(lhs.ty); - const use_backup_codegen = target.os.tag == .opencl and info.class != .float; - var wip = try self.elementWise(result_ty, use_backup_codegen); - defer wip.deinit(); + const binop: BinaryOp = switch (info.class) { + .float => switch (op) { + .min => .f_min, + .max => .f_max, + }, + .integer, .strange_integer => switch (info.signedness) { + .signed => switch (op) { + .min => .s_min, + .max => .s_max, + }, + .unsigned => switch (op) { + .min => .u_min, + .max => .u_max, + }, + }, + .composite_integer => unreachable, // TODO + .bool => unreachable, + }; - for (wip.results, 0..) |*result_id, i| { - const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i); - const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i); - - if (use_backup_codegen) { - const cmp_id = try self.cmp(op, Type.bool, wip.ty, lhs_elem_id, rhs_elem_id); - result_id.* = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .condition = cmp_id, - .object_1 = lhs_elem_id, - .object_2 = rhs_elem_id, - }); - } else { - const ext_inst: Word = switch (target.os.tag) { - .opencl => switch (op) { - .lt => 28, // fmin - .gt => 27, // fmax - else => unreachable, - }, - .vulkan => switch (info.class) { - .float => switch (op) { - .lt => 37, // FMin - .gt => 40, // FMax - else => unreachable, - }, - .integer, .strange_integer => switch (info.signedness) { - .signed => switch (op) { - .lt => 39, // SMin - .gt => 42, // SMax - else => unreachable, - }, - .unsigned => switch (op) { - .lt => 38, // UMin - .gt => 41, // UMax - else => unreachable, - }, - }, - .composite_integer => unreachable, // TODO - .bool => unreachable, - }, - else => unreachable, - }; - const set_id = switch (target.os.tag) { - .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"), - .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"), - else => unreachable, - }; - - result_id.* = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .set = set_id, - .instruction = .{ .inst = ext_inst }, - .id_ref_4 = &.{ lhs_elem_id, rhs_elem_id }, - }); - } - } - return wip.finalize(); + return try self.buildBinary(binop, lhs, rhs); } /// This function normalizes values to a canonical representation @@ -2740,41 +3472,24 @@ const DeclGen = struct { /// - Signed integers are also sign extended if they are negative. /// All other values are returned unmodified (this makes strange integer /// wrapping easier to use in generic operations). - fn normalize(self: *DeclGen, ty: Type, value_id: IdRef, info: ArithmeticTypeInfo) !IdRef { + fn normalize(self: *DeclGen, value: Temporary, info: ArithmeticTypeInfo) !Temporary { + const mod = self.module; + const ty = value.ty; switch (info.class) { - .integer, .bool, .float => return value_id, + .integer, .bool, .float => return value, .composite_integer => unreachable, // TODO .strange_integer => switch (info.signedness) { .unsigned => { const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1; - const result_id = self.spv.allocId(); - const mask_id = try self.constInt(ty, mask_value, .direct); - try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = result_id, - .operand_1 = value_id, - .operand_2 = mask_id, - }); - return result_id; + const mask_id = try self.constInt(ty.scalarType(mod), mask_value, .direct); + return try self.buildBinary(.bit_and, value, Temporary.init(ty.scalarType(mod), mask_id)); }, .signed => { // Shift left and right so that we can copy the sight bit that way. - const shift_amt_id = try self.constInt(ty, info.backing_bits - info.bits, .direct); - const left_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = left_id, - .base = value_id, - .shift = shift_amt_id, - }); - const right_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = right_id, - .base = left_id, - .shift = shift_amt_id, - }); - return right_id; + const shift_amt_id = try self.constInt(ty.scalarType(mod), info.backing_bits - info.bits, .direct); + const shift_amt = Temporary.init(ty.scalarType(mod), shift_amt_id); + const left = try self.buildBinary(.sll, value, shift_amt); + return try self.buildBinary(.sra, left, shift_amt); }, }, } @@ -2782,491 +3497,438 @@ const DeclGen = struct { fn airDivFloor(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - const ty = self.typeOfIndex(inst); - const ty_id = try self.resolveType(ty, .direct); - const info = self.arithmeticTypeInfo(ty); + + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); + + const info = self.arithmeticTypeInfo(lhs.ty); switch (info.class) { .composite_integer => unreachable, // TODO .integer, .strange_integer => { - const zero_id = try self.constInt(ty, 0, .direct); - const one_id = try self.constInt(ty, 1, .direct); + switch (info.signedness) { + .unsigned => { + const result = try self.buildBinary(.u_div, lhs, rhs); + return try result.materialize(self); + }, + .signed => {}, + } - // (a ^ b) > 0 - const bin_bitwise_id = try self.binOpSimple(ty, lhs_id, rhs_id, .OpBitwiseXor); - const is_positive_id = try self.cmp(.gt, Type.bool, ty, bin_bitwise_id, zero_id); + // For signed integers: + // (a / b) - (a % b != 0 && a < 0 != b < 0); + // There shouldn't be any overflow issues. - // a / b - const positive_div_id = try self.arithOp(ty, lhs_id, rhs_id, .OpFDiv, .OpSDiv, .OpUDiv); + const div = try self.buildBinary(.s_div, lhs, rhs); + const rem = try self.buildBinary(.s_rem, lhs, rhs); - // - (abs(a) + abs(b) - 1) / abs(b) - const lhs_abs = try self.abs(ty, ty, lhs_id); - const rhs_abs = try self.abs(ty, ty, rhs_id); - const negative_div_lhs = try self.arithOp( - ty, - try self.arithOp(ty, lhs_abs, rhs_abs, .OpFAdd, .OpIAdd, .OpIAdd), - one_id, - .OpFSub, - .OpISub, - .OpISub, + const zero = Temporary.init(lhs.ty, try self.constInt(lhs.ty, 0, .direct)); + + const rem_is_not_zero = try self.buildCmp(.i_ne, rem, zero); + + const result_negative = try self.buildCmp( + .l_ne, + try self.buildCmp(.s_lt, lhs, zero), + try self.buildCmp(.s_lt, rhs, zero), + ); + const rem_is_not_zero_and_result_is_negative = try self.buildBinary( + .l_and, + rem_is_not_zero, + result_negative, ); - const negative_div_id = try self.arithOp(ty, negative_div_lhs, rhs_abs, .OpFDiv, .OpSDiv, .OpUDiv); - const negated_negative_div_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSNegate, .{ - .id_result_type = ty_id, - .id_result = negated_negative_div_id, - .operand = negative_div_id, - }); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = ty_id, - .id_result = result_id, - .condition = is_positive_id, - .object_1 = positive_div_id, - .object_2 = negated_negative_div_id, - }); - return result_id; + const result = try self.buildBinary( + .i_sub, + div, + try self.intFromBool2(rem_is_not_zero_and_result_is_negative, div.ty), + ); + + return try result.materialize(self); }, .float => { - const div_id = try self.arithOp(ty, lhs_id, rhs_id, .OpFDiv, .OpSDiv, .OpUDiv); - return try self.floor(ty, div_id); + const div = try self.buildBinary(.f_div, lhs, rhs); + const result = try self.buildUnary(.floor, div); + return try result.materialize(self); }, .bool => unreachable, } } - fn airFloor(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { - const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; - const operand_id = try self.resolve(un_op); - const result_ty = self.typeOfIndex(inst); - return try self.floor(result_ty, operand_id); + fn airDivTrunc(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; + + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); + + const info = self.arithmeticTypeInfo(lhs.ty); + switch (info.class) { + .composite_integer => unreachable, // TODO + .integer, .strange_integer => switch (info.signedness) { + .unsigned => { + const result = try self.buildBinary(.u_div, lhs, rhs); + return try result.materialize(self); + }, + .signed => { + const result = try self.buildBinary(.s_div, lhs, rhs); + return try result.materialize(self); + }, + }, + .float => { + const div = try self.buildBinary(.f_div, lhs, rhs); + const result = try self.buildUnary(.trunc, div); + return try result.materialize(self); + }, + .bool => unreachable, + } } - fn floor(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { - const target = self.getTarget(); - const ty_id = try self.resolveType(ty, .direct); - const ext_inst: Word = switch (target.os.tag) { - .opencl => 25, - .vulkan => 8, - else => unreachable, - }; - const set_id = switch (target.os.tag) { - .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"), - .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"), - else => unreachable, - }; - - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ - .id_result_type = ty_id, - .id_result = result_id, - .set = set_id, - .instruction = .{ .inst = ext_inst }, - .id_ref_4 = &.{operand_id}, - }); - return result_id; + fn airUnOpSimple(self: *DeclGen, inst: Air.Inst.Index, op: UnaryOp) !?IdRef { + const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; + const operand = try self.temporary(un_op); + const result = try self.buildUnary(op, operand); + return try result.materialize(self); } fn airArithOp( self: *DeclGen, inst: Air.Inst.Index, - comptime fop: Opcode, - comptime sop: Opcode, - comptime uop: Opcode, + comptime fop: BinaryOp, + comptime sop: BinaryOp, + comptime uop: BinaryOp, ) !?IdRef { - // LHS and RHS are guaranteed to have the same type, and AIR guarantees - // the result to be the same as the LHS and RHS, which matches SPIR-V. - const ty = self.typeOfIndex(inst); const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - assert(self.typeOf(bin_op.lhs).eql(ty, self.module)); - assert(self.typeOf(bin_op.rhs).eql(ty, self.module)); + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); - return try self.arithOp(ty, lhs_id, rhs_id, fop, sop, uop); - } + const info = self.arithmeticTypeInfo(lhs.ty); - fn arithOp( - self: *DeclGen, - ty: Type, - lhs_id: IdRef, - rhs_id: IdRef, - comptime fop: Opcode, - comptime sop: Opcode, - comptime uop: Opcode, - ) !IdRef { - // Binary operations are generally applicable to both scalar and vector operations - // in SPIR-V, but int and float versions of operations require different opcodes. - const info = self.arithmeticTypeInfo(ty); - - const opcode_index: usize = switch (info.class) { - .composite_integer => { - return self.todo("binary operations for composite integers", .{}); - }, + const result = switch (info.class) { + .composite_integer => unreachable, // TODO .integer, .strange_integer => switch (info.signedness) { - .signed => 1, - .unsigned => 2, + .signed => try self.buildBinary(sop, lhs, rhs), + .unsigned => try self.buildBinary(uop, lhs, rhs), }, - .float => 0, + .float => try self.buildBinary(fop, lhs, rhs), .bool => unreachable, }; - var wip = try self.elementWise(ty, false); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { - const lhs_elem_id = try wip.elementAt(ty, lhs_id, i); - const rhs_elem_id = try wip.elementAt(ty, rhs_id, i); - - const value_id = self.spv.allocId(); - const operands = .{ - .id_result_type = wip.ty_id, - .id_result = value_id, - .operand_1 = lhs_elem_id, - .operand_2 = rhs_elem_id, - }; - - switch (opcode_index) { - 0 => try self.func.body.emit(self.spv.gpa, fop, operands), - 1 => try self.func.body.emit(self.spv.gpa, sop, operands), - 2 => try self.func.body.emit(self.spv.gpa, uop, operands), - else => unreachable, - } - - // TODO: Trap on overflow? Probably going to be annoying. - // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap. - result_id.* = try self.normalize(wip.ty, value_id, info); - } - - return try wip.finalize(); + return try result.materialize(self); } fn airAbs(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; - const operand_id = try self.resolve(ty_op.operand); + const operand = try self.temporary(ty_op.operand); // Note: operand_ty may be signed, while ty is always unsigned! - const operand_ty = self.typeOf(ty_op.operand); const result_ty = self.typeOfIndex(inst); - return try self.abs(result_ty, operand_ty, operand_id); + const result = try self.abs(result_ty, operand); + return try result.materialize(self); } - fn abs(self: *DeclGen, result_ty: Type, operand_ty: Type, operand_id: IdRef) !IdRef { + fn abs(self: *DeclGen, result_ty: Type, value: Temporary) !Temporary { const target = self.getTarget(); - const operand_info = self.arithmeticTypeInfo(operand_ty); + const operand_info = self.arithmeticTypeInfo(value.ty); - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); + switch (operand_info.class) { + .float => return try self.buildUnary(.f_abs, value), + .integer, .strange_integer => { + const abs_value = try self.buildUnary(.i_abs, value); - for (wip.results, 0..) |*result_id, i| { - const elem_id = try wip.elementAt(operand_ty, operand_id, i); + // TODO: We may need to bitcast the result to a uint + // depending on the result type. Do that when + // bitCast is implemented for vectors. + // This is only relevant for Vulkan + assert(target.os.tag != .vulkan); // TODO - const ext_inst: Word = switch (target.os.tag) { - .opencl => switch (operand_info.class) { - .float => 23, // fabs - .integer, .strange_integer => switch (operand_info.signedness) { - .signed => 141, // s_abs - .unsigned => 201, // u_abs - }, - .composite_integer => unreachable, // TODO - .bool => unreachable, - }, - .vulkan => switch (operand_info.class) { - .float => 4, // FAbs - .integer, .strange_integer => 5, // SAbs - .composite_integer => unreachable, // TODO - .bool => unreachable, - }, - else => unreachable, - }; - const set_id = switch (target.os.tag) { - .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"), - .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"), - else => unreachable, - }; - - result_id.* = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .set = set_id, - .instruction = .{ .inst = ext_inst }, - .id_ref_4 = &.{elem_id}, - }); + return try self.normalize(abs_value, self.arithmeticTypeInfo(result_ty)); + }, + .composite_integer => unreachable, // TODO + .bool => unreachable, } - return try wip.finalize(); } fn airAddSubOverflow( self: *DeclGen, inst: Air.Inst.Index, - comptime add: Opcode, - comptime ucmp: Opcode, - comptime scmp: Opcode, + comptime add: BinaryOp, + comptime ucmp: CmpPredicate, + comptime scmp: CmpPredicate, ) !?IdRef { - const mod = self.module; + // Note: OpIAddCarry and OpISubBorrow are not really useful here: For unsigned numbers, + // there is in both cases only one extra operation required. For signed operations, + // the overflow bit is set then going from 0x80.. to 0x00.., but this doesn't actually + // normally set a carry bit. So the SPIR-V overflow operations are not particularly + // useful here. + const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const extra = self.air.extraData(Air.Bin, ty_pl.payload).data; - const lhs = try self.resolve(extra.lhs); - const rhs = try self.resolve(extra.rhs); + + const lhs = try self.temporary(extra.lhs); + const rhs = try self.temporary(extra.rhs); const result_ty = self.typeOfIndex(inst); - const operand_ty = self.typeOf(extra.lhs); - const ov_ty = result_ty.structFieldType(1, self.module); - const bool_ty_id = try self.resolveType(Type.bool, .direct); - const cmp_ty_id = if (self.isSpvVector(operand_ty)) - // TODO: Resolving a vector type with .direct should return a SPIR-V vector - try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct)) - else - bool_ty_id; - - const info = self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(lhs.ty); switch (info.class) { - .composite_integer => return self.todo("overflow ops for composite integers", .{}), + .composite_integer => unreachable, // TODO .strange_integer, .integer => {}, .float, .bool => unreachable, } - var wip_result = try self.elementWise(operand_ty, false); - defer wip_result.deinit(); - var wip_ov = try self.elementWise(ov_ty, false); - defer wip_ov.deinit(); - for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| { - const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i); - const rhs_elem_id = try wip_result.elementAt(operand_ty, rhs, i); + const sum = try self.buildBinary(add, lhs, rhs); + const result = try self.normalize(sum, info); - // Normalize both so that we can properly check for overflow - const value_id = self.spv.allocId(); + const overflowed = switch (info.signedness) { + // Overflow happened if the result is smaller than either of the operands. It doesn't matter which. + // For subtraction the conditions need to be swapped. + .unsigned => try self.buildCmp(ucmp, result, lhs), + // For addition, overflow happened if: + // - rhs is negative and value > lhs + // - rhs is positive and value < lhs + // This can be shortened to: + // (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs) + // = (rhs < 0) == (value > lhs) + // = (rhs < 0) == (lhs < value) + // Note that signed overflow is also wrapping in spir-v. + // For subtraction, overflow happened if: + // - rhs is negative and value < lhs + // - rhs is positive and value > lhs + // This can be shortened to: + // (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs) + // = (rhs < 0) == (value < lhs) + // = (rhs < 0) == (lhs > value) + .signed => blk: { + const zero = Temporary.init(rhs.ty, try self.constInt(rhs.ty, 0, .direct)); + const rhs_lt_zero = try self.buildCmp(.s_lt, rhs, zero); + const result_gt_lhs = try self.buildCmp(scmp, lhs, result); + break :blk try self.buildCmp(.l_eq, rhs_lt_zero, result_gt_lhs); + }, + }; - try self.func.body.emit(self.spv.gpa, add, .{ - .id_result_type = wip_result.ty_id, - .id_result = value_id, - .operand_1 = lhs_elem_id, - .operand_2 = rhs_elem_id, - }); - - // Normalize the result so that the comparisons go well - result_id.* = try self.normalize(wip_result.ty, value_id, info); - - const overflowed_id = switch (info.signedness) { - .unsigned => blk: { - // Overflow happened if the result is smaller than either of the operands. It doesn't matter which. - // For subtraction the conditions need to be swapped. - const overflowed_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, ucmp, .{ - .id_result_type = cmp_ty_id, - .id_result = overflowed_id, - .operand_1 = result_id.*, - .operand_2 = lhs_elem_id, - }); - break :blk overflowed_id; - }, - .signed => blk: { - // lhs - rhs - // For addition, overflow happened if: - // - rhs is negative and value > lhs - // - rhs is positive and value < lhs - // This can be shortened to: - // (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs) - // = (rhs < 0) == (value > lhs) - // = (rhs < 0) == (lhs < value) - // Note that signed overflow is also wrapping in spir-v. - // For subtraction, overflow happened if: - // - rhs is negative and value < lhs - // - rhs is positive and value > lhs - // This can be shortened to: - // (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs) - // = (rhs < 0) == (value < lhs) - // = (rhs < 0) == (lhs > value) - - const rhs_lt_zero_id = self.spv.allocId(); - const zero_id = try self.constInt(wip_result.ty, 0, .direct); - try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{ - .id_result_type = cmp_ty_id, - .id_result = rhs_lt_zero_id, - .operand_1 = rhs_elem_id, - .operand_2 = zero_id, - }); - - const value_gt_lhs_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, scmp, .{ - .id_result_type = cmp_ty_id, - .id_result = value_gt_lhs_id, - .operand_1 = lhs_elem_id, - .operand_2 = result_id.*, - }); - - const overflowed_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalEqual, .{ - .id_result_type = cmp_ty_id, - .id_result = overflowed_id, - .operand_1 = rhs_lt_zero_id, - .operand_2 = value_gt_lhs_id, - }); - break :blk overflowed_id; - }, - }; - - ov_id.* = try self.intFromBool(wip_ov.ty, overflowed_id); - } + const ov = try self.intFromBool(overflowed); return try self.constructStruct( result_ty, - &.{ operand_ty, ov_ty }, - &.{ try wip_result.finalize(), try wip_ov.finalize() }, + &.{ result.ty, ov.ty }, + &.{ try result.materialize(self), try ov.materialize(self) }, ); } fn airMulOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + const target = self.getTarget(); + const mod = self.module; + const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const extra = self.air.extraData(Air.Bin, ty_pl.payload).data; - const lhs = try self.resolve(extra.lhs); - const rhs = try self.resolve(extra.rhs); + + const lhs = try self.temporary(extra.lhs); + const rhs = try self.temporary(extra.rhs); const result_ty = self.typeOfIndex(inst); - const operand_ty = self.typeOf(extra.lhs); - const ov_ty = result_ty.structFieldType(1, self.module); - const info = self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(lhs.ty); switch (info.class) { - .composite_integer => return self.todo("overflow ops for composite integers", .{}), + .composite_integer => unreachable, // TODO .strange_integer, .integer => {}, .float, .bool => unreachable, } - var wip_result = try self.elementWise(operand_ty, true); - defer wip_result.deinit(); - var wip_ov = try self.elementWise(ov_ty, true); - defer wip_ov.deinit(); + // There are 3 cases which we have to deal with: + // - If info.bits < 32 / 2, we will upcast to 32 and check the higher bits + // - If info.bits > 32 / 2, we have to use extended multiplication + // - Additionally, if info.bits != 32, we'll have to check the high bits + // of the result too. - const zero_id = try self.constInt(wip_result.ty, 0, .direct); - const zero_ov_id = try self.constInt(wip_ov.ty, 0, .direct); - const one_ov_id = try self.constInt(wip_ov.ty, 1, .direct); + const largest_int_bits: u16 = if (Target.spirv.featureSetHas(target.cpu.features, .Int64)) 64 else 32; + // If non-null, the number of bits that the multiplication should be performed in. If + // null, we have to use wide multiplication. + const maybe_op_ty_bits: ?u16 = switch (info.bits) { + 0 => unreachable, + 1...16 => 32, + 17...32 => if (largest_int_bits > 32) 64 else null, // Upcast if we can. + 33...64 => null, // Always use wide multiplication. + else => unreachable, // TODO: Composite integers + }; - for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| { - const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i); - const rhs_elem_id = try wip_result.elementAt(operand_ty, rhs, i); + const result, const overflowed = switch (info.signedness) { + .unsigned => blk: { + if (maybe_op_ty_bits) |op_ty_bits| { + const op_ty = try mod.intType(.unsigned, op_ty_bits); + const casted_lhs = try self.buildIntConvert(op_ty, lhs); + const casted_rhs = try self.buildIntConvert(op_ty, rhs); - result_id.* = try self.arithOp(wip_result.ty, lhs_elem_id, rhs_elem_id, .OpFMul, .OpIMul, .OpIMul); + const full_result = try self.buildBinary(.i_mul, casted_lhs, casted_rhs); - // (a != 0) and (x / a != b) - const not_zero_id = try self.cmp(.neq, Type.bool, wip_result.ty, lhs_elem_id, zero_id); - const res_rhs_id = try self.arithOp(wip_result.ty, result_id.*, lhs_elem_id, .OpFDiv, .OpSDiv, .OpUDiv); - const res_rhs_not_rhs_id = try self.cmp(.neq, Type.bool, wip_result.ty, res_rhs_id, rhs_elem_id); - const cond_id = try self.binOpSimple(Type.bool, not_zero_id, res_rhs_not_rhs_id, .OpLogicalAnd); + const low_bits = try self.buildIntConvert(lhs.ty, full_result); + const result = try self.normalize(low_bits, info); - ov_id.* = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = wip_ov.ty_id, - .id_result = ov_id.*, - .condition = cond_id, - .object_1 = one_ov_id, - .object_2 = zero_ov_id, - }); - } + // Shift the result bits away to get the overflow bits. + const shift = Temporary.init(full_result.ty, try self.constInt(full_result.ty, info.bits, .direct)); + const overflow = try self.buildBinary(.srl, full_result, shift); + + // Directly check if its zero in the op_ty without converting first. + const zero = Temporary.init(full_result.ty, try self.constInt(full_result.ty, 0, .direct)); + const overflowed = try self.buildCmp(.i_ne, zero, overflow); + + break :blk .{ result, overflowed }; + } + + const low_bits, const high_bits = try self.buildWideMul(.u_mul_extended, lhs, rhs); + + // Truncate the result, if required. + const result = try self.normalize(low_bits, info); + + // Overflow happened if the high-bits of the result are non-zero OR if the + // high bits of the low word of the result (those outside the range of the + // int) are nonzero. + const zero = Temporary.init(lhs.ty, try self.constInt(lhs.ty, 0, .direct)); + const high_overflowed = try self.buildCmp(.i_ne, zero, high_bits); + + // If no overflow bits in low_bits, no extra work needs to be done. + if (info.backing_bits == info.bits) { + break :blk .{ result, high_overflowed }; + } + + // Shift the result bits away to get the overflow bits. + const shift = Temporary.init(lhs.ty, try self.constInt(lhs.ty, info.bits, .direct)); + const low_overflow = try self.buildBinary(.srl, low_bits, shift); + const low_overflowed = try self.buildCmp(.i_ne, zero, low_overflow); + + const overflowed = try self.buildBinary(.l_or, low_overflowed, high_overflowed); + + break :blk .{ result, overflowed }; + }, + .signed => blk: { + // - lhs >= 0, rhxs >= 0: expect positive; overflow should be 0 + // - lhs == 0 : expect positive; overflow should be 0 + // - rhs == 0: expect positive; overflow should be 0 + // - lhs > 0, rhs < 0: expect negative; overflow should be -1 + // - lhs < 0, rhs > 0: expect negative; overflow should be -1 + // - lhs <= 0, rhs <= 0: expect positive; overflow should be 0 + // ------ + // overflow should be -1 when + // (lhs > 0 && rhs < 0) || (lhs < 0 && rhs > 0) + + const zero = Temporary.init(lhs.ty, try self.constInt(lhs.ty, 0, .direct)); + const lhs_negative = try self.buildCmp(.s_lt, lhs, zero); + const rhs_negative = try self.buildCmp(.s_lt, rhs, zero); + const lhs_positive = try self.buildCmp(.s_gt, lhs, zero); + const rhs_positive = try self.buildCmp(.s_gt, rhs, zero); + + // Set to `true` if we expect -1. + const expected_overflow_bit = try self.buildBinary( + .l_or, + try self.buildBinary(.l_and, lhs_positive, rhs_negative), + try self.buildBinary(.l_and, lhs_negative, rhs_positive), + ); + + if (maybe_op_ty_bits) |op_ty_bits| { + const op_ty = try mod.intType(.signed, op_ty_bits); + // Assume normalized; sign bit is set. We want a sign extend. + const casted_lhs = try self.buildIntConvert(op_ty, lhs); + const casted_rhs = try self.buildIntConvert(op_ty, rhs); + + const full_result = try self.buildBinary(.i_mul, casted_lhs, casted_rhs); + + // Truncate to the result type. + const low_bits = try self.buildIntConvert(lhs.ty, full_result); + const result = try self.normalize(low_bits, info); + + // Now, we need to check the overflow bits AND the sign + // bit for the expceted overflow bits. + // To do that, shift out everything bit the sign bit and + // then check what remains. + const shift = Temporary.init(full_result.ty, try self.constInt(full_result.ty, info.bits - 1, .direct)); + // Use SRA so that any sign bits are duplicated. Now we can just check if ALL bits are set + // for negative cases. + const overflow = try self.buildBinary(.sra, full_result, shift); + + const long_all_set = Temporary.init(full_result.ty, try self.constInt(full_result.ty, -1, .direct)); + const long_zero = Temporary.init(full_result.ty, try self.constInt(full_result.ty, 0, .direct)); + const mask = try self.buildSelect(expected_overflow_bit, long_all_set, long_zero); + + const overflowed = try self.buildCmp(.i_ne, mask, overflow); + + break :blk .{ result, overflowed }; + } + + const low_bits, const high_bits = try self.buildWideMul(.s_mul_extended, lhs, rhs); + + // Truncate result if required. + const result = try self.normalize(low_bits, info); + + const all_set = Temporary.init(lhs.ty, try self.constInt(lhs.ty, -1, .direct)); + const mask = try self.buildSelect(expected_overflow_bit, all_set, zero); + + // Like with unsigned, overflow happened if high_bits are not the ones we expect, + // and we also need to check some ones from the low bits. + + const high_overflowed = try self.buildCmp(.i_ne, mask, high_bits); + + // If no overflow bits in low_bits, no extra work needs to be done. + // Careful, we still have to check the sign bit, so this branch + // only goes for i33 and such. + if (info.backing_bits == info.bits + 1) { + break :blk .{ result, high_overflowed }; + } + + // Shift the result bits away to get the overflow bits. + const shift = Temporary.init(lhs.ty, try self.constInt(lhs.ty, info.bits - 1, .direct)); + // Use SRA so that any sign bits are duplicated. Now we can just check if ALL bits are set + // for negative cases. + const low_overflow = try self.buildBinary(.sra, low_bits, shift); + const low_overflowed = try self.buildCmp(.i_ne, mask, low_overflow); + + const overflowed = try self.buildBinary(.l_or, low_overflowed, high_overflowed); + + break :blk .{ result, overflowed }; + }, + }; + + const ov = try self.intFromBool(overflowed); return try self.constructStruct( result_ty, - &.{ operand_ty, ov_ty }, - &.{ try wip_result.finalize(), try wip_ov.finalize() }, + &.{ result.ty, ov.ty }, + &.{ try result.materialize(self), try ov.materialize(self) }, ); } fn airShlOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const mod = self.module; + const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const extra = self.air.extraData(Air.Bin, ty_pl.payload).data; - const lhs = try self.resolve(extra.lhs); - const rhs = try self.resolve(extra.rhs); + + const base = try self.temporary(extra.lhs); + const shift = try self.temporary(extra.rhs); const result_ty = self.typeOfIndex(inst); - const operand_ty = self.typeOf(extra.lhs); - const shift_ty = self.typeOf(extra.rhs); - const scalar_shift_ty_id = try self.resolveType(shift_ty.scalarType(mod), .direct); - const scalar_operand_ty_id = try self.resolveType(operand_ty.scalarType(mod), .direct); - const ov_ty = result_ty.structFieldType(1, self.module); - - const bool_ty_id = try self.resolveType(Type.bool, .direct); - const cmp_ty_id = if (self.isSpvVector(operand_ty)) - // TODO: Resolving a vector type with .direct should return a SPIR-V vector - try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct)) - else - bool_ty_id; - - const info = self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(base.ty); switch (info.class) { - .composite_integer => return self.todo("overflow shift for composite integers", .{}), + .composite_integer => unreachable, // TODO .integer, .strange_integer => {}, .float, .bool => unreachable, } - var wip_result = try self.elementWise(operand_ty, false); - defer wip_result.deinit(); - var wip_ov = try self.elementWise(ov_ty, false); - defer wip_ov.deinit(); - for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| { - const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i); - const rhs_elem_id = try wip_result.elementAt(shift_ty, rhs, i); + // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, + // so just manually upcast it if required. + const casted_shift = try self.buildIntConvert(base.ty.scalarType(mod), shift); - // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, - // so just manually upcast it if required. - const shift_id = if (scalar_shift_ty_id != scalar_operand_ty_id) blk: { - const shift_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = wip_result.ty_id, - .id_result = shift_id, - .unsigned_value = rhs_elem_id, - }); - break :blk shift_id; - } else rhs_elem_id; + const left = try self.buildBinary(.sll, base, casted_shift); + const result = try self.normalize(left, info); - const value_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{ - .id_result_type = wip_result.ty_id, - .id_result = value_id, - .base = lhs_elem_id, - .shift = shift_id, - }); - result_id.* = try self.normalize(wip_result.ty, value_id, info); + const right = switch (info.signedness) { + .unsigned => try self.buildBinary(.srl, result, casted_shift), + .signed => try self.buildBinary(.sra, result, casted_shift), + }; - const right_shift_id = self.spv.allocId(); - switch (info.signedness) { - .signed => { - try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{ - .id_result_type = wip_result.ty_id, - .id_result = right_shift_id, - .base = result_id.*, - .shift = shift_id, - }); - }, - .unsigned => { - try self.func.body.emit(self.spv.gpa, .OpShiftRightLogical, .{ - .id_result_type = wip_result.ty_id, - .id_result = right_shift_id, - .base = result_id.*, - .shift = shift_id, - }); - }, - } - - const overflowed_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ - .id_result_type = cmp_ty_id, - .id_result = overflowed_id, - .operand_1 = lhs_elem_id, - .operand_2 = right_shift_id, - }); - - ov_id.* = try self.intFromBool(wip_ov.ty, overflowed_id); - } + const overflowed = try self.buildCmp(.i_ne, base, right); + const ov = try self.intFromBool(overflowed); return try self.constructStruct( result_ty, - &.{ operand_ty, ov_ty }, - &.{ try wip_result.finalize(), try wip_ov.finalize() }, + &.{ result.ty, ov.ty }, + &.{ try result.materialize(self), try ov.materialize(self) }, ); } @@ -3274,122 +3936,67 @@ const DeclGen = struct { const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; const extra = self.air.extraData(Air.Bin, pl_op.payload).data; - const mulend1 = try self.resolve(extra.lhs); - const mulend2 = try self.resolve(extra.rhs); - const addend = try self.resolve(pl_op.operand); + const a = try self.temporary(extra.lhs); + const b = try self.temporary(extra.rhs); + const c = try self.temporary(pl_op.operand); - const ty = self.typeOfIndex(inst); - - const info = self.arithmeticTypeInfo(ty); + const result_ty = self.typeOfIndex(inst); + const info = self.arithmeticTypeInfo(result_ty); assert(info.class == .float); // .mul_add is only emitted for floats - var wip = try self.elementWise(ty, false); - defer wip.deinit(); - for (0..wip.results.len) |i| { - const mul_result = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpFMul, .{ - .id_result_type = wip.ty_id, - .id_result = mul_result, - .operand_1 = try wip.elementAt(ty, mulend1, i), - .operand_2 = try wip.elementAt(ty, mulend2, i), - }); - - try self.func.body.emit(self.spv.gpa, .OpFAdd, .{ - .id_result_type = wip.ty_id, - .id_result = wip.allocId(i), - .operand_1 = mul_result, - .operand_2 = try wip.elementAt(ty, addend, i), - }); - } - return try wip.finalize(); + const result = try self.buildFma(a, b, c); + return try result.materialize(self); } - fn airClzCtz(self: *DeclGen, inst: Air.Inst.Index, op: enum { clz, ctz }) !?IdRef { + fn airClzCtz(self: *DeclGen, inst: Air.Inst.Index, op: UnaryOp) !?IdRef { if (self.liveness.isUnused(inst)) return null; const mod = self.module; const target = self.getTarget(); const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; - const result_ty = self.typeOfIndex(inst); - const operand_ty = self.typeOf(ty_op.operand); - const operand = try self.resolve(ty_op.operand); + const operand = try self.temporary(ty_op.operand); - const info = self.arithmeticTypeInfo(operand_ty); + const scalar_result_ty = self.typeOfIndex(inst).scalarType(mod); + + const info = self.arithmeticTypeInfo(operand.ty); switch (info.class) { .composite_integer => unreachable, // TODO .integer, .strange_integer => {}, .float, .bool => unreachable, } - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); - - const elem_ty = if (wip.is_array) operand_ty.scalarType(mod) else operand_ty; - const elem_ty_id = try self.resolveType(elem_ty, .direct); - - for (wip.results, 0..) |*result_id, i| { - const elem = try wip.elementAt(operand_ty, operand, i); - - switch (target.os.tag) { - .opencl => { - const set = try self.spv.importInstructionSet(.@"OpenCL.std"); - const ext_inst: u32 = switch (op) { - .clz => 151, // clz - .ctz => 152, // ctz - }; - - // Note: result of OpenCL ctz/clz returns operand_ty, and we want result_ty. - // result_ty is always large enough to hold the result, so we might have to down - // cast it. - const tmp = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ - .id_result_type = elem_ty_id, - .id_result = tmp, - .set = set, - .instruction = .{ .inst = ext_inst }, - .id_ref_4 = &.{elem}, - }); - - // TODO: Comparison should be removed.. - // Its valid because SpvModule caches numeric types - if (wip.ty_id == elem_ty_id) { - result_id.* = tmp; - continue; - } - - result_id.* = self.spv.allocId(); - if (result_ty.scalarType(mod).isSignedInt(mod)) { - assert(elem_ty.scalarType(mod).isSignedInt(mod)); - try self.func.body.emit(self.spv.gpa, .OpSConvert, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .signed_value = tmp, - }); - } else { - assert(elem_ty.scalarType(mod).isUnsignedInt(mod)); - try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .unsigned_value = tmp, - }); - } - }, - .vulkan => unreachable, // TODO - else => unreachable, - } + switch (target.os.tag) { + .vulkan => unreachable, // TODO + else => {}, } - return try wip.finalize(); + const count = try self.buildUnary(op, operand); + + // Result of OpenCL ctz/clz returns operand.ty, and we want result_ty. + // result_ty is always large enough to hold the result, so we might have to down + // cast it. + const result = try self.buildIntConvert(scalar_result_ty, count); + return try result.materialize(self); + } + + fn airSelect(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; + const extra = self.air.extraData(Air.Bin, pl_op.payload).data; + const pred = try self.temporary(pl_op.operand); + const a = try self.temporary(extra.lhs); + const b = try self.temporary(extra.rhs); + + const result = try self.buildSelect(pred, a, b); + return try result.materialize(self); } fn airSplat(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; + const operand_id = try self.resolve(ty_op.operand); const result_ty = self.typeOfIndex(inst); - var wip = try self.elementWise(result_ty, true); - defer wip.deinit(); - @memset(wip.results, operand_id); - return try wip.finalize(); + + return try self.constructVectorSplat(result_ty, operand_id); } fn airReduce(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -3402,23 +4009,33 @@ const DeclGen = struct { const info = self.arithmeticTypeInfo(operand_ty); - var result_id = try self.extractVectorComponent(scalar_ty, operand, 0); const len = operand_ty.vectorLen(mod); + const first = try self.extractVectorComponent(scalar_ty, operand, 0); + switch (reduce.operation) { .Min, .Max => |op| { - const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt; + var result = Temporary.init(scalar_ty, first); + const cmp_op: MinMax = switch (op) { + .Max => .max, + .Min => .min, + else => unreachable, + }; for (1..len) |i| { - const lhs = result_id; - const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i)); - result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs); + const lhs = result; + const rhs_id = try self.extractVectorComponent(scalar_ty, operand, @intCast(i)); + const rhs = Temporary.init(scalar_ty, rhs_id); + + result = try self.minMax(lhs, rhs, cmp_op); } - return result_id; + return try result.materialize(self); }, else => {}, } + var result_id = first; + const opcode: Opcode = switch (info.class) { .bool => switch (reduce.operation) { .And => .OpLogicalAnd, @@ -3602,50 +4219,66 @@ const DeclGen = struct { fn cmp( self: *DeclGen, op: std.math.CompareOperator, - result_ty: Type, - ty: Type, - lhs_id: IdRef, - rhs_id: IdRef, - ) !IdRef { + lhs: Temporary, + rhs: Temporary, + ) !Temporary { const mod = self.module; - var cmp_lhs_id = lhs_id; - var cmp_rhs_id = rhs_id; - const bool_ty_id = try self.resolveType(Type.bool, .direct); - const op_ty = switch (ty.zigTypeTag(mod)) { - .Int, .Bool, .Float => ty, - .Enum => ty.intTagType(mod), - .ErrorSet => Type.u16, - .Pointer => blk: { + const scalar_ty = lhs.ty.scalarType(mod); + const is_vector = lhs.ty.isVector(mod); + + switch (scalar_ty.zigTypeTag(mod)) { + .Int, .Bool, .Float => {}, + .Enum => { + assert(!is_vector); + const ty = lhs.ty.intTagType(mod); + return try self.cmp(op, lhs.pun(ty), rhs.pun(ty)); + }, + .ErrorSet => { + assert(!is_vector); + return try self.cmp(op, lhs.pun(Type.u16), rhs.pun(Type.u16)); + }, + .Pointer => { + assert(!is_vector); // Note that while SPIR-V offers OpPtrEqual and OpPtrNotEqual, they are // currently not implemented in the SPIR-V LLVM translator. Thus, we emit these using // OpConvertPtrToU... - cmp_lhs_id = self.spv.allocId(); - cmp_rhs_id = self.spv.allocId(); const usize_ty_id = try self.resolveType(Type.usize, .direct); + const lhs_int_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{ .id_result_type = usize_ty_id, - .id_result = cmp_lhs_id, - .pointer = lhs_id, + .id_result = lhs_int_id, + .pointer = try lhs.materialize(self), }); + const rhs_int_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{ .id_result_type = usize_ty_id, - .id_result = cmp_rhs_id, - .pointer = rhs_id, + .id_result = rhs_int_id, + .pointer = try rhs.materialize(self), }); - break :blk Type.usize; + const lhs_int = Temporary.init(Type.usize, lhs_int_id); + const rhs_int = Temporary.init(Type.usize, rhs_int_id); + return try self.cmp(op, lhs_int, rhs_int); }, .Optional => { + assert(!is_vector); + + const ty = lhs.ty; + const payload_ty = ty.optionalChild(mod); if (ty.optionalReprIsPayload(mod)) { assert(payload_ty.hasRuntimeBitsIgnoreComptime(mod)); assert(!payload_ty.isSlice(mod)); - return self.cmp(op, Type.bool, payload_ty, lhs_id, rhs_id); + + return try self.cmp(op, lhs.pun(payload_ty), rhs.pun(payload_ty)); } + const lhs_id = try lhs.materialize(self); + const rhs_id = try rhs.materialize(self); + const lhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) try self.extractField(Type.bool, lhs_id, 1) else @@ -3656,8 +4289,11 @@ const DeclGen = struct { else try self.convertToDirect(Type.bool, rhs_id); + const lhs_valid = Temporary.init(Type.bool, lhs_valid_id); + const rhs_valid = Temporary.init(Type.bool, rhs_valid_id); + if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) { - return try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id); + return try self.cmp(op, lhs_valid, rhs_valid); } // a = lhs_valid @@ -3678,118 +4314,71 @@ const DeclGen = struct { const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0); const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0); - switch (op) { - .eq => { - const valid_eq_id = try self.cmp(.eq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id); - const pl_eq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id); - const lhs_not_valid_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalNot, .{ - .id_result_type = bool_ty_id, - .id_result = lhs_not_valid_id, - .operand = lhs_valid_id, - }); - const impl_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{ - .id_result_type = bool_ty_id, - .id_result = impl_id, - .operand_1 = lhs_not_valid_id, - .operand_2 = pl_eq_id, - }); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{ - .id_result_type = bool_ty_id, - .id_result = result_id, - .operand_1 = valid_eq_id, - .operand_2 = impl_id, - }); - return result_id; - }, - .neq => { - const valid_neq_id = try self.cmp(.neq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id); - const pl_neq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id); + const lhs_pl = Temporary.init(payload_ty, lhs_pl_id); + const rhs_pl = Temporary.init(payload_ty, rhs_pl_id); - const impl_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{ - .id_result_type = bool_ty_id, - .id_result = impl_id, - .operand_1 = lhs_valid_id, - .operand_2 = pl_neq_id, - }); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{ - .id_result_type = bool_ty_id, - .id_result = result_id, - .operand_1 = valid_neq_id, - .operand_2 = impl_id, - }); - return result_id; - }, + return switch (op) { + .eq => try self.buildBinary( + .l_and, + try self.cmp(.eq, lhs_valid, rhs_valid), + try self.buildBinary( + .l_or, + try self.buildUnary(.l_not, lhs_valid), + try self.cmp(.eq, lhs_pl, rhs_pl), + ), + ), + .neq => try self.buildBinary( + .l_or, + try self.cmp(.neq, lhs_valid, rhs_valid), + try self.buildBinary( + .l_and, + lhs_valid, + try self.cmp(.neq, lhs_pl, rhs_pl), + ), + ), else => unreachable, - } - }, - .Vector => { - var wip = try self.elementWise(result_ty, true); - defer wip.deinit(); - const scalar_ty = ty.scalarType(mod); - for (wip.results, 0..) |*result_id, i| { - const lhs_elem_id = try wip.elementAt(ty, lhs_id, i); - const rhs_elem_id = try wip.elementAt(ty, rhs_id, i); - result_id.* = try self.cmp(op, Type.bool, scalar_ty, lhs_elem_id, rhs_elem_id); - } - return wip.finalize(); + }; }, else => unreachable, - }; + } - const opcode: Opcode = opcode: { - const info = self.arithmeticTypeInfo(op_ty); - const signedness = switch (info.class) { - .composite_integer => { - return self.todo("binary operations for composite integers", .{}); - }, - .float => break :opcode switch (op) { - .eq => .OpFOrdEqual, - .neq => .OpFUnordNotEqual, - .lt => .OpFOrdLessThan, - .lte => .OpFOrdLessThanEqual, - .gt => .OpFOrdGreaterThan, - .gte => .OpFOrdGreaterThanEqual, - }, - .bool => break :opcode switch (op) { - .eq => .OpLogicalEqual, - .neq => .OpLogicalNotEqual, - else => unreachable, - }, - .integer, .strange_integer => info.signedness, - }; - - break :opcode switch (signedness) { - .unsigned => switch (op) { - .eq => .OpIEqual, - .neq => .OpINotEqual, - .lt => .OpULessThan, - .lte => .OpULessThanEqual, - .gt => .OpUGreaterThan, - .gte => .OpUGreaterThanEqual, - }, + const info = self.arithmeticTypeInfo(scalar_ty); + const pred: CmpPredicate = switch (info.class) { + .composite_integer => unreachable, // TODO + .float => switch (op) { + .eq => .f_oeq, + .neq => .f_une, + .lt => .f_olt, + .lte => .f_ole, + .gt => .f_ogt, + .gte => .f_oge, + }, + .bool => switch (op) { + .eq => .l_eq, + .neq => .l_ne, + else => unreachable, + }, + .integer, .strange_integer => switch (info.signedness) { .signed => switch (op) { - .eq => .OpIEqual, - .neq => .OpINotEqual, - .lt => .OpSLessThan, - .lte => .OpSLessThanEqual, - .gt => .OpSGreaterThan, - .gte => .OpSGreaterThanEqual, + .eq => .i_eq, + .neq => .i_ne, + .lt => .s_lt, + .lte => .s_le, + .gt => .s_gt, + .gte => .s_ge, }, - }; + .unsigned => switch (op) { + .eq => .i_eq, + .neq => .i_ne, + .lt => .u_lt, + .lte => .u_le, + .gt => .u_gt, + .gte => .u_ge, + }, + }, }; - const result_id = self.spv.allocId(); - try self.func.body.emitRaw(self.spv.gpa, opcode, 4); - self.func.body.writeOperand(spec.IdResultType, bool_ty_id); - self.func.body.writeOperand(spec.IdResult, result_id); - self.func.body.writeOperand(spec.IdResultType, cmp_lhs_id); - self.func.body.writeOperand(spec.IdResultType, cmp_rhs_id); - return result_id; + return try self.buildCmp(pred, lhs, rhs); } fn airCmp( @@ -3798,24 +4387,22 @@ const DeclGen = struct { comptime op: std.math.CompareOperator, ) !?IdRef { const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - const ty = self.typeOf(bin_op.lhs); - const result_ty = self.typeOfIndex(inst); + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); - return try self.cmp(op, result_ty, ty, lhs_id, rhs_id); + const result = try self.cmp(op, lhs, rhs); + return try result.materialize(self); } fn airVectorCmp(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const vec_cmp = self.air.extraData(Air.VectorCmp, ty_pl.payload).data; - const lhs_id = try self.resolve(vec_cmp.lhs); - const rhs_id = try self.resolve(vec_cmp.rhs); + const lhs = try self.temporary(vec_cmp.lhs); + const rhs = try self.temporary(vec_cmp.rhs); const op = vec_cmp.compareOperator(); - const ty = self.typeOf(vec_cmp.lhs); - const result_ty = self.typeOfIndex(inst); - return try self.cmp(op, result_ty, ty, lhs_id, rhs_id); + const result = try self.cmp(op, lhs, rhs); + return try result.materialize(self); } /// Bitcast one type to another. Note: both types, input, output are expected in **direct** representation. @@ -3881,7 +4468,8 @@ const DeclGen = struct { // should we change the representation of strange integers? if (dst_ty.zigTypeTag(mod) == .Int) { const info = self.arithmeticTypeInfo(dst_ty); - return try self.normalize(dst_ty, result_id, info); + const result = try self.normalize(Temporary.init(dst_ty, result_id), info); + return try result.materialize(self); } return result_id; @@ -3897,46 +4485,28 @@ const DeclGen = struct { fn airIntCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; - const operand_id = try self.resolve(ty_op.operand); - const src_ty = self.typeOf(ty_op.operand); + const src = try self.temporary(ty_op.operand); const dst_ty = self.typeOfIndex(inst); - const src_info = self.arithmeticTypeInfo(src_ty); + const src_info = self.arithmeticTypeInfo(src.ty); const dst_info = self.arithmeticTypeInfo(dst_ty); if (src_info.backing_bits == dst_info.backing_bits) { - return operand_id; + return try src.materialize(self); } - var wip = try self.elementWise(dst_ty, false); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { - const elem_id = try wip.elementAt(src_ty, operand_id, i); - const value_id = self.spv.allocId(); - switch (dst_info.signedness) { - .signed => try self.func.body.emit(self.spv.gpa, .OpSConvert, .{ - .id_result_type = wip.ty_id, - .id_result = value_id, - .signed_value = elem_id, - }), - .unsigned => try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = wip.ty_id, - .id_result = value_id, - .unsigned_value = elem_id, - }), - } + const converted = try self.buildIntConvert(dst_ty, src); - // Make sure to normalize the result if shrinking. - // Because strange ints are sign extended in their backing - // type, we don't need to normalize when growing the type. The - // representation is already the same. - if (dst_info.bits < src_info.bits) { - result_id.* = try self.normalize(wip.ty, value_id, dst_info); - } else { - result_id.* = value_id; - } - } - return try wip.finalize(); + // Make sure to normalize the result if shrinking. + // Because strange ints are sign extended in their backing + // type, we don't need to normalize when growing the type. The + // representation is already the same. + const result = if (dst_info.bits < src_info.bits) + try self.normalize(converted, dst_info) + else + converted; + + return try result.materialize(self); } fn intFromPtr(self: *DeclGen, operand_id: IdRef) !IdRef { @@ -4011,16 +4581,9 @@ const DeclGen = struct { fn airIntFromBool(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; - const operand_id = try self.resolve(un_op); - const result_ty = self.typeOfIndex(inst); - - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { - const elem_id = try wip.elementAt(Type.bool, operand_id, i); - result_id.* = try self.intFromBool(wip.ty, elem_id); - } - return try wip.finalize(); + const operand = try self.temporary(un_op); + const result = try self.intFromBool(operand); + return try result.materialize(self); } fn airFloatCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -4040,33 +4603,21 @@ const DeclGen = struct { fn airNot(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; - const operand_id = try self.resolve(ty_op.operand); + const operand = try self.temporary(ty_op.operand); const result_ty = self.typeOfIndex(inst); const info = self.arithmeticTypeInfo(result_ty); - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); + const result = switch (info.class) { + .bool => try self.buildUnary(.l_not, operand), + .float => unreachable, + .composite_integer => unreachable, // TODO + .strange_integer, .integer => blk: { + const complement = try self.buildUnary(.bit_not, operand); + break :blk try self.normalize(complement, info); + }, + }; - for (0..wip.results.len) |i| { - const args = .{ - .id_result_type = wip.ty_id, - .id_result = wip.allocId(i), - .operand = try wip.elementAt(result_ty, operand_id, i), - }; - switch (info.class) { - .bool => { - try self.func.body.emit(self.spv.gpa, .OpLogicalNot, args); - }, - .float => unreachable, - .composite_integer => unreachable, // TODO - .strange_integer, .integer => { - // Note: strange integer bits will be masked before operations that do not hold under modulo. - try self.func.body.emit(self.spv.gpa, .OpNot, args); - }, - } - } - - return try wip.finalize(); + return try result.materialize(self); } fn airArrayToSlice(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -4338,8 +4889,11 @@ const DeclGen = struct { // For now, just generate a temporary and use that. // TODO: This backend probably also should use isByRef from llvm... + const is_vector = array_ty.isVector(mod); + + const elem_repr: Repr = if (is_vector) .direct else .indirect; const ptr_array_ty_id = try self.ptrType2(array_ty, .Function, .direct); - const ptr_elem_ty_id = try self.ptrType2(elem_ty, .Function, .direct); + const ptr_elem_ty_id = try self.ptrType2(elem_ty, .Function, elem_repr); const tmp_id = self.spv.allocId(); try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ @@ -4357,12 +4911,12 @@ const DeclGen = struct { const result_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpLoad, .{ - .id_result_type = try self.resolveType(elem_ty, .direct), + .id_result_type = try self.resolveType(elem_ty, elem_repr), .id_result = result_id, .pointer = elem_ptr_id, }); - if (array_ty.isVector(mod)) { + if (is_vector) { // Result is already in direct representation return result_id; } @@ -4585,7 +5139,10 @@ const DeclGen = struct { if (field_offset == 0) break :base_ptr_int field_ptr_int; const field_offset_id = try self.constInt(Type.usize, field_offset, .direct); - break :base_ptr_int try self.binOpSimple(Type.usize, field_ptr_int, field_offset_id, .OpISub); + const field_ptr_tmp = Temporary.init(Type.usize, field_ptr_int); + const field_offset_tmp = Temporary.init(Type.usize, field_offset_id); + const result = try self.buildBinary(.i_sub, field_ptr_tmp, field_offset_tmp); + break :base_ptr_int try result.materialize(self); }; const base_ptr = self.spv.allocId(); @@ -5400,13 +5957,17 @@ const DeclGen = struct { else loaded_id; - const payload_ty_id = try self.resolveType(ptr_ty, .direct); - const null_id = try self.spv.constNull(payload_ty_id); + const ptr_ty_id = try self.resolveType(ptr_ty, .direct); + const null_id = try self.spv.constNull(ptr_ty_id); + const null_tmp = Temporary.init(ptr_ty, null_id); + const ptr = Temporary.init(ptr_ty, ptr_id); + const op: std.math.CompareOperator = switch (pred) { .is_null => .eq, .is_non_null => .neq, }; - return try self.cmp(op, Type.bool, ptr_ty, ptr_id, null_id); + const result = try self.cmp(op, ptr, null_tmp); + return try result.materialize(self); } const is_non_null_id = blk: { diff --git a/src/codegen/spirv/Module.zig b/src/codegen/spirv/Module.zig index 88fe677345..d1b2171786 100644 --- a/src/codegen/spirv/Module.zig +++ b/src/codegen/spirv/Module.zig @@ -155,6 +155,9 @@ cache: struct { void_type: ?IdRef = null, int_types: std.AutoHashMapUnmanaged(std.builtin.Type.Int, IdRef) = .{}, float_types: std.AutoHashMapUnmanaged(std.builtin.Type.Float, IdRef) = .{}, + // This cache is required so that @Vector(X, u1) in direct representation has the + // same ID as @Vector(X, bool) in indirect representation. + vector_types: std.AutoHashMapUnmanaged(struct { IdRef, u32 }, IdRef) = .{}, } = .{}, /// Set of Decls, referred to by Decl.Index. @@ -194,6 +197,7 @@ pub fn deinit(self: *Module) void { self.cache.int_types.deinit(self.gpa); self.cache.float_types.deinit(self.gpa); + self.cache.vector_types.deinit(self.gpa); self.decls.deinit(self.gpa); self.decl_deps.deinit(self.gpa); @@ -474,13 +478,17 @@ pub fn floatType(self: *Module, bits: u16) !IdRef { } pub fn vectorType(self: *Module, len: u32, child_id: IdRef) !IdRef { - const result_id = self.allocId(); - try self.sections.types_globals_constants.emit(self.gpa, .OpTypeVector, .{ - .id_result = result_id, - .component_type = child_id, - .component_count = len, - }); - return result_id; + const entry = try self.cache.vector_types.getOrPut(self.gpa, .{ child_id, len }); + if (!entry.found_existing) { + const result_id = self.allocId(); + entry.value_ptr.* = result_id; + try self.sections.types_globals_constants.emit(self.gpa, .OpTypeVector, .{ + .id_result = result_id, + .component_type = child_id, + .component_count = len, + }); + } + return entry.value_ptr.*; } pub fn constUndef(self: *Module, ty_id: IdRef) !IdRef { diff --git a/test/behavior/abs.zig b/test/behavior/abs.zig index 8ca160faff..21f02b2a3d 100644 --- a/test/behavior/abs.zig +++ b/test/behavior/abs.zig @@ -152,7 +152,6 @@ test "@abs int vectors" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try comptime testAbsIntVectors(1); try testAbsIntVectors(1); diff --git a/test/behavior/floatop.zig b/test/behavior/floatop.zig index 2e18b58d3c..d32319c644 100644 --- a/test/behavior/floatop.zig +++ b/test/behavior/floatop.zig @@ -275,7 +275,6 @@ test "@sqrt f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -287,7 +286,6 @@ test "@sqrt f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -389,7 +387,6 @@ test "@sqrt with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; try testSqrtWithVectors(); @@ -410,7 +407,6 @@ test "@sin f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -422,7 +418,6 @@ test "@sin f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -464,7 +459,6 @@ test "@sin with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -486,7 +480,6 @@ test "@cos f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -498,7 +491,6 @@ test "@cos f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -540,7 +532,6 @@ test "@cos with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -574,7 +565,6 @@ test "@tan f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -616,7 +606,6 @@ test "@tan with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -638,7 +627,6 @@ test "@exp f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -650,7 +638,6 @@ test "@exp f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -696,7 +683,6 @@ test "@exp with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -718,7 +704,6 @@ test "@exp2 f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -730,7 +715,6 @@ test "@exp2 f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -771,7 +755,6 @@ test "@exp2 with @vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -793,7 +776,6 @@ test "@log f16" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -805,7 +787,6 @@ test "@log f32/f64" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -847,7 +828,6 @@ test "@log with @vectors" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -866,7 +846,6 @@ test "@log2 f16" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -878,7 +857,6 @@ test "@log2 f32/f64" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -919,7 +897,6 @@ test "@log2 with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/13681 if (builtin.zig_backend == .stage2_llvm and @@ -945,7 +922,6 @@ test "@log10 f16" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -957,7 +933,6 @@ test "@log10 f32/f64" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -998,7 +973,6 @@ test "@log10 with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -1243,7 +1217,6 @@ test "@ceil f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; try testCeil(f16); @@ -1255,7 +1228,6 @@ test "@ceil f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; try testCeil(f32); @@ -1320,7 +1292,6 @@ test "@ceil with vectors" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest; @@ -1344,7 +1315,6 @@ test "@trunc f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch.isMIPS()) { @@ -1361,7 +1331,6 @@ test "@trunc f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch.isMIPS()) { @@ -1430,7 +1399,6 @@ fn testTrunc(comptime T: type) !void { test "@trunc with vectors" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and @@ -1454,7 +1422,6 @@ test "neg f16" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -1472,7 +1439,6 @@ test "neg f32/f64" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; diff --git a/test/behavior/math.zig b/test/behavior/math.zig index 59937515ab..66f86ede89 100644 --- a/test/behavior/math.zig +++ b/test/behavior/math.zig @@ -440,7 +440,6 @@ test "division" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -530,7 +529,6 @@ test "division half-precision floats" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -622,7 +620,6 @@ test "negation wrapping" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try expectEqual(@as(u1, 1), negateWrap(u1, 1)); } @@ -1031,6 +1028,60 @@ test "@mulWithOverflow bitsize > 32" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + { + var a: u40 = 3; + var b: u40 = 0x55_5555_5555; + var ov = @mulWithOverflow(a, b); + + try expect(ov[0] == 0xff_ffff_ffff); + try expect(ov[1] == 0); + + // Check that overflow bits in the low-word of wide-multiplications are checked too. + // Intermediate result is less than 2**64 + b = 0x55_5555_5556; + ov = @mulWithOverflow(a, b); + try expect(ov[0] == 2); + try expect(ov[1] == 1); + + // Check that overflow bits in the high-word of wide-multiplications are checked too. + // Intermediate result is more than 2**64 and bits 40..64 are not set. + a = 0x10_0000_0000; + b = 0x10_0000_0000; + ov = @mulWithOverflow(a, b); + try expect(ov[0] == 0); + try expect(ov[1] == 1); + } + + { + var a: i40 = 3; + var b: i40 = -0x2a_aaaa_aaaa; + var ov = @mulWithOverflow(a, b); + + try expect(ov[0] == -0x7f_ffff_fffe); + try expect(ov[1] == 0); + + // Check that the sign bit is properly checked + b = -0x2a_aaaa_aaab; + ov = @mulWithOverflow(a, b); + try expect(ov[0] == 0x7f_ffff_ffff); + try expect(ov[1] == 1); + + // Check that the low-order bits above the sign are checked. + a = 6; + ov = @mulWithOverflow(a, b); + try expect(ov[0] == -2); + try expect(ov[1] == 1); + + // Check that overflow bits in the high-word of wide-multiplications are checked too. + // high parts and sign of low-order bits are all 1. + a = 0x08_0000_0000; + b = -0x08_0000_0001; + ov = @mulWithOverflow(a, b); + + try expect(ov[0] == -0x8_0000_0000); + try expect(ov[1] == 1); + } + { var a: u62 = 3; _ = &a; @@ -1580,7 +1631,6 @@ test "@round f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -1592,7 +1642,6 @@ test "@round f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; diff --git a/test/behavior/select.zig b/test/behavior/select.zig index 90166dcfe5..f2a6cf8a63 100644 --- a/test/behavior/select.zig +++ b/test/behavior/select.zig @@ -8,7 +8,6 @@ test "@select vectors" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; try comptime selectVectors(); @@ -39,7 +38,6 @@ test "@select arrays" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx2)) return error.SkipZigTest; diff --git a/test/behavior/vector.zig b/test/behavior/vector.zig index 8987e0c091..2e860e1001 100644 --- a/test/behavior/vector.zig +++ b/test/behavior/vector.zig @@ -548,7 +548,6 @@ test "vector division operators" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_llvm and comptime builtin.cpu.arch.isArmOrThumb()) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; const S = struct {