diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index ed04ee475b..0574b7ee9e 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -22,6 +22,7 @@ const IdResultType = spec.IdResultType; const StorageClass = spec.StorageClass; const SpvModule = @import("spirv/Module.zig"); +const IdRange = SpvModule.IdRange; const SpvSection = @import("spirv/Section.zig"); const SpvAssembler = @import("spirv/Assembler.zig"); @@ -32,7 +33,7 @@ pub const zig_call_abi_ver = 3; const InternMap = std.AutoHashMapUnmanaged(struct { InternPool.Index, DeclGen.Repr }, IdResult); const PtrTypeMap = std.AutoHashMapUnmanaged( - struct { InternPool.Index, StorageClass }, + struct { InternPool.Index, StorageClass, DeclGen.Repr }, struct { ty_id: IdRef, fwd_emitted: bool }, ); @@ -626,7 +627,7 @@ const DeclGen = struct { } /// Checks whether the type can be directly translated to SPIR-V vectors - fn isVector(self: *DeclGen, ty: Type) bool { + fn isSpvVector(self: *DeclGen, ty: Type) bool { const mod = self.module; const target = self.getTarget(); if (ty.zigTypeTag(mod) != .Vector) return false; @@ -798,26 +799,39 @@ const DeclGen = struct { /// Construct a vector at runtime. /// ty must be an vector type. - /// Constituents should be in `indirect` representation (as the elements of an vector should be). - /// Result is in `direct` representation. fn constructVector(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef { - // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which' - // operands are not constant. - // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349 - // For now, just initialize the struct by setting the fields manually... - // TODO: Make this OpCompositeConstruct when we can const mod = self.module; - const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function }); - const ptr_elem_ty_id = try self.ptrType(ty.elemType2(mod), .Function); - for (constituents, 0..) |constitent_id, index| { - const ptr_id = try self.accessChain(ptr_elem_ty_id, ptr_composite_id, &.{@as(u32, @intCast(index))}); - try self.func.body.emit(self.spv.gpa, .OpStore, .{ - .pointer = ptr_id, - .object = constitent_id, - }); - } + assert(ty.vectorLen(mod) == constituents.len); - return try self.load(ty, ptr_composite_id, .{}); + // Note: older versions of the Khronos SPRIV-LLVM translator crash on this instruction + // because it cannot construct structs which' operands are not constant. + // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349 + // Currently this is the case for Intel OpenCL CPU runtime (2023-WW46), but the + // alternatives dont work properly: + // - using temporaries/pointers doesn't work properly with vectors of bool, causes + // backends that use llvm to crash + // - using OpVectorInsertDynamic doesn't work for non-spirv-vectors of bool. + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{ + .id_result_type = try self.resolveType(ty, .direct), + .id_result = result_id, + .constituents = constituents, + }); + return result_id; + } + + /// Construct a vector at runtime with all lanes set to the same value. + /// ty must be an vector type. + fn constructVectorSplat(self: *DeclGen, ty: Type, constituent: IdRef) !IdRef { + const mod = self.module; + const n = ty.vectorLen(mod); + + const constituents = try self.gpa.alloc(IdRef, n); + defer self.gpa.free(constituents); + @memset(constituents, constituent); + + return try self.constructVector(ty, constituents); } /// Construct an array at runtime. @@ -1031,21 +1045,27 @@ const DeclGen = struct { const constituents = try self.gpa.alloc(IdRef, @intCast(ty.arrayLenIncludingSentinel(mod))); defer self.gpa.free(constituents); + const child_repr: Repr = switch (tag) { + .array_type => .indirect, + .vector_type => .direct, + else => unreachable, + }; + switch (aggregate.storage) { .bytes => |bytes| { // TODO: This is really space inefficient, perhaps there is a better // way to do it? for (constituents, bytes.toSlice(constituents.len, ip)) |*constituent, byte| { - constituent.* = try self.constInt(elem_ty, byte, .indirect); + constituent.* = try self.constInt(elem_ty, byte, child_repr); } }, .elems => |elems| { for (constituents, elems) |*constituent, elem| { - constituent.* = try self.constant(elem_ty, Value.fromInterned(elem), .indirect); + constituent.* = try self.constant(elem_ty, Value.fromInterned(elem), child_repr); } }, .repeated_elem => |elem| { - @memset(constituents, try self.constant(elem_ty, Value.fromInterned(elem), .indirect)); + @memset(constituents, try self.constant(elem_ty, Value.fromInterned(elem), child_repr)); }, } @@ -1334,7 +1354,11 @@ const DeclGen = struct { } fn ptrType(self: *DeclGen, child_ty: Type, storage_class: StorageClass) !IdRef { - const key = .{ child_ty.toIntern(), storage_class }; + return try self.ptrType2(child_ty, storage_class, .indirect); + } + + fn ptrType2(self: *DeclGen, child_ty: Type, storage_class: StorageClass, child_repr: Repr) !IdRef { + const key = .{ child_ty.toIntern(), storage_class, child_repr }; const entry = try self.ptr_types.getOrPut(self.gpa, key); if (entry.found_existing) { const fwd_id = entry.value_ptr.ty_id; @@ -1354,7 +1378,7 @@ const DeclGen = struct { .fwd_emitted = false, }; - const child_ty_id = try self.resolveType(child_ty, .indirect); + const child_ty_id = try self.resolveType(child_ty, child_repr); try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpTypePointer, .{ .id_result = result_id, @@ -1645,11 +1669,10 @@ const DeclGen = struct { }, .Vector => { const elem_ty = ty.childType(mod); - // TODO: Make `.direct`. - const elem_ty_id = try self.resolveType(elem_ty, .indirect); + const elem_ty_id = try self.resolveType(elem_ty, repr); const len = ty.vectorLen(mod); - if (self.isVector(ty)) { + if (self.isSpvVector(ty)) { return try self.spv.vectorType(len, elem_ty_id); } else { return try self.arrayType(len, elem_ty_id); @@ -1948,7 +1971,7 @@ const DeclGen = struct { const mod = wip.dg.module; if (wip.is_array) { assert(ty.isVector(mod)); - return try wip.dg.extractField(ty.childType(mod), value, @intCast(index)); + return try wip.dg.extractVectorComponent(ty.childType(mod), value, @intCast(index)); } else { assert(index == 0); return value; @@ -1961,11 +1984,7 @@ const DeclGen = struct { /// Results is in `direct` representation. fn finalize(wip: *WipElementWise) !IdRef { if (wip.is_array) { - // Convert all the constituents to indirect, as required for the array. - for (wip.results) |*result| { - result.* = try wip.dg.convertToIndirect(wip.ty, result.*); - } - return try wip.dg.constructArray(wip.result_ty, wip.results); + return try wip.dg.constructVector(wip.result_ty, wip.results); } else { return wip.results[0]; } @@ -1982,7 +2001,7 @@ const DeclGen = struct { /// Create a new element-wise operation. fn elementWise(self: *DeclGen, result_ty: Type, force_element_wise: bool) !WipElementWise { const mod = self.module; - const is_array = result_ty.isVector(mod) and (!self.isVector(result_ty) or force_element_wise); + const is_array = result_ty.isVector(mod) and (!self.isSpvVector(result_ty) or force_element_wise); const num_results = if (is_array) result_ty.vectorLen(mod) else 1; const results = try self.gpa.alloc(IdRef, num_results); @memset(results, undefined); @@ -2253,29 +2272,102 @@ const DeclGen = struct { /// This converts the argument type from resolveType(ty, .indirect) to resolveType(ty, .direct). fn convertToDirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { const mod = self.module; - return switch (ty.zigTypeTag(mod)) { - .Bool => blk: { - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ - .id_result_type = try self.resolveType(Type.bool, .direct), - .id_result = result_id, - .operand_1 = operand_id, - .operand_2 = try self.constBool(false, .indirect), - }); - break :blk result_id; + const scalar_ty = ty.scalarType(mod); + const is_spv_vector = self.isSpvVector(ty); + switch (scalar_ty.zigTypeTag(mod)) { + .Bool => { + // TODO: We may want to use something like elementWise in this function. + // First we need to audit whether this would recursively call into itself. + if (!ty.isVector(mod) or is_spv_vector) { + const result_id = self.spv.allocId(); + const scalar_false_id = try self.constBool(false, .indirect); + const false_id = if (is_spv_vector) blk: { + const index = try mod.intern_pool.get(mod.gpa, .{ + .vector_type = .{ + .len = ty.vectorLen(mod), + .child = Type.u1.toIntern(), + }, + }); + const vec_ty = Type.fromInterned(index); + break :blk try self.constructVectorSplat(vec_ty, scalar_false_id); + } else scalar_false_id; + + try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ + .id_result_type = try self.resolveType(ty, .direct), + .id_result = result_id, + .operand_1 = operand_id, + .operand_2 = false_id, + }); + return result_id; + } + + const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod)); + for (constituents, 0..) |*id, i| { + const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i)); + id.* = try self.convertToDirect(scalar_ty, element); + } + return try self.constructVector(ty, constituents); }, - else => operand_id, - }; + else => return operand_id, + } } /// Convert representation from direct (in 'register) to direct (in memory) /// This converts the argument type from resolveType(ty, .direct) to resolveType(ty, .indirect). fn convertToIndirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { const mod = self.module; - return switch (ty.zigTypeTag(mod)) { - .Bool => try self.intFromBool(Type.u1, operand_id), - else => operand_id, - }; + const scalar_ty = ty.scalarType(mod); + const is_spv_vector = self.isSpvVector(ty); + switch (scalar_ty.zigTypeTag(mod)) { + .Bool => { + const result_ty = if (is_spv_vector) blk: { + const index = try mod.intern_pool.get(mod.gpa, .{ + .vector_type = .{ + .len = ty.vectorLen(mod), + .child = Type.u1.toIntern(), + }, + }); + break :blk Type.fromInterned(index); + } else Type.u1; + + if (!ty.isVector(mod) or is_spv_vector) { + // TODO: We may want to use something like elementWise in this function. + // First we need to audit whether this would recursively call into itself. + // Also unify it with intFromBool + + const scalar_zero_id = try self.constInt(Type.u1, 0, .direct); + const scalar_one_id = try self.constInt(Type.u1, 1, .direct); + + const zero_id = if (is_spv_vector) + try self.constructVectorSplat(result_ty, scalar_zero_id) + else + scalar_zero_id; + + const one_id = if (is_spv_vector) + try self.constructVectorSplat(result_ty, scalar_one_id) + else + scalar_one_id; + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpSelect, .{ + .id_result_type = try self.resolveType(result_ty, .direct), + .id_result = result_id, + .condition = operand_id, + .object_1 = one_id, + .object_2 = zero_id, + }); + return result_id; + } + + const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod)); + for (constituents, 0..) |*id, i| { + const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i)); + id.* = try self.convertToIndirect(scalar_ty, element); + } + return try self.constructVector(result_ty, constituents); + }, + else => return operand_id, + } } fn extractField(self: *DeclGen, result_ty: Type, object: IdRef, field: u32) !IdRef { @@ -2292,6 +2384,21 @@ const DeclGen = struct { return try self.convertToDirect(result_ty, result_id); } + fn extractVectorComponent(self: *DeclGen, result_ty: Type, vector_id: IdRef, field: u32) !IdRef { + // Whether this is an OpTypeVector or OpTypeArray, we need to emit the same instruction regardless. + const result_ty_id = try self.resolveType(result_ty, .direct); + const result_id = self.spv.allocId(); + const indexes = [_]u32{field}; + try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{ + .id_result_type = result_ty_id, + .id_result = result_id, + .composite = vector_id, + .indexes = &indexes, + }); + // Vector components are already stored in direct representation. + return result_id; + } + const MemoryOptions = struct { is_volatile: bool = false, }; @@ -2926,7 +3033,7 @@ const DeclGen = struct { const ov_ty = result_ty.structFieldType(1, self.module); const bool_ty_id = try self.resolveType(Type.bool, .direct); - const cmp_ty_id = if (self.isVector(operand_ty)) + const cmp_ty_id = if (self.isSpvVector(operand_ty)) // TODO: Resolving a vector type with .direct should return a SPIR-V vector try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct)) else @@ -3100,7 +3207,7 @@ const DeclGen = struct { const ov_ty = result_ty.structFieldType(1, self.module); const bool_ty_id = try self.resolveType(Type.bool, .direct); - const cmp_ty_id = if (self.isVector(operand_ty)) + const cmp_ty_id = if (self.isSpvVector(operand_ty)) // TODO: Resolving a vector type with .direct should return a SPIR-V vector try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct)) else @@ -3312,7 +3419,7 @@ const DeclGen = struct { const info = self.arithmeticTypeInfo(operand_ty); - var result_id = try self.extractField(scalar_ty, operand, 0); + var result_id = try self.extractVectorComponent(scalar_ty, operand, 0); const len = operand_ty.vectorLen(mod); switch (reduce.operation) { @@ -3320,7 +3427,7 @@ const DeclGen = struct { const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt; for (1..len) |i| { const lhs = result_id; - const rhs = try self.extractField(scalar_ty, operand, @intCast(i)); + const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i)); result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs); } @@ -3354,7 +3461,7 @@ const DeclGen = struct { for (1..len) |i| { const lhs = result_id; - const rhs = try self.extractField(scalar_ty, operand, @intCast(i)); + const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i)); result_id = self.spv.allocId(); try self.func.body.emitRaw(self.spv.gpa, opcode, 4); @@ -3388,9 +3495,9 @@ const DeclGen = struct { const index = elem.toSignedInt(mod); if (index >= 0) { - result_id.* = try self.extractField(wip.ty, a, @intCast(index)); + result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index)); } else { - result_id.* = try self.extractField(wip.ty, b, @intCast(~index)); + result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index)); } } return try wip.finalize(); @@ -4086,8 +4193,7 @@ const DeclGen = struct { defer self.gpa.free(elem_ids); for (elements, 0..) |element, i| { - const id = try self.resolve(element); - elem_ids[i] = try self.convertToIndirect(result_ty.childType(mod), id); + elem_ids[i] = try self.resolve(element); } return try self.constructVector(result_ty, elem_ids); @@ -4234,16 +4340,54 @@ const DeclGen = struct { const array_id = try self.resolve(bin_op.lhs); const index_id = try self.resolve(bin_op.rhs); + if (self.isSpvVector(array_ty)) { + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpVectorExtractDynamic, .{ + .id_result_type = try self.resolveType(elem_ty, .direct), + .id_result = result_id, + .vector = array_id, + .index = index_id, + }); + return result_id; + } + // SPIR-V doesn't have an array indexing function for some damn reason. // For now, just generate a temporary and use that. // TODO: This backend probably also should use isByRef from llvm... - const elem_ptr_ty_id = try self.ptrType(elem_ty, .Function); + const ptr_array_ty_id = try self.ptrType2(array_ty, .Function, .direct); + const ptr_elem_ty_id = try self.ptrType2(elem_ty, .Function, .direct); - const tmp_id = try self.alloc(array_ty, .{ .storage_class = .Function }); - try self.store(array_ty, tmp_id, array_id, .{}); - const elem_ptr_id = try self.accessChainId(elem_ptr_ty_id, tmp_id, &.{index_id}); - return try self.load(elem_ty, elem_ptr_id, .{}); + const tmp_id = self.spv.allocId(); + try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ + .id_result_type = ptr_array_ty_id, + .id_result = tmp_id, + .storage_class = .Function, + }); + + try self.func.body.emit(self.spv.gpa, .OpStore, .{ + .pointer = tmp_id, + .object = array_id, + }); + + const elem_ptr_id = try self.accessChainId(ptr_elem_ty_id, tmp_id, &.{index_id}); + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpLoad, .{ + .id_result_type = try self.resolveType(elem_ty, .direct), + .id_result = result_id, + .pointer = elem_ptr_id, + }); + + if (array_ty.isVector(mod)) { + // Result is already in direct representation + return result_id; + } + + // This is an array type; the elements are stored in indirect representation. + // We have to convert the type to direct. + + return try self.convertToDirect(elem_ty, result_id); } fn airPtrElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {