spirv: lower air is_null, is_non_null

Implements AIR lowering for is_null and is_non_null tags.

Additionally this cleans up and centralizes the logic to convert from 'direct'
representation to 'indirect' representation and vice-versa. The related functions,
as well as the functions that use it, are all moved near eachother so that the
conversion logic remains in a central place. Extracting/inserting fields
and loading/storing pointers should go through these functions.
This commit is contained in:
Robin Voetter 2023-04-11 22:13:54 +02:00
parent 435a5660ce
commit 83ab1ba8fd
No known key found for this signature in database
GPG Key ID: E755662F227CB468

View File

@ -402,9 +402,21 @@ pub const DeclGen = struct {
return result_id;
}
fn constUndef(self: *DeclGen, ty_ref: SpvType.Ref) Error!IdRef {
fn constUndef(self: *DeclGen, ty_ref: SpvType.Ref) !IdRef {
const result_id = self.spv.allocId();
try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpUndef, .{ .id_result_type = self.typeId(ty_ref), .id_result = result_id });
try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpUndef, .{
.id_result_type = self.typeId(ty_ref),
.id_result = result_id,
});
return result_id;
}
fn constNull(self: *DeclGen, ty_ref: SpvType.Ref) !IdRef {
const result_id = self.spv.allocId();
try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpConstantNull, .{
.id_result_type = self.typeId(ty_ref),
.id_result = result_id,
});
return result_id;
}
@ -674,7 +686,7 @@ pub const DeclGen = struct {
try self.addConstBool(has_payload);
return;
} else if (ty.optionalReprIsPayload()) {
// Optional representation is a nullable pointer.
// Optional representation is a nullable pointer or slice.
if (val.castTag(.opt_payload)) |payload| {
try self.lower(payload_ty, payload.data);
} else if (has_payload) {
@ -1257,7 +1269,7 @@ pub const DeclGen = struct {
const payload_ty_ref = try self.resolveType(payload_ty, .indirect);
if (ty.optionalReprIsPayload()) {
// Optional is actually a pointer.
// Optional is actually a pointer or a slice.
return payload_ty_ref;
}
@ -1523,6 +1535,93 @@ pub const DeclGen = struct {
}
}
/// Convert representation from indirect (in memory) to direct (in 'register')
/// This converts the argument type from resolveType(ty, .indirect) to resolveType(ty, .direct).
fn convertToDirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef {
// const direct_ty_ref = try self.resolveType(ty, .direct);
return switch (ty.zigTypeTag()) {
.Bool => blk: {
const direct_bool_ty_ref = try self.resolveType(ty, .direct);
const indirect_bool_ty_ref = try self.resolveType(ty, .indirect);
const zero_id = try self.constInt(indirect_bool_ty_ref, 0);
const result_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
.id_result_type = self.typeId(direct_bool_ty_ref),
.id_result = result_id,
.operand_1 = operand_id,
.operand_2 = zero_id,
});
break :blk result_id;
},
else => operand_id,
};
}
/// Convert representation from direct (in 'register) to direct (in memory)
/// This converts the argument type from resolveType(ty, .direct) to resolveType(ty, .indirect).
fn convertToIndirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef {
return switch (ty.zigTypeTag()) {
.Bool => blk: {
const indirect_bool_ty_ref = try self.resolveType(ty, .indirect);
const zero_id = try self.constInt(indirect_bool_ty_ref, 0);
const one_id = try self.constInt(indirect_bool_ty_ref, 1);
const result_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpSelect, .{
.id_result_type = self.typeId(indirect_bool_ty_ref),
.id_result = result_id,
.condition = operand_id,
.object_1 = one_id,
.object_2 = zero_id,
});
break :blk result_id;
},
else => operand_id,
};
}
fn extractField(self: *DeclGen, result_ty: Type, object: IdRef, field: u32) !IdRef {
const result_ty_ref = try self.resolveType(result_ty, .indirect);
const result_id = self.spv.allocId();
const indexes = [_]u32{field};
try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
.id_result_type = self.typeId(result_ty_ref),
.id_result = result_id,
.composite = object,
.indexes = &indexes,
});
// Convert bools; direct structs have their field types as indirect values.
return try self.convertToDirect(result_ty, result_id);
}
fn load(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef) !IdRef {
const value_ty = ptr_ty.childType();
const indirect_value_ty_ref = try self.resolveType(value_ty, .indirect);
const result_id = self.spv.allocId();
const access = spec.MemoryAccess.Extended{
.Volatile = ptr_ty.isVolatilePtr(),
};
try self.func.body.emit(self.spv.gpa, .OpLoad, .{
.id_result_type = self.typeId(indirect_value_ty_ref),
.id_result = result_id,
.pointer = ptr_id,
.memory_access = access,
});
return try self.convertToDirect(value_ty, result_id);
}
fn store(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef, value_id: IdRef) !void {
const value_ty = ptr_ty.childType();
const indirect_value_id = try self.convertToIndirect(value_ty, value_id);
const access = spec.MemoryAccess.Extended{
.Volatile = ptr_ty.isVolatilePtr(),
};
try self.func.body.emit(self.spv.gpa, .OpStore, .{
.pointer = ptr_id,
.object = indirect_value_id,
.memory_access = access,
});
}
fn genBody(self: *DeclGen, body: []const Air.Inst.Index) Error!void {
for (body) |inst| {
try self.genInst(inst);
@ -1615,6 +1714,9 @@ pub const DeclGen = struct {
.unwrap_errunion_err => try self.airErrUnionErr(inst),
.wrap_errunion_err => try self.airWrapErrUnionErr(inst),
.is_null => try self.airIsNull(inst, .is_null),
.is_non_null => try self.airIsNull(inst, .is_non_null),
.assembly => try self.airAssembly(inst),
.call => try self.airCall(inst, .auto),
@ -1776,18 +1878,17 @@ pub const DeclGen = struct {
.float, .bool => unreachable,
}
const operand_ty_id = try self.resolveTypeId(operand_ty);
const result_type_id = try self.resolveTypeId(result_ty);
const overflow_member_ty_ref = try self.intType(.unsigned, info.bits);
// The operand type must be the same as the result type in SPIR-V.
const operand_ty_ref = try self.resolveType(operand_ty, .direct);
const operand_ty_id = self.typeId(operand_ty_ref);
const op_result_id = blk: {
// Construct the SPIR-V result type.
// It is almost the same as the zig one, except that the fields must be the same type
// and they must be unsigned.
const overflow_result_ty_ref = try self.spv.simpleStructType(&.{
.{ .ty = overflow_member_ty_ref, .name = "res" },
.{ .ty = overflow_member_ty_ref, .name = "ov" },
.{ .ty = operand_ty_ref, .name = "res" },
.{ .ty = operand_ty_ref, .name = "ov" },
});
const result_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpIAddCarry, .{
@ -1801,11 +1902,13 @@ pub const DeclGen = struct {
// Now convert the SPIR-V flavor result into a Zig-flavor result.
// First, extract the two fields.
const unsigned_result = try self.extractField(overflow_member_ty_ref, op_result_id, 0);
const overflow = try self.extractField(overflow_member_ty_ref, op_result_id, 1);
const unsigned_result = try self.extractField(operand_ty, op_result_id, 0);
const overflow = try self.extractField(operand_ty, op_result_id, 1);
// We need to convert the results to the types that Zig expects here.
// The `result` is the same type except unsigned, so we can just bitcast that.
// TODO: This can be removed in Kernels as there are only unsigned ints. Maybe for
// shaders as well?
const result = try self.bitcast(operand_ty_id, unsigned_result);
// The overflow needs to be converted into whatever is used to represent it in Zig.
@ -1828,7 +1931,7 @@ pub const DeclGen = struct {
const result_id = self.spv.allocId();
const constituents = [_]IdRef{ result, casted_overflow };
try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{
.id_result_type = result_type_id,
.id_result_type = operand_ty_id,
.id_result = result_id,
.constituents = &constituents,
});
@ -1980,25 +2083,14 @@ pub const DeclGen = struct {
return result_id;
}
fn extractField(self: *DeclGen, result_ty_ref: SpvType.Ref, object: IdRef, field: u32) !IdRef {
const result_id = self.spv.allocId();
const indexes = [_]u32{field};
try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
.id_result_type = self.typeId(result_ty_ref),
.id_result = result_id,
.composite = object,
.indexes = &indexes,
});
// TODO: Convert bools, direct structs should have their field types as indirect values.
return result_id;
}
fn airSliceField(self: *DeclGen, inst: Air.Inst.Index, field: u32) !?IdRef {
if (self.liveness.isUnused(inst)) return null;
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const field_ty = self.air.typeOfIndex(inst);
const operand_id = try self.resolve(ty_op.operand);
return try self.extractField(
try self.resolveType(self.air.typeOfIndex(inst), .direct),
try self.resolve(ty_op.operand),
field_ty,
operand_id,
field,
);
}
@ -2367,35 +2459,6 @@ pub const DeclGen = struct {
return try self.load(ptr_ty, operand);
}
fn load(self: *DeclGen, ptr_ty: Type, ptr: IdRef) !IdRef {
const value_ty = ptr_ty.childType();
const direct_result_ty_ref = try self.resolveType(value_ty, .direct);
const indirect_result_ty_ref = try self.resolveType(value_ty, .indirect);
const result_id = self.spv.allocId();
const access = spec.MemoryAccess.Extended{
.Volatile = ptr_ty.isVolatilePtr(),
};
try self.func.body.emit(self.spv.gpa, .OpLoad, .{
.id_result_type = self.typeId(indirect_result_ty_ref),
.id_result = result_id,
.pointer = ptr,
.memory_access = access,
});
if (value_ty.zigTypeTag() == .Bool) {
// Convert indirect bool to direct bool
const zero_id = try self.constInt(indirect_result_ty_ref, 0);
const casted_result_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
.id_result_type = self.typeId(direct_result_ty_ref),
.id_result = casted_result_id,
.operand_1 = result_id,
.operand_2 = zero_id,
});
return casted_result_id;
}
return result_id;
}
fn airStore(self: *DeclGen, inst: Air.Inst.Index) !void {
const bin_op = self.air.instructions.items(.data)[inst].bin_op;
const ptr_ty = self.air.typeOf(bin_op.lhs);
@ -2405,35 +2468,6 @@ pub const DeclGen = struct {
try self.store(ptr_ty, ptr, value);
}
fn store(self: *DeclGen, ptr_ty: Type, ptr: IdRef, value: IdRef) !void {
const value_ty = ptr_ty.childType();
const converted_value = switch (value_ty.zigTypeTag()) {
.Bool => blk: {
const indirect_bool_ty_ref = try self.resolveType(value_ty, .indirect);
const result_id = self.spv.allocId();
const zero = try self.constInt(indirect_bool_ty_ref, 0);
const one = try self.constInt(indirect_bool_ty_ref, 1);
try self.func.body.emit(self.spv.gpa, .OpSelect, .{
.id_result_type = self.typeId(indirect_bool_ty_ref),
.id_result = result_id,
.condition = value,
.object_1 = one,
.object_2 = zero,
});
break :blk result_id;
},
else => value,
};
const access = spec.MemoryAccess.Extended{
.Volatile = ptr_ty.isVolatilePtr(),
};
try self.func.body.emit(self.spv.gpa, .OpStore, .{
.pointer = ptr,
.object = converted_value,
.memory_access = access,
});
}
fn airLoop(self: *DeclGen, inst: Air.Inst.Index) !void {
const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
const loop = self.air.extraData(Air.Block, ty_pl.payload);
@ -2488,14 +2522,13 @@ pub const DeclGen = struct {
const payload_ty = self.air.typeOfIndex(inst);
const err_ty_ref = try self.resolveType(Type.anyerror, .direct);
const payload_ty_ref = try self.resolveType(payload_ty, .direct);
const bool_ty_ref = try self.resolveType(Type.bool, .direct);
const eu_layout = self.errorUnionLayout(payload_ty);
if (!err_union_ty.errorUnionSet().errorSetIsEmpty()) {
const err_id = if (eu_layout.payload_has_bits)
try self.extractField(err_ty_ref, err_union_id, eu_layout.errorFieldIndex())
try self.extractField(Type.anyerror, err_union_id, eu_layout.errorFieldIndex())
else
err_union_id;
@ -2535,7 +2568,7 @@ pub const DeclGen = struct {
return null;
}
return try self.extractField(payload_ty_ref, err_union_id, eu_layout.payloadFieldIndex());
return try self.extractField(payload_ty, err_union_id, eu_layout.payloadFieldIndex());
}
fn airErrUnionErr(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@ -2559,7 +2592,7 @@ pub const DeclGen = struct {
return operand_id;
}
return try self.extractField(err_ty_ref, operand_id, eu_layout.errorFieldIndex());
return try self.extractField(Type.anyerror, operand_id, eu_layout.errorFieldIndex());
}
fn airWrapErrUnionErr(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@ -2598,6 +2631,69 @@ pub const DeclGen = struct {
return result_id;
}
fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, pred: enum { is_null, is_non_null }) !?IdRef {
if (self.liveness.isUnused(inst)) return null;
const un_op = self.air.instructions.items(.data)[inst].un_op;
const operand_id = try self.resolve(un_op);
const optional_ty = self.air.typeOf(un_op);
var buf: Type.Payload.ElemType = undefined;
const payload_ty = optional_ty.optionalChild(&buf);
const bool_ty_ref = try self.resolveType(Type.bool, .direct);
if (optional_ty.optionalReprIsPayload()) {
// Pointer payload represents nullability: pointer or slice.
var ptr_buf: Type.SlicePtrFieldTypeBuffer = undefined;
const ptr_ty = if (payload_ty.isSlice())
payload_ty.slicePtrFieldType(&ptr_buf)
else
payload_ty;
const ptr_id = if (payload_ty.isSlice())
try self.extractField(Type.bool, operand_id, 0)
else
operand_id;
const payload_ty_ref = try self.resolveType(ptr_ty, .direct);
const null_id = try self.constNull(payload_ty_ref);
const result_id = self.spv.allocId();
const operands = .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = result_id,
.operand_1 = ptr_id,
.operand_2 = null_id,
};
switch (pred) {
.is_null => try self.func.body.emit(self.spv.gpa, .OpPtrEqual, operands),
.is_non_null => try self.func.body.emit(self.spv.gpa, .OpPtrNotEqual, operands),
}
return result_id;
}
const is_non_null_id = if (optional_ty.hasRuntimeBitsIgnoreComptime())
try self.extractField(Type.bool, operand_id, 1)
else
// Optional representation is bool indicating whether the optional is set
operand_id;
return switch (pred) {
.is_null => blk: {
// Invert condition
const result_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpLogicalNot, .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = result_id,
.operand = is_non_null_id,
});
break :blk result_id;
},
.is_non_null => is_non_null_id,
};
}
fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
const target = self.getTarget();
const pl_op = self.air.instructions.items(.data)[inst].pl_op;