From 5317d88414324c6555338e8574811c6710df4e44 Mon Sep 17 00:00:00 2001 From: mlugg Date: Wed, 5 Feb 2025 21:06:39 +0000 Subject: [PATCH] Sema: fix `@errorCast` with error unions Resolves: #20169 --- src/Sema.zig | 137 ++++++++++++++++++++++++---------------- test/behavior/error.zig | 21 +++++- 2 files changed, 100 insertions(+), 58 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index d66587abb1..32f611ce3e 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -23175,11 +23175,12 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data; const src = block.nodeOffset(extra.node); const operand_src = block.builtinCallArgSrc(extra.node, 0); - const base_dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_opt, "@errorCast"); + const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_opt, "@errorCast"); const operand = try sema.resolveInst(extra.rhs); - const base_operand_ty = sema.typeOf(operand); - const dest_tag = base_dest_ty.zigTypeTag(zcu); - const operand_tag = base_operand_ty.zigTypeTag(zcu); + const operand_ty = sema.typeOf(operand); + + const dest_tag = dest_ty.zigTypeTag(zcu); + const operand_tag = operand_ty.zigTypeTag(zcu); if (dest_tag != .error_set and dest_tag != .error_union) { return sema.fail(block, src, "expected error set or error union type, found '{s}'", .{@tagName(dest_tag)}); @@ -23191,107 +23192,133 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData return sema.fail(block, src, "cannot cast an error union type to error set", .{}); } if (dest_tag == .error_union and operand_tag == .error_union and - base_dest_ty.errorUnionPayload(zcu).toIntern() != base_operand_ty.errorUnionPayload(zcu).toIntern()) + dest_ty.errorUnionPayload(zcu).toIntern() != operand_ty.errorUnionPayload(zcu).toIntern()) { return sema.failWithOwnedErrorMsg(block, msg: { const msg = try sema.errMsg(src, "payload types of error unions must match", .{}); errdefer msg.destroy(sema.gpa); - const dest_ty = base_dest_ty.errorUnionPayload(zcu); - const operand_ty = base_operand_ty.errorUnionPayload(zcu); - try sema.errNote(src, msg, "destination payload is '{}'", .{dest_ty.fmt(pt)}); - try sema.errNote(src, msg, "operand payload is '{}'", .{operand_ty.fmt(pt)}); + const dest_payload_ty = dest_ty.errorUnionPayload(zcu); + const operand_payload_ty = operand_ty.errorUnionPayload(zcu); + try sema.errNote(src, msg, "destination payload is '{}'", .{dest_payload_ty.fmt(pt)}); + try sema.errNote(src, msg, "operand payload is '{}'", .{operand_payload_ty.fmt(pt)}); try addDeclaredHereNote(sema, msg, dest_ty); try addDeclaredHereNote(sema, msg, operand_ty); break :msg msg; }); } - const dest_ty = if (dest_tag == .error_union) base_dest_ty.errorUnionSet(zcu) else base_dest_ty; - const operand_ty = if (operand_tag == .error_union) base_operand_ty.errorUnionSet(zcu) else base_operand_ty; - - // operand must be defined since it can be an invalid error value - const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand); + const dest_err_ty = switch (dest_tag) { + .error_union => dest_ty.errorUnionSet(zcu), + .error_set => dest_ty, + else => unreachable, + }; + const operand_err_ty = switch (operand_tag) { + .error_union => operand_ty.errorUnionSet(zcu), + .error_set => operand_ty, + else => unreachable, + }; const disjoint = disjoint: { // Try avoiding resolving inferred error sets if we can - if (!dest_ty.isAnyError(zcu) and dest_ty.errorSetIsEmpty(zcu)) break :disjoint true; - if (!operand_ty.isAnyError(zcu) and operand_ty.errorSetIsEmpty(zcu)) break :disjoint true; - if (dest_ty.isAnyError(zcu)) break :disjoint false; - if (operand_ty.isAnyError(zcu)) break :disjoint false; - const dest_err_names = dest_ty.errorSetNames(zcu); + if (!dest_err_ty.isAnyError(zcu) and dest_err_ty.errorSetIsEmpty(zcu)) break :disjoint true; + if (!operand_err_ty.isAnyError(zcu) and operand_err_ty.errorSetIsEmpty(zcu)) break :disjoint true; + if (dest_err_ty.isAnyError(zcu)) break :disjoint false; + if (operand_err_ty.isAnyError(zcu)) break :disjoint false; + const dest_err_names = dest_err_ty.errorSetNames(zcu); for (0..dest_err_names.len) |dest_err_index| { - if (Type.errorSetHasFieldIp(ip, operand_ty.toIntern(), dest_err_names.get(ip)[dest_err_index])) + if (Type.errorSetHasFieldIp(ip, operand_err_ty.toIntern(), dest_err_names.get(ip)[dest_err_index])) break :disjoint false; } - if (!ip.isInferredErrorSetType(dest_ty.toIntern()) and - !ip.isInferredErrorSetType(operand_ty.toIntern())) + if (!ip.isInferredErrorSetType(dest_err_ty.toIntern()) and + !ip.isInferredErrorSetType(operand_err_ty.toIntern())) { break :disjoint true; } - _ = try sema.resolveInferredErrorSetTy(block, src, dest_ty.toIntern()); - _ = try sema.resolveInferredErrorSetTy(block, operand_src, operand_ty.toIntern()); + _ = try sema.resolveInferredErrorSetTy(block, src, dest_err_ty.toIntern()); + _ = try sema.resolveInferredErrorSetTy(block, operand_src, operand_err_ty.toIntern()); for (0..dest_err_names.len) |dest_err_index| { - if (Type.errorSetHasFieldIp(ip, operand_ty.toIntern(), dest_err_names.get(ip)[dest_err_index])) + if (Type.errorSetHasFieldIp(ip, operand_err_ty.toIntern(), dest_err_names.get(ip)[dest_err_index])) break :disjoint false; } break :disjoint true; }; - if (disjoint and dest_tag != .error_union) { + if (disjoint and !(operand_tag == .error_union and dest_tag == .error_union)) { return sema.fail(block, src, "error sets '{}' and '{}' have no common errors", .{ - operand_ty.fmt(pt), dest_ty.fmt(pt), + operand_err_ty.fmt(pt), dest_err_ty.fmt(pt), }); } - if (maybe_operand_val) |val| { - if (!dest_ty.isAnyError(zcu)) check: { - const operand_val = zcu.intern_pool.indexToKey(val.toIntern()); - var error_name: InternPool.NullTerminatedString = undefined; - if (operand_tag == .error_union) { - if (operand_val.error_union.val != .err_name) break :check; - error_name = operand_val.error_union.val.err_name; - } else { - error_name = operand_val.err.name; - } - if (!Type.errorSetHasFieldIp(ip, dest_ty.toIntern(), error_name)) { - return sema.fail(block, src, "'error.{}' not a member of error set '{}'", .{ - error_name.fmt(ip), dest_ty.fmt(pt), - }); - } + // operand must be defined since it can be an invalid error value + if (try sema.resolveDefinedValue(block, operand_src, operand)) |operand_val| { + const err_name: InternPool.NullTerminatedString = switch (operand_tag) { + .error_set => ip.indexToKey(operand_val.toIntern()).err.name, + .error_union => switch (ip.indexToKey(operand_val.toIntern()).error_union.val) { + .err_name => |name| name, + .payload => |payload_val| { + assert(dest_tag == .error_union); // should be guaranteed from the type checks above + return sema.coerce(block, dest_ty, Air.internedToRef(payload_val), operand_src); + }, + }, + else => unreachable, + }; + + if (!dest_err_ty.isAnyError(zcu) and !Type.errorSetHasFieldIp(ip, dest_err_ty.toIntern(), err_name)) { + return sema.fail(block, src, "'error.{}' not a member of error set '{}'", .{ + err_name.fmt(ip), dest_err_ty.fmt(pt), + }); } - return Air.internedToRef((try pt.getCoerced(val, base_dest_ty)).toIntern()); + return Air.internedToRef(try pt.intern(switch (dest_tag) { + .error_set => .{ .err = .{ + .ty = dest_ty.toIntern(), + .name = err_name, + } }, + .error_union => .{ .error_union = .{ + .ty = dest_ty.toIntern(), + .val = .{ .err_name = err_name }, + } }, + else => unreachable, + })); } - try sema.requireRuntimeBlock(block, src, operand_src); const err_int_ty = try pt.errorIntType(); - if (block.wantSafety() and !dest_ty.isAnyError(zcu) and - dest_ty.toIntern() != .adhoc_inferred_error_set_type and + if (block.wantSafety() and !dest_err_ty.isAnyError(zcu) and + dest_err_ty.toIntern() != .adhoc_inferred_error_set_type and zcu.backendSupportsFeature(.error_set_has_value)) { - if (dest_tag == .error_union) { - const err_code = try block.addTyOp(.unwrap_errunion_err, operand_ty, operand); - const err_int = try block.addBitCast(err_int_ty, err_code); - const zero_err = try pt.intRef(try pt.errorIntType(), 0); + const err_code_inst = switch (operand_tag) { + .error_set => operand, + .error_union => try block.addTyOp(.unwrap_errunion_err, operand_err_ty, operand), + else => unreachable, + }; + const err_int_inst = try block.addBitCast(err_int_ty, err_code_inst); - const is_zero = try block.addBinOp(.cmp_eq, err_int, zero_err); + if (dest_tag == .error_union) { + const zero_err = try pt.intRef(err_int_ty, 0); + const is_zero = try block.addBinOp(.cmp_eq, err_int_inst, zero_err); if (disjoint) { // Error must be zero. try sema.addSafetyCheck(block, src, is_zero, .invalid_error_code); } else { // Error must be in destination set or zero. - const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code); + const has_value = try block.addTyOp(.error_set_has_value, dest_err_ty, err_int_inst); const ok = try block.addBinOp(.bool_or, has_value, is_zero); try sema.addSafetyCheck(block, src, ok, .invalid_error_code); } } else { - const err_int_inst = try block.addBitCast(err_int_ty, operand); - const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst); + const ok = try block.addTyOp(.error_set_has_value, dest_err_ty, err_int_inst); try sema.addSafetyCheck(block, src, ok, .invalid_error_code); } } - return block.addBitCast(base_dest_ty, operand); + + if (operand_tag == .error_set and dest_tag == .error_union) { + const err_val = try block.addBitCast(dest_err_ty, operand); + return block.addTyOp(.wrap_errunion_err, dest_ty, err_val); + } else { + return block.addBitCast(dest_ty, operand); + } } fn zirPtrCastFull(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref { diff --git a/test/behavior/error.zig b/test/behavior/error.zig index c915da9a6e..5dd9cfd192 100644 --- a/test/behavior/error.zig +++ b/test/behavior/error.zig @@ -1060,9 +1060,24 @@ test "errorCast to adhoc inferred error set" { try std.testing.expect((try S.baz()) == 1234); } -test "errorCast from error sets to error unions" { - const err_union: Set1!void = @errorCast(error.A); - try expectError(error.A, err_union); +test "@errorCast from error set to error union" { + const S = struct { + fn doTheTest(set: error{ A, B }) error{A}!i32 { + return @errorCast(set); + } + }; + try expectError(error.A, S.doTheTest(error.A)); + try expectError(error.A, comptime S.doTheTest(error.A)); +} + +test "@errorCast from error union to error union" { + const S = struct { + fn doTheTest(set: error{ A, B }!i32) error{A}!i32 { + return @errorCast(set); + } + }; + try expectError(error.A, S.doTheTest(error.A)); + try expectError(error.A, comptime S.doTheTest(error.A)); } test "result location initialization of error union with OPV payload" {