From 869880adac4f708b0fd1e65de6046b561c38528c Mon Sep 17 00:00:00 2001 From: Dominic <4678790+dweiller@users.noreply.github.com> Date: Sat, 11 May 2024 19:06:13 +1000 Subject: [PATCH] astgen: fix result info for catch switch_block_err_union --- lib/std/zig/AstGen.zig | 6 +- test/behavior.zig | 1 + test/behavior/switch_on_captured_error.zig | 82 ++++++++++++++++++++++ 3 files changed, 87 insertions(+), 2 deletions(-) diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index a52007eabf..6328fa3e86 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -7071,8 +7071,10 @@ fn switchExprErrUnion( .ctx = ri.ctx, }; - const payload_is_ref = node_ty == .@"if" and - if_full.payload_token != null and token_tags[if_full.payload_token.?] == .asterisk; + const payload_is_ref = switch (node_ty) { + .@"if" => if_full.payload_token != null and token_tags[if_full.payload_token.?] == .asterisk, + .@"catch" => ri.rl == .ref or ri.rl == .ref_coerced_ty, + }; // We need to call `rvalue` to write through to the pointer only if we had a // result pointer and aren't forwarding it. diff --git a/test/behavior.zig b/test/behavior.zig index 3081f6c9f9..ea8ea713ac 100644 --- a/test/behavior.zig +++ b/test/behavior.zig @@ -89,6 +89,7 @@ test { _ = @import("behavior/switch.zig"); _ = @import("behavior/switch_prong_err_enum.zig"); _ = @import("behavior/switch_prong_implicit_cast.zig"); + _ = @import("behavior/switch_on_captured_error.zig"); _ = @import("behavior/this.zig"); _ = @import("behavior/threadlocal.zig"); _ = @import("behavior/truncate.zig"); diff --git a/test/behavior/switch_on_captured_error.zig b/test/behavior/switch_on_captured_error.zig index f5ba762559..6e70c851b1 100644 --- a/test/behavior/switch_on_captured_error.zig +++ b/test/behavior/switch_on_captured_error.zig @@ -3,9 +3,11 @@ const assert = std.debug.assert; const expect = std.testing.expect; const expectError = std.testing.expectError; const expectEqual = std.testing.expectEqual; +const builtin = @import("builtin"); test "switch on error union catch capture" { if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; const S = struct { const Error = error{ A, B, C }; @@ -16,6 +18,7 @@ test "switch on error union catch capture" { try testCapture(); try testInline(); try testEmptyErrSet(); + try testAddressOf(); } fn testScalar() !void { @@ -252,6 +255,44 @@ test "switch on error union catch capture" { try expectEqual(@as(u64, 0), b); } } + + fn testAddressOf() !void { + { + const a: anyerror!usize = 0; + const ptr = &(a catch |e| switch (e) { + else => 3, + }); + comptime assert(@TypeOf(ptr) == *const usize); + try expectEqual(ptr, &(a catch unreachable)); + } + { + const a: anyerror!usize = error.A; + const ptr = &(a catch |e| switch (e) { + else => 3, + }); + comptime assert(@TypeOf(ptr) == *const comptime_int); + try expectEqual(3, ptr.*); + } + { + var a: anyerror!usize = 0; + _ = &a; + const ptr = &(a catch |e| switch (e) { + else => return, + }); + comptime assert(@TypeOf(ptr) == *usize); + ptr.* += 1; + try expectEqual(@as(usize, 1), a catch unreachable); + } + { + var a: anyerror!usize = error.A; + _ = &a; + const ptr = &(a catch |e| switch (e) { + else => return, + }); + comptime assert(@TypeOf(ptr) == *usize); + unreachable; + } + } }; try comptime S.doTheTest(); @@ -260,6 +301,7 @@ test "switch on error union catch capture" { test "switch on error union if else capture" { if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; const S = struct { const Error = error{ A, B, C }; @@ -276,6 +318,7 @@ test "switch on error union if else capture" { try testInlinePtr(); try testEmptyErrSet(); try testEmptyErrSetPtr(); + try testAddressOf(); } fn testScalar() !void { @@ -747,6 +790,45 @@ test "switch on error union if else capture" { try expectEqual(@as(u64, 0), b); } } + + fn testAddressOf() !void { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; + { + const a: anyerror!usize = 0; + const ptr = &(if (a) |*v| v.* else |e| switch (e) { + else => 3, + }); + comptime assert(@TypeOf(ptr) == *const usize); + try expectEqual(ptr, &(a catch unreachable)); + } + { + const a: anyerror!usize = error.A; + const ptr = &(if (a) |*v| v.* else |e| switch (e) { + else => 3, + }); + comptime assert(@TypeOf(ptr) == *const comptime_int); + try expectEqual(3, ptr.*); + } + { + var a: anyerror!usize = 0; + _ = &a; + const ptr = &(if (a) |*v| v.* else |e| switch (e) { + else => return, + }); + comptime assert(@TypeOf(ptr) == *usize); + ptr.* += 1; + try expectEqual(@as(usize, 1), a catch unreachable); + } + { + var a: anyerror!usize = error.A; + _ = &a; + const ptr = &(if (a) |*v| v.* else |e| switch (e) { + else => return, + }); + comptime assert(@TypeOf(ptr) == *usize); + unreachable; + } + } }; try comptime S.doTheTest();