diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index bb2fe4219c..c9c11dd44e 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -52,12 +52,6 @@ const Block = struct { const BlockMap = std.AutoHashMapUnmanaged(Air.Inst.Index, *Block); -/// Maps Zig decl indices to SPIR-V linking information. -pub const DeclLinkMap = std.AutoHashMapUnmanaged(Decl.Index, SpvModule.Decl.Index); - -/// Maps anon decl indices to SPIR-V linking information. -pub const AnonDeclLinkMap = std.AutoHashMapUnmanaged(struct { InternPool.Index, StorageClass }, SpvModule.Decl.Index); - /// This structure holds information that is relevant to the entire compilation, /// in contrast to `DeclGen`, which only holds relevant information about a /// single decl. @@ -70,10 +64,10 @@ pub const Object = struct { /// The Zig module that this object file is generated for. /// A map of Zig decl indices to SPIR-V decl indices. - decl_link: DeclLinkMap = .{}, + decl_link: std.AutoHashMapUnmanaged(Decl.Index, SpvModule.Decl.Index) = .{}, /// A map of Zig InternPool indices for anonymous decls to SPIR-V decl indices. - anon_decl_link: AnonDeclLinkMap = .{}, + anon_decl_link: std.AutoHashMapUnmanaged(struct { InternPool.Index, StorageClass }, SpvModule.Decl.Index) = .{}, /// A map that maps AIR intern pool indices to SPIR-V cache references (which /// is basically the same thing except for SPIR-V). @@ -1266,22 +1260,32 @@ const DeclGen = struct { const elem_ty = ty.childType(mod); const elem_ty_ref = try self.resolveType(elem_ty, .indirect); - const total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel(mod)) orelse { + var total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel(mod)) orelse { return self.fail("array type of {} elements is too large", .{ty.arrayLenIncludingSentinel(mod)}); }; - if (!ty.hasRuntimeBitsIgnoreComptime(mod)) { + const ty_ref = if (!elem_ty.hasRuntimeBitsIgnoreComptime(mod)) blk: { + // The size of the array would be 0, but that is not allowed in SPIR-V. + // This path can be reached when the backend is asked to generate a pointer to + // an array of some zero-bit type. This should always be an indirect path. + assert(repr == .indirect); + + // We cannot use the child type here, so just use an opaque type. + break :blk try self.spv.resolve(.{ .opaque_type = .{ + .name = try self.spv.resolveString("zero-sized array"), + } }); + } else if (total_len == 0) blk: { // The size of the array would be 0, but that is not allowed in SPIR-V. // This path can be reached for example when there is a slicing of a pointer // that produces a zero-length array. In all cases where this type can be generated, - // we should be in an indirect path (direct uses of this type should be filtered out in Sema). + // this should be an indirect path. assert(repr == .indirect); - return try self.spv.resolve(.{ .opaque_type = .{ - .name = try self.spv.resolveString("zero-sized array"), - } }); - } + // In this case, we have an array of a non-zero sized type. In this case, + // generate an array of 1 element instead, so that ptr_elem_ptr instructions + // can be lowered to ptrAccessChain instead of manually performing the math. + break :blk try self.spv.arrayType(1, elem_ty_ref); + } else try self.spv.arrayType(total_len, elem_ty_ref); - const ty_ref = try self.spv.arrayType(total_len, elem_ty_ref); try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref }); return ty_ref; }, @@ -2554,38 +2558,85 @@ const DeclGen = struct { var cmp_lhs_id = lhs_id; var cmp_rhs_id = rhs_id; const bool_ty_ref = try self.resolveType(Type.bool, .direct); + const op_ty = switch (ty.zigTypeTag(mod)) { + .Int, .Bool, .Float => ty, + .Enum => ty.intTagType(mod), + .ErrorSet => Type.u16, + .Pointer => blk: { + // Note that while SPIR-V offers OpPtrEqual and OpPtrNotEqual, they are + // currently not implemented in the SPIR-V LLVM translator. Thus, we emit these using + // OpConvertPtrToU... + cmp_lhs_id = self.spv.allocId(); + cmp_rhs_id = self.spv.allocId(); + + const usize_ty_id = self.typeId(try self.sizeType()); + + try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{ + .id_result_type = usize_ty_id, + .id_result = cmp_lhs_id, + .pointer = lhs_id, + }); + + try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{ + .id_result_type = usize_ty_id, + .id_result = cmp_rhs_id, + .pointer = rhs_id, + }); + + break :blk Type.usize; + }, + .Optional => { + const payload_ty = ty.optionalChild(mod); + if (ty.optionalReprIsPayload(mod)) { + assert(payload_ty.hasRuntimeBitsIgnoreComptime(mod)); + assert(!payload_ty.isSlice(mod)); + return self.cmp(op, payload_ty, lhs_id, rhs_id); + } + + const lhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) + try self.extractField(Type.bool, lhs_id, 1) + else + try self.convertToDirect(Type.bool, lhs_id); + + const rhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) + try self.extractField(Type.bool, rhs_id, 1) + else + try self.convertToDirect(Type.bool, rhs_id); + + const valid_cmp_id = try self.cmp(op, Type.bool, lhs_valid_id, rhs_valid_id); + if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) { + return valid_cmp_id; + } + + // TODO: Should we short circuit here? It shouldn't affect correctness, but + // perhaps it will generate more efficient code. + + const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0); + const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0); + + const pl_cmp_id = try self.cmp(op, payload_ty, lhs_pl_id, rhs_pl_id); + + // op == .eq => lhs_valid == rhs_valid && lhs_pl == rhs_pl + // op == .neq => lhs_valid != rhs_valid || lhs_pl != rhs_pl + + const result_id = self.spv.allocId(); + const args = .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = result_id, + .operand_1 = valid_cmp_id, + .operand_2 = pl_cmp_id, + }; + switch (op) { + .eq => try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, args), + .neq => try self.func.body.emit(self.spv.gpa, .OpLogicalOr, args), + else => unreachable, + } + return result_id; + }, + else => unreachable, + }; + const opcode: Opcode = opcode: { - const op_ty = switch (ty.zigTypeTag(mod)) { - .Int, .Bool, .Float => ty, - .Enum => ty.intTagType(mod), - .ErrorSet => Type.u16, - .Pointer => blk: { - // Note that while SPIR-V offers OpPtrEqual and OpPtrNotEqual, they are - // currently not implemented in the SPIR-V LLVM translator. Thus, we emit these using - // OpConvertPtrToU... - cmp_lhs_id = self.spv.allocId(); - cmp_rhs_id = self.spv.allocId(); - - const usize_ty_id = self.typeId(try self.sizeType()); - - try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{ - .id_result_type = usize_ty_id, - .id_result = cmp_lhs_id, - .pointer = lhs_id, - }); - - try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{ - .id_result_type = usize_ty_id, - .id_result = cmp_rhs_id, - .pointer = rhs_id, - }); - - break :blk Type.usize; - }, - .Optional => unreachable, // TODO - else => unreachable, - }; - const info = try self.arithmeticTypeInfo(op_ty); const signedness = switch (info.class) { .composite_integer => { @@ -2653,7 +2704,6 @@ const DeclGen = struct { const lhs_id = try self.resolve(bin_op.lhs); const rhs_id = try self.resolve(bin_op.rhs); const ty = self.typeOf(bin_op.lhs); - assert(ty.eql(self.typeOf(bin_op.rhs), self.module)); return try self.cmp(op, ty, lhs_id, rhs_id); } @@ -3061,16 +3111,17 @@ const DeclGen = struct { const mod = self.module; const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; const bin_op = self.air.extraData(Air.Bin, ty_pl.payload).data; - const ptr_ty = self.typeOf(bin_op.lhs); - const elem_ty = ptr_ty.childType(mod); + const src_ptr_ty = self.typeOf(bin_op.lhs); + const elem_ty = src_ptr_ty.childType(mod); + const ptr_id = try self.resolve(bin_op.lhs); + if (!elem_ty.hasRuntimeBitsIgnoreComptime(mod)) { - const ptr_ty_ref = try self.resolveType(ptr_ty, .direct); - return try self.spv.constUndef(ptr_ty_ref); + const dst_ptr_ty = self.typeOfIndex(inst); + return try self.bitCast(dst_ptr_ty, src_ptr_ty, ptr_id); } - const ptr_id = try self.resolve(bin_op.lhs); const index_id = try self.resolve(bin_op.rhs); - return try self.ptrElemPtr(ptr_ty, ptr_id, index_id); + return try self.ptrElemPtr(src_ptr_ty, ptr_id, index_id); } fn airArrayElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {