spirv: optional comparison

This commit is contained in:
Robin Voetter 2023-10-12 22:09:34 +02:00
parent 10b8171466
commit f4064d98e2
No known key found for this signature in database

View File

@ -52,12 +52,6 @@ const Block = struct {
const BlockMap = std.AutoHashMapUnmanaged(Air.Inst.Index, *Block); const BlockMap = std.AutoHashMapUnmanaged(Air.Inst.Index, *Block);
/// Maps Zig decl indices to SPIR-V linking information.
pub const DeclLinkMap = std.AutoHashMapUnmanaged(Decl.Index, SpvModule.Decl.Index);
/// Maps anon decl indices to SPIR-V linking information.
pub const AnonDeclLinkMap = std.AutoHashMapUnmanaged(struct { InternPool.Index, StorageClass }, SpvModule.Decl.Index);
/// This structure holds information that is relevant to the entire compilation, /// This structure holds information that is relevant to the entire compilation,
/// in contrast to `DeclGen`, which only holds relevant information about a /// in contrast to `DeclGen`, which only holds relevant information about a
/// single decl. /// single decl.
@ -70,10 +64,10 @@ pub const Object = struct {
/// The Zig module that this object file is generated for. /// The Zig module that this object file is generated for.
/// A map of Zig decl indices to SPIR-V decl indices. /// A map of Zig decl indices to SPIR-V decl indices.
decl_link: DeclLinkMap = .{}, decl_link: std.AutoHashMapUnmanaged(Decl.Index, SpvModule.Decl.Index) = .{},
/// A map of Zig InternPool indices for anonymous decls to SPIR-V decl indices. /// A map of Zig InternPool indices for anonymous decls to SPIR-V decl indices.
anon_decl_link: AnonDeclLinkMap = .{}, anon_decl_link: std.AutoHashMapUnmanaged(struct { InternPool.Index, StorageClass }, SpvModule.Decl.Index) = .{},
/// A map that maps AIR intern pool indices to SPIR-V cache references (which /// A map that maps AIR intern pool indices to SPIR-V cache references (which
/// is basically the same thing except for SPIR-V). /// is basically the same thing except for SPIR-V).
@ -1266,22 +1260,32 @@ const DeclGen = struct {
const elem_ty = ty.childType(mod); const elem_ty = ty.childType(mod);
const elem_ty_ref = try self.resolveType(elem_ty, .indirect); const elem_ty_ref = try self.resolveType(elem_ty, .indirect);
const total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel(mod)) orelse { var total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel(mod)) orelse {
return self.fail("array type of {} elements is too large", .{ty.arrayLenIncludingSentinel(mod)}); return self.fail("array type of {} elements is too large", .{ty.arrayLenIncludingSentinel(mod)});
}; };
if (!ty.hasRuntimeBitsIgnoreComptime(mod)) { const ty_ref = if (!elem_ty.hasRuntimeBitsIgnoreComptime(mod)) blk: {
// The size of the array would be 0, but that is not allowed in SPIR-V.
// This path can be reached when the backend is asked to generate a pointer to
// an array of some zero-bit type. This should always be an indirect path.
assert(repr == .indirect);
// We cannot use the child type here, so just use an opaque type.
break :blk try self.spv.resolve(.{ .opaque_type = .{
.name = try self.spv.resolveString("zero-sized array"),
} });
} else if (total_len == 0) blk: {
// The size of the array would be 0, but that is not allowed in SPIR-V. // The size of the array would be 0, but that is not allowed in SPIR-V.
// This path can be reached for example when there is a slicing of a pointer // This path can be reached for example when there is a slicing of a pointer
// that produces a zero-length array. In all cases where this type can be generated, // that produces a zero-length array. In all cases where this type can be generated,
// we should be in an indirect path (direct uses of this type should be filtered out in Sema). // this should be an indirect path.
assert(repr == .indirect); assert(repr == .indirect);
return try self.spv.resolve(.{ .opaque_type = .{ // In this case, we have an array of a non-zero sized type. In this case,
.name = try self.spv.resolveString("zero-sized array"), // generate an array of 1 element instead, so that ptr_elem_ptr instructions
} }); // can be lowered to ptrAccessChain instead of manually performing the math.
} break :blk try self.spv.arrayType(1, elem_ty_ref);
} else try self.spv.arrayType(total_len, elem_ty_ref);
const ty_ref = try self.spv.arrayType(total_len, elem_ty_ref);
try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref }); try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref });
return ty_ref; return ty_ref;
}, },
@ -2554,38 +2558,85 @@ const DeclGen = struct {
var cmp_lhs_id = lhs_id; var cmp_lhs_id = lhs_id;
var cmp_rhs_id = rhs_id; var cmp_rhs_id = rhs_id;
const bool_ty_ref = try self.resolveType(Type.bool, .direct); const bool_ty_ref = try self.resolveType(Type.bool, .direct);
const op_ty = switch (ty.zigTypeTag(mod)) {
.Int, .Bool, .Float => ty,
.Enum => ty.intTagType(mod),
.ErrorSet => Type.u16,
.Pointer => blk: {
// Note that while SPIR-V offers OpPtrEqual and OpPtrNotEqual, they are
// currently not implemented in the SPIR-V LLVM translator. Thus, we emit these using
// OpConvertPtrToU...
cmp_lhs_id = self.spv.allocId();
cmp_rhs_id = self.spv.allocId();
const usize_ty_id = self.typeId(try self.sizeType());
try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
.id_result_type = usize_ty_id,
.id_result = cmp_lhs_id,
.pointer = lhs_id,
});
try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
.id_result_type = usize_ty_id,
.id_result = cmp_rhs_id,
.pointer = rhs_id,
});
break :blk Type.usize;
},
.Optional => {
const payload_ty = ty.optionalChild(mod);
if (ty.optionalReprIsPayload(mod)) {
assert(payload_ty.hasRuntimeBitsIgnoreComptime(mod));
assert(!payload_ty.isSlice(mod));
return self.cmp(op, payload_ty, lhs_id, rhs_id);
}
const lhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
try self.extractField(Type.bool, lhs_id, 1)
else
try self.convertToDirect(Type.bool, lhs_id);
const rhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
try self.extractField(Type.bool, rhs_id, 1)
else
try self.convertToDirect(Type.bool, rhs_id);
const valid_cmp_id = try self.cmp(op, Type.bool, lhs_valid_id, rhs_valid_id);
if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
return valid_cmp_id;
}
// TODO: Should we short circuit here? It shouldn't affect correctness, but
// perhaps it will generate more efficient code.
const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0);
const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0);
const pl_cmp_id = try self.cmp(op, payload_ty, lhs_pl_id, rhs_pl_id);
// op == .eq => lhs_valid == rhs_valid && lhs_pl == rhs_pl
// op == .neq => lhs_valid != rhs_valid || lhs_pl != rhs_pl
const result_id = self.spv.allocId();
const args = .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = result_id,
.operand_1 = valid_cmp_id,
.operand_2 = pl_cmp_id,
};
switch (op) {
.eq => try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, args),
.neq => try self.func.body.emit(self.spv.gpa, .OpLogicalOr, args),
else => unreachable,
}
return result_id;
},
else => unreachable,
};
const opcode: Opcode = opcode: { const opcode: Opcode = opcode: {
const op_ty = switch (ty.zigTypeTag(mod)) {
.Int, .Bool, .Float => ty,
.Enum => ty.intTagType(mod),
.ErrorSet => Type.u16,
.Pointer => blk: {
// Note that while SPIR-V offers OpPtrEqual and OpPtrNotEqual, they are
// currently not implemented in the SPIR-V LLVM translator. Thus, we emit these using
// OpConvertPtrToU...
cmp_lhs_id = self.spv.allocId();
cmp_rhs_id = self.spv.allocId();
const usize_ty_id = self.typeId(try self.sizeType());
try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
.id_result_type = usize_ty_id,
.id_result = cmp_lhs_id,
.pointer = lhs_id,
});
try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
.id_result_type = usize_ty_id,
.id_result = cmp_rhs_id,
.pointer = rhs_id,
});
break :blk Type.usize;
},
.Optional => unreachable, // TODO
else => unreachable,
};
const info = try self.arithmeticTypeInfo(op_ty); const info = try self.arithmeticTypeInfo(op_ty);
const signedness = switch (info.class) { const signedness = switch (info.class) {
.composite_integer => { .composite_integer => {
@ -2653,7 +2704,6 @@ const DeclGen = struct {
const lhs_id = try self.resolve(bin_op.lhs); const lhs_id = try self.resolve(bin_op.lhs);
const rhs_id = try self.resolve(bin_op.rhs); const rhs_id = try self.resolve(bin_op.rhs);
const ty = self.typeOf(bin_op.lhs); const ty = self.typeOf(bin_op.lhs);
assert(ty.eql(self.typeOf(bin_op.rhs), self.module));
return try self.cmp(op, ty, lhs_id, rhs_id); return try self.cmp(op, ty, lhs_id, rhs_id);
} }
@ -3061,16 +3111,17 @@ const DeclGen = struct {
const mod = self.module; const mod = self.module;
const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
const bin_op = self.air.extraData(Air.Bin, ty_pl.payload).data; const bin_op = self.air.extraData(Air.Bin, ty_pl.payload).data;
const ptr_ty = self.typeOf(bin_op.lhs); const src_ptr_ty = self.typeOf(bin_op.lhs);
const elem_ty = ptr_ty.childType(mod); const elem_ty = src_ptr_ty.childType(mod);
const ptr_id = try self.resolve(bin_op.lhs);
if (!elem_ty.hasRuntimeBitsIgnoreComptime(mod)) { if (!elem_ty.hasRuntimeBitsIgnoreComptime(mod)) {
const ptr_ty_ref = try self.resolveType(ptr_ty, .direct); const dst_ptr_ty = self.typeOfIndex(inst);
return try self.spv.constUndef(ptr_ty_ref); return try self.bitCast(dst_ptr_ty, src_ptr_ty, ptr_id);
} }
const ptr_id = try self.resolve(bin_op.lhs);
const index_id = try self.resolve(bin_op.rhs); const index_id = try self.resolve(bin_op.rhs);
return try self.ptrElemPtr(ptr_ty, ptr_id, index_id); return try self.ptrElemPtr(src_ptr_ty, ptr_id, index_id);
} }
fn airArrayElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { fn airArrayElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {