Sema: fix issues in @errorCast with error unions

This commit is contained in:
Veikka Tuominen 2023-10-02 15:44:50 +03:00 committed by Andrew Kelley
parent c9c3ee704c
commit 0bdbd3e235
3 changed files with 46 additions and 9 deletions

View File

@ -21771,10 +21771,10 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
// operand must be defined since it can be an invalid error value
const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand);
if (disjoint: {
const disjoint = disjoint: {
// Try avoiding resolving inferred error sets if we can
if (!dest_ty.isAnyError(mod) and dest_ty.errorSetNames(mod).len == 0) break :disjoint true;
if (!operand_ty.isAnyError(mod) and operand_ty.errorSetNames(mod).len == 0) break :disjoint true;
if (!dest_ty.isAnyError(mod) and dest_ty.errorSetIsEmpty(mod)) break :disjoint true;
if (!operand_ty.isAnyError(mod) and operand_ty.errorSetIsEmpty(mod)) break :disjoint true;
if (dest_ty.isAnyError(mod)) break :disjoint false;
if (operand_ty.isAnyError(mod)) break :disjoint false;
for (dest_ty.errorSetNames(mod)) |dest_err_name| {
@ -21796,7 +21796,8 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
}
break :disjoint true;
}) {
};
if (disjoint and dest_tag != .ErrorUnion) {
const msg = msg: {
const msg = try sema.errMsg(
block,
@ -21850,10 +21851,16 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
.int = .{ .ty = .u16_type, .storage = .{ .u64 = 0 } },
}));
const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code);
const is_zero = try block.addBinOp(.cmp_eq, err_int, zero_u16);
const ok = try block.addBinOp(.bit_or, has_value, is_zero);
try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
if (disjoint) {
// Error must be zero.
try sema.addSafetyCheck(block, src, is_zero, .invalid_error_code);
} else {
// Error must be in destination set or zero.
const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code);
const ok = try block.addBinOp(.bit_or, has_value, is_zero);
try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
}
} else {
const err_int_inst = try block.addBitCast(Type.err_int, operand);
const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);

View File

@ -238,13 +238,23 @@ fn testExplicitErrorSetCast(set1: Set1) !void {
test "@errorCast on error unions" {
const S = struct {
fn doTheTest() !void {
const casted: error{Bad}!i32 = @errorCast(retErrUnion());
try expect((try casted) == 1234);
{
const casted: error{Bad}!i32 = @errorCast(retErrUnion());
try expect((try casted) == 1234);
}
{
const casted: error{Bad}!i32 = @errorCast(retInferredErrUnion());
try expect((try casted) == 5678);
}
}
fn retErrUnion() anyerror!i32 {
return 1234;
}
fn retInferredErrUnion() !i32 {
return 5678;
}
};
try S.doTheTest();

View File

@ -0,0 +1,20 @@
const std = @import("std");
pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace, _: ?usize) noreturn {
_ = stack_trace;
if (std.mem.eql(u8, message, "invalid error code")) {
std.process.exit(0);
}
std.process.exit(1);
}
pub fn main() !void {
const bar: error{Foo}!i32 = @errorCast(foo());
_ = &bar;
return error.TestFailed;
}
fn foo() anyerror!i32 {
return error.Bar;
}
// run
// backend=llvm
// target=native