Sema: fix @errorCast with error unions

Resolves: #20169
This commit is contained in:
mlugg 2025-02-05 21:06:39 +00:00
parent fbbf34e563
commit 5317d88414
No known key found for this signature in database
GPG Key ID: 3F5B7DCCBF4AF02E
2 changed files with 100 additions and 58 deletions

View File

@ -23175,11 +23175,12 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data;
const src = block.nodeOffset(extra.node);
const operand_src = block.builtinCallArgSrc(extra.node, 0);
const base_dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_opt, "@errorCast");
const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_opt, "@errorCast");
const operand = try sema.resolveInst(extra.rhs);
const base_operand_ty = sema.typeOf(operand);
const dest_tag = base_dest_ty.zigTypeTag(zcu);
const operand_tag = base_operand_ty.zigTypeTag(zcu);
const operand_ty = sema.typeOf(operand);
const dest_tag = dest_ty.zigTypeTag(zcu);
const operand_tag = operand_ty.zigTypeTag(zcu);
if (dest_tag != .error_set and dest_tag != .error_union) {
return sema.fail(block, src, "expected error set or error union type, found '{s}'", .{@tagName(dest_tag)});
@ -23191,107 +23192,133 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
return sema.fail(block, src, "cannot cast an error union type to error set", .{});
}
if (dest_tag == .error_union and operand_tag == .error_union and
base_dest_ty.errorUnionPayload(zcu).toIntern() != base_operand_ty.errorUnionPayload(zcu).toIntern())
dest_ty.errorUnionPayload(zcu).toIntern() != operand_ty.errorUnionPayload(zcu).toIntern())
{
return sema.failWithOwnedErrorMsg(block, msg: {
const msg = try sema.errMsg(src, "payload types of error unions must match", .{});
errdefer msg.destroy(sema.gpa);
const dest_ty = base_dest_ty.errorUnionPayload(zcu);
const operand_ty = base_operand_ty.errorUnionPayload(zcu);
try sema.errNote(src, msg, "destination payload is '{}'", .{dest_ty.fmt(pt)});
try sema.errNote(src, msg, "operand payload is '{}'", .{operand_ty.fmt(pt)});
const dest_payload_ty = dest_ty.errorUnionPayload(zcu);
const operand_payload_ty = operand_ty.errorUnionPayload(zcu);
try sema.errNote(src, msg, "destination payload is '{}'", .{dest_payload_ty.fmt(pt)});
try sema.errNote(src, msg, "operand payload is '{}'", .{operand_payload_ty.fmt(pt)});
try addDeclaredHereNote(sema, msg, dest_ty);
try addDeclaredHereNote(sema, msg, operand_ty);
break :msg msg;
});
}
const dest_ty = if (dest_tag == .error_union) base_dest_ty.errorUnionSet(zcu) else base_dest_ty;
const operand_ty = if (operand_tag == .error_union) base_operand_ty.errorUnionSet(zcu) else base_operand_ty;
// operand must be defined since it can be an invalid error value
const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand);
const dest_err_ty = switch (dest_tag) {
.error_union => dest_ty.errorUnionSet(zcu),
.error_set => dest_ty,
else => unreachable,
};
const operand_err_ty = switch (operand_tag) {
.error_union => operand_ty.errorUnionSet(zcu),
.error_set => operand_ty,
else => unreachable,
};
const disjoint = disjoint: {
// Try avoiding resolving inferred error sets if we can
if (!dest_ty.isAnyError(zcu) and dest_ty.errorSetIsEmpty(zcu)) break :disjoint true;
if (!operand_ty.isAnyError(zcu) and operand_ty.errorSetIsEmpty(zcu)) break :disjoint true;
if (dest_ty.isAnyError(zcu)) break :disjoint false;
if (operand_ty.isAnyError(zcu)) break :disjoint false;
const dest_err_names = dest_ty.errorSetNames(zcu);
if (!dest_err_ty.isAnyError(zcu) and dest_err_ty.errorSetIsEmpty(zcu)) break :disjoint true;
if (!operand_err_ty.isAnyError(zcu) and operand_err_ty.errorSetIsEmpty(zcu)) break :disjoint true;
if (dest_err_ty.isAnyError(zcu)) break :disjoint false;
if (operand_err_ty.isAnyError(zcu)) break :disjoint false;
const dest_err_names = dest_err_ty.errorSetNames(zcu);
for (0..dest_err_names.len) |dest_err_index| {
if (Type.errorSetHasFieldIp(ip, operand_ty.toIntern(), dest_err_names.get(ip)[dest_err_index]))
if (Type.errorSetHasFieldIp(ip, operand_err_ty.toIntern(), dest_err_names.get(ip)[dest_err_index]))
break :disjoint false;
}
if (!ip.isInferredErrorSetType(dest_ty.toIntern()) and
!ip.isInferredErrorSetType(operand_ty.toIntern()))
if (!ip.isInferredErrorSetType(dest_err_ty.toIntern()) and
!ip.isInferredErrorSetType(operand_err_ty.toIntern()))
{
break :disjoint true;
}
_ = try sema.resolveInferredErrorSetTy(block, src, dest_ty.toIntern());
_ = try sema.resolveInferredErrorSetTy(block, operand_src, operand_ty.toIntern());
_ = try sema.resolveInferredErrorSetTy(block, src, dest_err_ty.toIntern());
_ = try sema.resolveInferredErrorSetTy(block, operand_src, operand_err_ty.toIntern());
for (0..dest_err_names.len) |dest_err_index| {
if (Type.errorSetHasFieldIp(ip, operand_ty.toIntern(), dest_err_names.get(ip)[dest_err_index]))
if (Type.errorSetHasFieldIp(ip, operand_err_ty.toIntern(), dest_err_names.get(ip)[dest_err_index]))
break :disjoint false;
}
break :disjoint true;
};
if (disjoint and dest_tag != .error_union) {
if (disjoint and !(operand_tag == .error_union and dest_tag == .error_union)) {
return sema.fail(block, src, "error sets '{}' and '{}' have no common errors", .{
operand_ty.fmt(pt), dest_ty.fmt(pt),
operand_err_ty.fmt(pt), dest_err_ty.fmt(pt),
});
}
if (maybe_operand_val) |val| {
if (!dest_ty.isAnyError(zcu)) check: {
const operand_val = zcu.intern_pool.indexToKey(val.toIntern());
var error_name: InternPool.NullTerminatedString = undefined;
if (operand_tag == .error_union) {
if (operand_val.error_union.val != .err_name) break :check;
error_name = operand_val.error_union.val.err_name;
} else {
error_name = operand_val.err.name;
}
if (!Type.errorSetHasFieldIp(ip, dest_ty.toIntern(), error_name)) {
return sema.fail(block, src, "'error.{}' not a member of error set '{}'", .{
error_name.fmt(ip), dest_ty.fmt(pt),
});
}
// operand must be defined since it can be an invalid error value
if (try sema.resolveDefinedValue(block, operand_src, operand)) |operand_val| {
const err_name: InternPool.NullTerminatedString = switch (operand_tag) {
.error_set => ip.indexToKey(operand_val.toIntern()).err.name,
.error_union => switch (ip.indexToKey(operand_val.toIntern()).error_union.val) {
.err_name => |name| name,
.payload => |payload_val| {
assert(dest_tag == .error_union); // should be guaranteed from the type checks above
return sema.coerce(block, dest_ty, Air.internedToRef(payload_val), operand_src);
},
},
else => unreachable,
};
if (!dest_err_ty.isAnyError(zcu) and !Type.errorSetHasFieldIp(ip, dest_err_ty.toIntern(), err_name)) {
return sema.fail(block, src, "'error.{}' not a member of error set '{}'", .{
err_name.fmt(ip), dest_err_ty.fmt(pt),
});
}
return Air.internedToRef((try pt.getCoerced(val, base_dest_ty)).toIntern());
return Air.internedToRef(try pt.intern(switch (dest_tag) {
.error_set => .{ .err = .{
.ty = dest_ty.toIntern(),
.name = err_name,
} },
.error_union => .{ .error_union = .{
.ty = dest_ty.toIntern(),
.val = .{ .err_name = err_name },
} },
else => unreachable,
}));
}
try sema.requireRuntimeBlock(block, src, operand_src);
const err_int_ty = try pt.errorIntType();
if (block.wantSafety() and !dest_ty.isAnyError(zcu) and
dest_ty.toIntern() != .adhoc_inferred_error_set_type and
if (block.wantSafety() and !dest_err_ty.isAnyError(zcu) and
dest_err_ty.toIntern() != .adhoc_inferred_error_set_type and
zcu.backendSupportsFeature(.error_set_has_value))
{
if (dest_tag == .error_union) {
const err_code = try block.addTyOp(.unwrap_errunion_err, operand_ty, operand);
const err_int = try block.addBitCast(err_int_ty, err_code);
const zero_err = try pt.intRef(try pt.errorIntType(), 0);
const err_code_inst = switch (operand_tag) {
.error_set => operand,
.error_union => try block.addTyOp(.unwrap_errunion_err, operand_err_ty, operand),
else => unreachable,
};
const err_int_inst = try block.addBitCast(err_int_ty, err_code_inst);
const is_zero = try block.addBinOp(.cmp_eq, err_int, zero_err);
if (dest_tag == .error_union) {
const zero_err = try pt.intRef(err_int_ty, 0);
const is_zero = try block.addBinOp(.cmp_eq, err_int_inst, zero_err);
if (disjoint) {
// Error must be zero.
try sema.addSafetyCheck(block, src, is_zero, .invalid_error_code);
} else {
// Error must be in destination set or zero.
const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code);
const has_value = try block.addTyOp(.error_set_has_value, dest_err_ty, err_int_inst);
const ok = try block.addBinOp(.bool_or, has_value, is_zero);
try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
}
} else {
const err_int_inst = try block.addBitCast(err_int_ty, operand);
const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);
const ok = try block.addTyOp(.error_set_has_value, dest_err_ty, err_int_inst);
try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
}
}
return block.addBitCast(base_dest_ty, operand);
if (operand_tag == .error_set and dest_tag == .error_union) {
const err_val = try block.addBitCast(dest_err_ty, operand);
return block.addTyOp(.wrap_errunion_err, dest_ty, err_val);
} else {
return block.addBitCast(dest_ty, operand);
}
}
fn zirPtrCastFull(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {

View File

@ -1060,9 +1060,24 @@ test "errorCast to adhoc inferred error set" {
try std.testing.expect((try S.baz()) == 1234);
}
test "errorCast from error sets to error unions" {
const err_union: Set1!void = @errorCast(error.A);
try expectError(error.A, err_union);
test "@errorCast from error set to error union" {
const S = struct {
fn doTheTest(set: error{ A, B }) error{A}!i32 {
return @errorCast(set);
}
};
try expectError(error.A, S.doTheTest(error.A));
try expectError(error.A, comptime S.doTheTest(error.A));
}
test "@errorCast from error union to error union" {
const S = struct {
fn doTheTest(set: error{ A, B }!i32) error{A}!i32 {
return @errorCast(set);
}
};
try expectError(error.A, S.doTheTest(error.A));
try expectError(error.A, comptime S.doTheTest(error.A));
}
test "result location initialization of error union with OPV payload" {