spirv: lower air try

Implements code generation for the try air tag. This commit also adds
a utility `errorUnionLayout` function that helps keeping the layout
of a spir-v error union consistent.
This commit is contained in:
Robin Voetter 2023-04-10 20:34:15 +02:00
parent dfecf89d06
commit 0bae2caaf3
No known key found for this signature in database
GPG Key ID: E755662F227CB468

View File

@ -765,21 +765,18 @@ pub const DeclGen = struct {
const is_pl = val.errorUnionIsPayload();
const error_val = if (!is_pl) val else Value.initTag(.zero);
if (!payload_ty.hasRuntimeBitsIgnoreComptime()) {
const eu_layout = dg.errorUnionLayout(payload_ty);
if (!eu_layout.payload_has_bits) {
return try self.lower(Type.anyerror, error_val);
}
const payload_align = payload_ty.abiAlignment(target);
const error_align = Type.anyerror.abiAlignment(target);
const payload_size = payload_ty.abiSize(target);
const error_size = Type.anyerror.abiAlignment(target);
const ty_size = ty.abiSize(target);
const padding = ty_size - payload_size - error_size;
const payload_val = if (val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef);
if (error_align > payload_align) {
if (eu_layout.error_first) {
try self.lower(Type.anyerror, error_val);
try self.lower(payload_ty, payload_val);
} else {
@ -1277,18 +1274,16 @@ pub const DeclGen = struct {
.ErrorUnion => {
const payload_ty = ty.errorUnionPayload();
const error_ty_ref = try self.resolveType(Type.anyerror, .indirect);
if (!payload_ty.hasRuntimeBitsIgnoreComptime()) {
const eu_layout = self.errorUnionLayout(payload_ty);
if (!eu_layout.payload_has_bits) {
return error_ty_ref;
}
const payload_ty_ref = try self.resolveType(payload_ty, .indirect);
const payload_align = payload_ty.abiAlignment(target);
const error_align = Type.anyerror.abiAlignment(target);
var members = std.BoundedArray(SpvType.Payload.Struct.Member, 2){};
// Similar to unions, we're going to put the most aligned member first.
if (error_align > payload_align) {
if (eu_layout.error_first) {
// Put the error first
members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error" });
members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload" });
@ -1336,6 +1331,34 @@ pub const DeclGen = struct {
};
}
const ErrorUnionLayout = struct {
payload_has_bits: bool,
error_first: bool,
fn errorFieldIndex(self: @This()) u32 {
assert(self.payload_has_bits);
return if (self.error_first) 0 else 1;
}
fn payloadFieldIndex(self: @This()) u32 {
assert(self.payload_has_bits);
return if (self.error_first) 1 else 0;
}
};
fn errorUnionLayout(self: *DeclGen, payload_ty: Type) ErrorUnionLayout {
const target = self.getTarget();
const error_align = Type.anyerror.abiAlignment(target);
const payload_align = payload_ty.abiAlignment(target);
const error_first = error_align > payload_align;
return .{
.payload_has_bits = payload_ty.hasRuntimeBitsIgnoreComptime(),
.error_first = error_first,
};
}
/// The SPIR-V backend is not yet advanced enough to support the std testing infrastructure.
/// In order to be able to run tests, we "temporarily" lower test kernels into separate entry-
/// points. The test executor will then be able to invoke these to run the tests.
@ -1585,6 +1608,7 @@ pub const DeclGen = struct {
.loop => return self.airLoop(inst),
.ret => return self.airRet(inst),
.ret_load => return self.airRetLoad(inst),
.@"try" => try self.airTry(inst),
.switch_br => return self.airSwitchBr(inst),
.unreach => return self.airUnreach(),
@ -1752,16 +1776,15 @@ pub const DeclGen = struct {
const operand_ty_id = try self.resolveTypeId(operand_ty);
const result_type_id = try self.resolveTypeId(result_ty);
const overflow_member_ty = try self.intType(.unsigned, info.bits);
const overflow_member_ty_id = self.typeId(overflow_member_ty);
const overflow_member_ty_ref = try self.intType(.unsigned, info.bits);
const op_result_id = blk: {
// Construct the SPIR-V result type.
// It is almost the same as the zig one, except that the fields must be the same type
// and they must be unsigned.
const overflow_result_ty_ref = try self.spv.simpleStructType(&.{
.{ .ty = overflow_member_ty, .name = "res" },
.{ .ty = overflow_member_ty, .name = "ov" },
.{ .ty = overflow_member_ty_ref, .name = "res" },
.{ .ty = overflow_member_ty_ref, .name = "ov" },
});
const result_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpIAddCarry, .{
@ -1775,8 +1798,8 @@ pub const DeclGen = struct {
// Now convert the SPIR-V flavor result into a Zig-flavor result.
// First, extract the two fields.
const unsigned_result = try self.extractField(overflow_member_ty_id, op_result_id, 0);
const overflow = try self.extractField(overflow_member_ty_id, op_result_id, 1);
const unsigned_result = try self.extractField(overflow_member_ty_ref, op_result_id, 0);
const overflow = try self.extractField(overflow_member_ty_ref, op_result_id, 1);
// We need to convert the results to the types that Zig expects here.
// The `result` is the same type except unsigned, so we can just bitcast that.
@ -1954,15 +1977,16 @@ pub const DeclGen = struct {
return result_id;
}
fn extractField(self: *DeclGen, result_ty: IdResultType, object: IdRef, field: u32) !IdRef {
fn extractField(self: *DeclGen, result_ty_ref: SpvType.Ref, object: IdRef, field: u32) !IdRef {
const result_id = self.spv.allocId();
const indexes = [_]u32{field};
try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
.id_result_type = result_ty,
.id_result_type = self.typeId(result_ty_ref),
.id_result = result_id,
.composite = object,
.indexes = &indexes,
});
// TODO: Convert bools, direct structs should have their field types as indirect values.
return result_id;
}
@ -1970,7 +1994,7 @@ pub const DeclGen = struct {
if (self.liveness.isUnused(inst)) return null;
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
return try self.extractField(
try self.resolveTypeId(self.air.typeOfIndex(inst)),
try self.resolveType(self.air.typeOfIndex(inst), .direct),
try self.resolve(ty_op.operand),
field,
);
@ -2451,6 +2475,66 @@ pub const DeclGen = struct {
});
}
fn airTry(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
const pl_op = self.air.instructions.items(.data)[inst].pl_op;
const err_union_id = try self.resolve(pl_op.operand);
const extra = self.air.extraData(Air.Try, pl_op.payload);
const body = self.air.extra[extra.end..][0..extra.data.body_len];
const err_union_ty = self.air.typeOf(pl_op.operand);
const payload_ty = self.air.typeOfIndex(inst);
const err_ty_ref = try self.resolveType(Type.anyerror, .direct);
const payload_ty_ref = try self.resolveType(payload_ty, .direct);
const bool_ty_ref = try self.resolveType(Type.bool, .direct);
const eu_layout = self.errorUnionLayout(payload_ty);
if (!err_union_ty.errorUnionSet().errorSetIsEmpty()) {
const err_id = if (eu_layout.payload_has_bits)
try self.extractField(err_ty_ref, err_union_id, eu_layout.errorFieldIndex())
else
err_union_id;
const zero_id = try self.constInt(err_ty_ref, 0);
const is_err_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = is_err_id,
.operand_1 = err_id,
.operand_2 = zero_id,
});
// When there is an error, we must evaluate `body`. Otherwise we must continue
// with the current body.
// Just generate a new block here, then generate a new block inline for the remainder of the body.
const err_block = self.spv.allocId();
const ok_block = self.spv.allocId();
// TODO: Merge block
try self.func.body.emit(self.spv.gpa, .OpBranchConditional, .{
.condition = is_err_id,
.true_label = err_block,
.false_label = ok_block,
});
try self.beginSpvBlock(err_block);
try self.genBody(body);
try self.beginSpvBlock(ok_block);
// Now just extract the payload, if required.
}
if (self.liveness.isUnused(inst)) {
return null;
}
if (!eu_layout.payload_has_bits) {
return null;
}
return try self.extractField(payload_ty_ref, err_union_id, eu_layout.payloadFieldIndex());
}
fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
const target = self.getTarget();
const pl_op = self.air.instructions.items(.data)[inst].pl_op;