diff --git a/src/Module.zig b/src/Module.zig index 93e4b87d5b..693cc3b5a0 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -824,7 +824,7 @@ pub const ErrorSet = struct { /// Offset from Decl node index, points to the error set AST node. node_offset: i32, /// The string bytes are stored in the owner Decl arena. - /// They are in the same order they appear in the AST. + /// These must be in sorted order. See sortNames. names: NameMap, pub const NameMap = std.StringArrayHashMapUnmanaged(void); @@ -836,6 +836,18 @@ pub const ErrorSet = struct { .lazy = .{ .node_offset = self.node_offset }, }; } + + /// sort the NameMap. This should be called whenever the map is modified. + /// alloc should be the allocator used for the NameMap data. + pub fn sortNames(names: *NameMap) void { + const Context = struct { + keys: [][]const u8, + pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool { + return std.mem.lessThan(u8, ctx.keys[a_index], ctx.keys[b_index]); + } + }; + names.sort(Context{ .keys = names.keys() }); + } }; pub const RequiresComptime = enum { no, yes, unknown, wip }; diff --git a/src/Sema.zig b/src/Sema.zig index 195a0ef274..f74fa1e0bf 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -2212,6 +2212,10 @@ fn zirErrorSetDecl( return sema.fail(block, src, "duplicate error set field {s}", .{name}); } } + + // names must be sorted. + Module.ErrorSet.sortNames(&names); + error_set.* = .{ .owner_decl = new_decl, .node_offset = inst_data.src_node, diff --git a/src/type.zig b/src/type.zig index b0ff9e59e9..e35006d930 100644 --- a/src/type.zig +++ b/src/type.zig @@ -556,35 +556,37 @@ pub const Type = extern union { return info_a.signedness == info_b.signedness and info_a.bits == info_b.bits; }, + .error_set_inferred => { + // Inferred error sets are only equal if both are inferred + // and they originate from the exact same function. + const a_set = a.castTag(.error_set_inferred).?.data; + const b_set = (b.castTag(.error_set_inferred) orelse return false).data; + return a_set.func == b_set.func; + }, + + .anyerror => { + return b.tag() == .anyerror; + }, + .error_set, .error_set_single, - .anyerror, - .error_set_inferred, .error_set_merged, => { - if (b.zigTypeTag() != .ErrorSet) return false; - - // TODO: revisit the language specification for how to evaluate equality - // for error set types. - - if (a.tag() == .anyerror and b.tag() == .anyerror) { - return true; + switch (b.tag()) { + .error_set, .error_set_single, .error_set_merged => {}, + else => return false, } - if (a.tag() == .error_set and b.tag() == .error_set) { - return a.castTag(.error_set).?.data.owner_decl == b.castTag(.error_set).?.data.owner_decl; + // Two resolved sets match if their error set names match. + // Since they are pre-sorted we compare them element-wise. + const a_set = a.errorSetNames(); + const b_set = b.errorSetNames(); + if (a_set.len != b_set.len) return false; + for (a_set) |a_item, i| { + const b_item = b_set[i]; + if (!std.mem.eql(u8, a_item, b_item)) return false; } - - if (a.tag() == .error_set_inferred and b.tag() == .error_set_inferred) { - return a.castTag(.error_set_inferred).?.data == b.castTag(.error_set_inferred).?.data; - } - - if (a.tag() == .error_set_single and b.tag() == .error_set_single) { - const a_data = a.castTag(.error_set_single).?.data; - const b_data = b.castTag(.error_set_single).?.data; - return std.mem.eql(u8, a_data, b_data); - } - return false; + return true; }, .@"opaque" => { @@ -961,12 +963,30 @@ pub const Type = extern union { .error_set, .error_set_single, - .anyerror, - .error_set_inferred, .error_set_merged, => { + // all are treated like an "error set" for hashing std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet); - // TODO implement this after revisiting Type.Eql for error sets + std.hash.autoHash(hasher, Tag.error_set); + + const names = ty.errorSetNames(); + std.hash.autoHash(hasher, names.len); + assert(std.sort.isSorted([]const u8, names, u8, std.mem.lessThan)); + for (names) |name| hasher.update(name); + }, + + .anyerror => { + // anyerror is distinct from other error sets + std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet); + std.hash.autoHash(hasher, Tag.anyerror); + }, + + .error_set_inferred => { + // inferred error sets are compared using their data pointer + const set = ty.castTag(.error_set_inferred).?.data; + std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet); + std.hash.autoHash(hasher, Tag.error_set_inferred); + std.hash.autoHash(hasher, set.func); }, .@"opaque" => { @@ -4365,6 +4385,9 @@ pub const Type = extern union { try names.put(arena, name, {}); } + // names must be sorted + Module.ErrorSet.sortNames(&names); + return try Tag.error_set_merged.create(arena, names); } diff --git a/src/value.zig b/src/value.zig index 121e380bd9..502de64348 100644 --- a/src/value.zig +++ b/src/value.zig @@ -1870,6 +1870,16 @@ pub const Value = extern union { return eql(a_payload.container_ptr, b_payload.container_ptr, ty); }, + .@"error" => { + const a_name = a.castTag(.@"error").?.data.name; + const b_name = b.castTag(.@"error").?.data.name; + return std.mem.eql(u8, a_name, b_name); + }, + .eu_payload => { + const a_payload = a.castTag(.eu_payload).?.data; + const b_payload = b.castTag(.eu_payload).?.data; + return eql(a_payload, b_payload, ty.errorUnionPayload()); + }, .eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"), .opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"), .array => { diff --git a/test/behavior/cast.zig b/test/behavior/cast.zig index 6f300773ae..1f92a0214f 100644 --- a/test/behavior/cast.zig +++ b/test/behavior/cast.zig @@ -669,8 +669,8 @@ test "peer type resolution: disjoint error sets" { try expect(error_set_info == .ErrorSet); try expect(error_set_info.ErrorSet.?.len == 3); try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One")); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Two")); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Three")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two")); } { @@ -678,8 +678,8 @@ test "peer type resolution: disjoint error sets" { const error_set_info = @typeInfo(ty); try expect(error_set_info == .ErrorSet); try expect(error_set_info.ErrorSet.?.len == 3); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "Three")); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "One")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three")); try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two")); } } @@ -704,8 +704,8 @@ test "peer type resolution: error union and error set" { const error_set_info = @typeInfo(info.ErrorUnion.error_set); try expect(error_set_info.ErrorSet.?.len == 3); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "Three")); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "One")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three")); try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two")); } @@ -717,8 +717,8 @@ test "peer type resolution: error union and error set" { const error_set_info = @typeInfo(info.ErrorUnion.error_set); try expect(error_set_info.ErrorSet.?.len == 3); try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One")); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Two")); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Three")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two")); } } diff --git a/test/behavior/error.zig b/test/behavior/error.zig index 7dd0d44e01..73e03b1c3e 100644 --- a/test/behavior/error.zig +++ b/test/behavior/error.zig @@ -330,7 +330,11 @@ fn intLiteral(str: []const u8) !?i64 { } test "nested error union function call in optional unwrap" { - if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO const S = struct { const Foo = struct { @@ -375,7 +379,11 @@ test "nested error union function call in optional unwrap" { } test "return function call to error set from error union function" { - if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO const S = struct { fn errorable() anyerror!i32 { @@ -404,7 +412,11 @@ test "optional error set is the same size as error set" { } test "nested catch" { - if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO const S = struct { fn entry() !void { @@ -428,11 +440,20 @@ test "nested catch" { } test "function pointer with return type that is error union with payload which is pointer of parent struct" { - if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage1) { + // stage1 has wrong function pointer semantics + return error.SkipZigTest; + } + + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO const S = struct { const Foo = struct { - fun: fn (a: i32) (anyerror!*Foo), + fun: *const fn (a: i32) (anyerror!*Foo), }; const Err = error{UnspecifiedErr}; @@ -480,7 +501,11 @@ test "return result loc as peer result loc in inferred error set function" { } test "error payload type is correctly resolved" { - if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO const MyIntWrapper = struct { const Self = @This(); @@ -496,8 +521,6 @@ test "error payload type is correctly resolved" { } test "error union comptime caching" { - if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO - const S = struct { fn quux(comptime arg: anytype) void { arg catch {}; @@ -539,3 +562,59 @@ test "@errorName sentinel length matches slice length" { pub fn testBuiltinErrorName(err: anyerror) [:0]const u8 { return @errorName(err); } + +test "error set equality" { + // This tests using stage2 logic (#11022) + if (builtin.zig_backend == .stage1) return error.SkipZigTest; + + const a = error{One}; + const b = error{One}; + + try expect(a == a); + try expect(a == b); + try expect(a == error{One}); + + // should treat as a set + const c = error{ One, Two }; + const d = error{ Two, One }; + + try expect(c == d); +} + +test "inferred error set equality" { + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + + const S = struct { + fn foo() !void { + return @This().bar(); + } + + fn bar() !void { + return error.Bad; + } + + fn baz() !void { + return quux(); + } + + fn quux() anyerror!void {} + }; + + const FooError = @typeInfo(@typeInfo(@TypeOf(S.foo)).Fn.return_type.?).ErrorUnion.error_set; + const BarError = @typeInfo(@typeInfo(@TypeOf(S.bar)).Fn.return_type.?).ErrorUnion.error_set; + const BazError = @typeInfo(@typeInfo(@TypeOf(S.baz)).Fn.return_type.?).ErrorUnion.error_set; + + try expect(BarError != error{Bad}); + + try expect(FooError != anyerror); + try expect(BarError != anyerror); + try expect(BazError != anyerror); + + try expect(FooError != BarError); + try expect(FooError != BazError); + try expect(BarError != BazError); + + try expect(FooError == FooError); + try expect(BarError == BarError); + try expect(BazError == BazError); +} diff --git a/test/behavior/type_info.zig b/test/behavior/type_info.zig index ad8fe03c15..b2ea6ecbe0 100644 --- a/test/behavior/type_info.zig +++ b/test/behavior/type_info.zig @@ -205,11 +205,12 @@ test "type info: error set single value" { } test "type info: error set merged" { + // #11022 forces ordering of error sets in stage2 + if (builtin.zig_backend == .stage1) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO const TestSet = error{ One, Two } || error{Three}; @@ -217,8 +218,8 @@ test "type info: error set merged" { try expect(error_set_info == .ErrorSet); try expect(error_set_info.ErrorSet.?.len == 3); try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One")); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Two")); - try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Three")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three")); + try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two")); } test "type info: enum info" { diff --git a/test/stage2/x86_64.zig b/test/stage2/x86_64.zig index aeed29bf43..458289e06e 100644 --- a/test/stage2/x86_64.zig +++ b/test/stage2/x86_64.zig @@ -1412,25 +1412,6 @@ pub fn addCases(ctx: *TestContext) !void { }); } - { - var case = ctx.exe("error set equality", target); - - case.addCompareOutput( - \\pub fn main() void { - \\ assert(@TypeOf(error.Foo) == @TypeOf(error.Foo)); - \\ assert(@TypeOf(error.Bar) != @TypeOf(error.Foo)); - \\ assert(anyerror == anyerror); - \\ assert(error{Foo} != error{Foo}); - \\ // TODO put inferred error sets here when @typeInfo works - \\} - \\fn assert(b: bool) void { - \\ if (!b) unreachable; - \\} - , - "", - ); - } - { var case = ctx.exe("comptime var", target);