diff --git a/lib/std/enums.zig b/lib/std/enums.zig index 32f04df012..440b505283 100644 --- a/lib/std/enums.zig +++ b/lib/std/enums.zig @@ -304,6 +304,346 @@ pub fn EnumMap(comptime E: type, comptime V: type) type { return IndexedMap(EnumIndexer(E), V, mixin.EnumMapExt); } +/// A multiset of enum elements up to a count of usize. Backed +/// by an EnumArray. This type does no dynamic allocation and can +/// be copied by value. +pub fn EnumMultiset(comptime E: type) type { + return BoundedEnumMultiset(E, usize); +} + +/// A multiset of enum elements up to CountSize. Backed by an +/// EnumArray. This type does no dynamic allocation and can be +/// copied by value. +pub fn BoundedEnumMultiset(comptime E: type, comptime CountSize: type) type { + return struct { + const Self = @This(); + + counts: EnumArray(E, CountSize), + + /// Initializes the multiset using a struct of counts. + pub fn init(init_counts: EnumFieldStruct(E, CountSize, 0)) Self { + var self = initWithCount(0); + inline for (@typeInfo(E).Enum.fields) |field| { + const c = @field(init_counts, field.name); + const key = @intToEnum(E, field.value); + self.counts.set(key, c); + } + return self; + } + + /// Initializes the multiset with a count of zero. + pub fn initEmpty() Self { + return initWithCount(0); + } + + /// Initializes the multiset with all keys at the + /// same count. + pub fn initWithCount(comptime c: CountSize) Self { + return .{ + .counts = EnumArray(E, CountSize).initDefault(c, .{}), + }; + } + + /// Returns the total number of key counts in the multiset. + pub fn count(self: Self) usize { + var sum: usize = 0; + for (self.counts.values) |c| { + sum += c; + } + return sum; + } + + /// Checks if at least one key in multiset. + pub fn contains(self: Self, key: E) bool { + return self.counts.get(key) > 0; + } + + /// Removes all instance of a key from multiset. Same as + /// setCount(key, 0). + pub fn removeAll(self: *Self, key: E) void { + return self.counts.set(key, 0); + } + + /// Increases the key count by given amount. Caller asserts + /// operation will not overflow. + pub fn addAssertSafe(self: *Self, key: E, c: CountSize) void { + self.counts.getPtr(key).* += c; + } + + /// Increases the key count by given amount. + pub fn add(self: *Self, key: E, c: CountSize) error{Overflow}!void { + self.counts.set(key, try std.math.add(CountSize, self.counts.get(key), c)); + } + + /// Decreases the key count by given amount. If amount is + /// greater than the number of keys in multset, then key count + /// will be set to zero. + pub fn remove(self: *Self, key: E, c: CountSize) void { + self.counts.getPtr(key).* -= @min(self.getCount(key), c); + } + + /// Returns the count for a key. + pub fn getCount(self: Self, key: E) CountSize { + return self.counts.get(key); + } + + /// Set the count for a key. + pub fn setCount(self: *Self, key: E, c: CountSize) void { + self.counts.set(key, c); + } + + /// Increases the all key counts by given multiset. Caller + /// asserts operation will not overflow any key. + pub fn addSetAssertSafe(self: *Self, other: Self) void { + inline for (@typeInfo(E).Enum.fields) |field| { + const key = @intToEnum(E, field.value); + self.addAssertSafe(key, other.getCount(key)); + } + } + + /// Increases the all key counts by given multiset. + pub fn addSet(self: *Self, other: Self) error{Overflow}!void { + inline for (@typeInfo(E).Enum.fields) |field| { + const key = @intToEnum(E, field.value); + try self.add(key, other.getCount(key)); + } + } + + /// Deccreases the all key counts by given multiset. If + /// the given multiset has more key counts than this, + /// then that key will have a key count of zero. + pub fn removeSet(self: *Self, other: Self) void { + inline for (@typeInfo(E).Enum.fields) |field| { + const key = @intToEnum(E, field.value); + self.remove(key, other.getCount(key)); + } + } + + /// Returns true iff all key counts are the same as + /// given multiset. + pub fn eql(self: Self, other: Self) bool { + inline for (@typeInfo(E).Enum.fields) |field| { + const key = @intToEnum(E, field.value); + if (self.getCount(key) != other.getCount(key)) { + return false; + } + } + return true; + } + + /// Returns a multiset with the total key count of this + /// multiset and the other multiset. Caller asserts + /// operation will not overflow any key. + pub fn plusAssertSafe(self: Self, other: Self) Self { + var result = self; + result.addSetAssertSafe(other); + return result; + } + + /// Returns a multiset with the total key count of this + /// multiset and the other multiset. + pub fn plus(self: Self, other: Self) error{Overflow}!Self { + var result = self; + try result.addSet(other); + return result; + } + + /// Returns a multiset with the key count of this + /// multiset minus the corresponding key count in the + /// other multiset. If the other multiset contains + /// more key count than this set, that key will have + /// a count of zero. + pub fn minus(self: Self, other: Self) Self { + var result = self; + result.removeSet(other); + return result; + } + + pub const Entry = EnumArray(E, CountSize).Entry; + pub const Iterator = EnumArray(E, CountSize).Iterator; + + /// Returns an iterator over this multiset. Keys with zero + /// counts are included. Modifications to the set during + /// iteration may or may not be observed by the iterator, + /// but will not invalidate it. + pub fn iterator(self: *Self) Iterator { + return self.counts.iterator(); + } + }; +} + +test "EnumMultiset" { + const Ball = enum { red, green, blue }; + + const empty = EnumMultiset(Ball).initEmpty(); + const r0_g1_b2 = EnumMultiset(Ball).init(.{ + .red = 0, + .green = 1, + .blue = 2, + }); + const ten_of_each = EnumMultiset(Ball).initWithCount(10); + + try testing.expectEqual(empty.count(), 0); + try testing.expectEqual(r0_g1_b2.count(), 3); + try testing.expectEqual(ten_of_each.count(), 30); + + try testing.expect(!empty.contains(.red)); + try testing.expect(!empty.contains(.green)); + try testing.expect(!empty.contains(.blue)); + + try testing.expect(!r0_g1_b2.contains(.red)); + try testing.expect(r0_g1_b2.contains(.green)); + try testing.expect(r0_g1_b2.contains(.blue)); + + try testing.expect(ten_of_each.contains(.red)); + try testing.expect(ten_of_each.contains(.green)); + try testing.expect(ten_of_each.contains(.blue)); + + { + var copy = ten_of_each; + copy.removeAll(.red); + try testing.expect(!copy.contains(.red)); + + // removeAll second time does nothing + copy.removeAll(.red); + try testing.expect(!copy.contains(.red)); + } + + { + var copy = ten_of_each; + copy.addAssertSafe(.red, 6); + try testing.expectEqual(copy.getCount(.red), 16); + } + + { + var copy = ten_of_each; + try copy.add(.red, 6); + try testing.expectEqual(copy.getCount(.red), 16); + + try testing.expectError(error.Overflow, copy.add(.red, std.math.maxInt(usize))); + } + + { + var copy = ten_of_each; + copy.remove(.red, 4); + try testing.expectEqual(copy.getCount(.red), 6); + + // subtracting more it contains does not underflow + copy.remove(.green, 14); + try testing.expectEqual(copy.getCount(.green), 0); + } + + try testing.expectEqual(empty.getCount(.green), 0); + try testing.expectEqual(r0_g1_b2.getCount(.green), 1); + try testing.expectEqual(ten_of_each.getCount(.green), 10); + + { + var copy = empty; + copy.setCount(.red, 6); + try testing.expectEqual(copy.getCount(.red), 6); + } + + { + var copy = r0_g1_b2; + copy.addSetAssertSafe(ten_of_each); + try testing.expectEqual(copy.getCount(.red), 10); + try testing.expectEqual(copy.getCount(.green), 11); + try testing.expectEqual(copy.getCount(.blue), 12); + } + + { + var copy = r0_g1_b2; + try copy.addSet(ten_of_each); + try testing.expectEqual(copy.getCount(.red), 10); + try testing.expectEqual(copy.getCount(.green), 11); + try testing.expectEqual(copy.getCount(.blue), 12); + + const full = EnumMultiset(Ball).initWithCount(std.math.maxInt(usize)); + try testing.expectError(error.Overflow, copy.addSet(full)); + } + + { + var copy = ten_of_each; + copy.removeSet(r0_g1_b2); + try testing.expectEqual(copy.getCount(.red), 10); + try testing.expectEqual(copy.getCount(.green), 9); + try testing.expectEqual(copy.getCount(.blue), 8); + + copy.removeSet(ten_of_each); + try testing.expectEqual(copy.getCount(.red), 0); + try testing.expectEqual(copy.getCount(.green), 0); + try testing.expectEqual(copy.getCount(.blue), 0); + } + + try testing.expect(empty.eql(empty)); + try testing.expect(r0_g1_b2.eql(r0_g1_b2)); + try testing.expect(ten_of_each.eql(ten_of_each)); + try testing.expect(!empty.eql(r0_g1_b2)); + try testing.expect(!r0_g1_b2.eql(ten_of_each)); + try testing.expect(!ten_of_each.eql(empty)); + + { + const result = r0_g1_b2.plusAssertSafe(ten_of_each); + try testing.expectEqual(result.getCount(.red), 10); + try testing.expectEqual(result.getCount(.green), 11); + try testing.expectEqual(result.getCount(.blue), 12); + } + + { + const result = try r0_g1_b2.plus(ten_of_each); + try testing.expectEqual(result.getCount(.red), 10); + try testing.expectEqual(result.getCount(.green), 11); + try testing.expectEqual(result.getCount(.blue), 12); + + const full = EnumMultiset(Ball).initWithCount(std.math.maxInt(usize)); + try testing.expectError(error.Overflow, result.plus(full)); + } + + { + const result = ten_of_each.minus(r0_g1_b2); + try testing.expectEqual(result.getCount(.red), 10); + try testing.expectEqual(result.getCount(.green), 9); + try testing.expectEqual(result.getCount(.blue), 8); + } + + { + const result = ten_of_each.minus(r0_g1_b2).minus(ten_of_each); + try testing.expectEqual(result.getCount(.red), 0); + try testing.expectEqual(result.getCount(.green), 0); + try testing.expectEqual(result.getCount(.blue), 0); + } + + { + var copy = empty; + var it = copy.iterator(); + var entry = it.next().?; + try testing.expectEqual(entry.key, .red); + try testing.expectEqual(entry.value.*, 0); + entry = it.next().?; + try testing.expectEqual(entry.key, .green); + try testing.expectEqual(entry.value.*, 0); + entry = it.next().?; + try testing.expectEqual(entry.key, .blue); + try testing.expectEqual(entry.value.*, 0); + try testing.expectEqual(it.next(), null); + } + + { + var copy = r0_g1_b2; + var it = copy.iterator(); + var entry = it.next().?; + try testing.expectEqual(entry.key, .red); + try testing.expectEqual(entry.value.*, 0); + entry = it.next().?; + try testing.expectEqual(entry.key, .green); + try testing.expectEqual(entry.value.*, 1); + entry = it.next().?; + try testing.expectEqual(entry.key, .blue); + try testing.expectEqual(entry.value.*, 2); + try testing.expectEqual(it.next(), null); + } +} + /// An array keyed by an enum, backed by a dense array. /// If the enum is not dense, a mapping will be constructed from /// enum values to dense indices. This type does no dynamic