diff --git a/std/hash_map.zig b/std/hash_map.zig index 6ea128c9ad..9cd1ea052c 100644 --- a/std/hash_map.zig +++ b/std/hash_map.zig @@ -118,7 +118,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 }; } self.incrementModificationCount(); - try self.ensureCapacity(); + try self.autoCapacity(); const put_result = self.internalPut(key); assert(put_result.old_kv == null); return GetOrPutResult{ @@ -135,15 +135,37 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 return res.kv; } - fn ensureCapacity(self: *Self) !void { - if (self.entries.len == 0) { - return self.initCapacity(16); + fn optimizedCapacity(expected_count: usize) usize { + // ensure that the hash map will be at most 60% full if + // expected_count items are put into it + var optimized_capacity = expected_count * 5 / 3; + // round capacity to the next power of two + const pow = math.log2_int_ceil(usize, optimized_capacity); + return math.pow(usize, 2, pow); + } + + /// Increases capacity so that the hash map will be at most + /// 60% full when expected_count items are put into it + pub fn ensureCapacity(self: *Self, expected_count: usize) !void { + const optimized_capacity = optimizedCapacity(expected_count); + return self.ensureCapacityExact(optimized_capacity); + } + + /// Sets the capacity to the new capacity if the new + /// capacity is greater than the current capacity. + /// New capacity must be a power of two. + fn ensureCapacityExact(self: *Self, new_capacity: usize) !void { + const is_power_of_two = new_capacity & (new_capacity-1) == 0; + assert(is_power_of_two); + + if (new_capacity <= self.entries.len) { + return; } - // if we get too full (60%), double the capacity - if (self.size * 5 >= self.entries.len * 3) { - const old_entries = self.entries; - try self.initCapacity(self.entries.len * 2); + const old_entries = self.entries; + try self.initCapacity(new_capacity); + self.incrementModificationCount(); + if (old_entries.len > 0) { // dump all of the old elements into the new table for (old_entries) |*old_entry| { if (old_entry.used) { @@ -156,8 +178,13 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 /// Returns the kv pair that was already there. pub fn put(self: *Self, key: K, value: V) !?KV { + try self.autoCapacity(); + return putAssumeCapacity(self, key, value); + } + + pub fn putAssumeCapacity(self: *Self, key: K, value: V) ?KV { + assert(self.count() < self.entries.len); self.incrementModificationCount(); - try self.ensureCapacity(); const put_result = self.internalPut(key); put_result.new_entry.kv.value = value; @@ -227,6 +254,16 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 return other; } + fn autoCapacity(self: *Self) !void { + if (self.entries.len == 0) { + return self.ensureCapacityExact(16); + } + // if we get too full (60%), double the capacity + if (self.size * 5 >= self.entries.len * 3) { + return self.ensureCapacityExact(self.entries.len * 2); + } + } + fn initCapacity(hm: *Self, capacity: usize) !void { hm.entries = try hm.allocator.alloc(Entry, capacity); hm.size = 0; @@ -427,6 +464,24 @@ test "iterator hash map" { testing.expect(entry.value == values[0]); } +test "ensure capacity" { + var direct_allocator = std.heap.DirectAllocator.init(); + defer direct_allocator.deinit(); + + var map = AutoHashMap(i32, i32).init(&direct_allocator.allocator); + defer map.deinit(); + + try map.ensureCapacity(20); + const initialCapacity = map.entries.len; + testing.expect(initialCapacity >= 20); + var i : i32 = 0; + while (i < 20) : (i += 1) { + testing.expect(map.putAssumeCapacity(i, i+10) == null); + } + // shouldn't resize from putAssumeCapacity + testing.expect(initialCapacity == map.entries.len); +} + pub fn getHashPtrAddrFn(comptime K: type) (fn (K) u32) { return struct { fn hash(key: K) u32 {