diff --git a/src/Sema.zig b/src/Sema.zig index ca8e91dd68..02f6712d6c 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -7622,6 +7622,10 @@ fn zirErrUnionCode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileErro const inst_data = sema.code.instructions.items(.data)[inst].un_node; const src = inst_data.src(); const operand = try sema.resolveInst(inst_data.operand); + return sema.analyzeErrUnionCode(block, src, operand); +} + +fn analyzeErrUnionCode(sema: *Sema, block: *Block, src: LazySrcLoc, operand: Air.Inst.Ref) CompileError!Air.Inst.Ref { const operand_ty = sema.typeOf(operand); if (operand_ty.zigTypeTag() != .ErrorUnion) { return sema.fail(block, src, "expected error union type, found '{}'", .{ @@ -14129,6 +14133,14 @@ fn analyzeCmp( // numeric types. return sema.cmpNumeric(block, src, lhs, rhs, op, lhs_src, rhs_src); } + if (is_equality_cmp and lhs_ty.zigTypeTag() == .ErrorUnion and rhs_ty.zigTypeTag() == .ErrorSet) { + const casted_lhs = try sema.analyzeErrUnionCode(block, lhs_src, lhs); + return sema.cmpSelf(block, src, casted_lhs, rhs, op, lhs_src, rhs_src); + } + if (is_equality_cmp and lhs_ty.zigTypeTag() == .ErrorSet and rhs_ty.zigTypeTag() == .ErrorUnion) { + const casted_rhs = try sema.analyzeErrUnionCode(block, rhs_src, rhs); + return sema.cmpSelf(block, src, lhs, casted_rhs, op, lhs_src, rhs_src); + } const instructions = &[_]Air.Inst.Ref{ lhs, rhs }; const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{ .override = &[_]LazySrcLoc{ lhs_src, rhs_src } }); if (!resolved_type.isSelfComparable(is_equality_cmp)) { diff --git a/test/behavior/error.zig b/test/behavior/error.zig index dc29c9bc5b..ba5bdfdc20 100644 --- a/test/behavior/error.zig +++ b/test/behavior/error.zig @@ -809,3 +809,24 @@ test "alignment of wrapping an error union payload" { }; try expect((S.foo() catch unreachable).x == 1234); } + +test "compare error union and error set" { + if (builtin.zig_backend == .stage1) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; + + var a: anyerror = error.Foo; + var b: anyerror!u32 = error.Bar; + + try expect(a != b); + try expect(b != a); + + b = error.Foo; + + try expect(a == b); + try expect(b == a); + + b = 2; + + try expect(a != b); + try expect(b != a); +} diff --git a/test/cases/compile_errors/comparison_with_error_union_and_error_value.zig b/test/cases/compile_errors/comparison_with_error_union_and_error_value.zig deleted file mode 100644 index 3a9eabcb95..0000000000 --- a/test/cases/compile_errors/comparison_with_error_union_and_error_value.zig +++ /dev/null @@ -1,10 +0,0 @@ -export fn entry() void { - var number_or_error: anyerror!i32 = error.SomethingAwful; - _ = number_or_error == error.SomethingAwful; -} - -// error -// backend=stage2 -// target=native -// -// :3:25: error: operator == not allowed for type 'anyerror!i32'