From 509639717ac8903fe340a02cef844383183bf716 Mon Sep 17 00:00:00 2001 From: Lucas Santos Date: Fri, 20 Sep 2024 19:48:38 -0300 Subject: [PATCH] `std.equalRange`: Compute lower and upper bounds simultaneously The current implementation of `equalRange` just calls `lowerRange` and `upperRange`, but a lot of the work done by these two functions can be shared. Specifically, each iteration gives information about whether the lower bound or the upper bound can be tightened. This leads to fewer iterations and, since there is one comparison per iteration, fewer comparisons. Implementation adapted from [GCC](https://github.com/gcc-mirror/gcc/blob/519ec1cfe9d2c6a1d06709c52cb103508d2c42a7/libstdc%2B%2B-v3/include/bits/stl_algo.h#L2063). This sample demonstrates the difference between the current implementation and mine: ```zig fn S(comptime T: type) type { return struct { needle: T, count: *usize, pub fn order(context: @This(), item: T) std.math.Order { context.count.* += 1; return std.math.order(item, context.needle); } pub fn orderLength(context: @This(), item: []const u8) std.math.Order { context.count.* += 1; return std.math.order(item.len, context.needle); } }; } pub fn main() !void { var count: usize = 0; try std.testing.expectEqual(.{ 0, 0 }, equalRange(i32, &[_]i32{}, S(i32){ .needle = 0, .count = &count }, S(i32).order)); try std.testing.expectEqual(.{ 0, 0 }, equalRange(i32, &[_]i32{ 2, 4, 8, 16, 32, 64 }, S(i32){ .needle = 0, .count = &count }, S(i32).order)); try std.testing.expectEqual(.{ 0, 1 }, equalRange(i32, &[_]i32{ 2, 4, 8, 16, 32, 64 }, S(i32){ .needle = 2, .count = &count }, S(i32).order)); try std.testing.expectEqual(.{ 2, 2 }, equalRange(i32, &[_]i32{ 2, 4, 8, 16, 32, 64 }, S(i32){ .needle = 5, .count = &count }, S(i32).order)); try std.testing.expectEqual(.{ 2, 3 }, equalRange(i32, &[_]i32{ 2, 4, 8, 16, 32, 64 }, S(i32){ .needle = 8, .count = &count }, S(i32).order)); try std.testing.expectEqual(.{ 5, 6 }, equalRange(i32, &[_]i32{ 2, 4, 8, 16, 32, 64 }, S(i32){ .needle = 64, .count = &count }, S(i32).order)); try std.testing.expectEqual(.{ 6, 6 }, equalRange(i32, &[_]i32{ 2, 4, 8, 16, 32, 64 }, S(i32){ .needle = 100, .count = &count }, S(i32).order)); try std.testing.expectEqual(.{ 2, 6 }, equalRange(i32, &[_]i32{ 2, 4, 8, 8, 8, 8, 15, 22 }, S(i32){ .needle = 8, .count = &count }, S(i32).order)); try std.testing.expectEqual(.{ 2, 2 }, equalRange(u32, &[_]u32{ 2, 4, 8, 16, 32, 64 }, S(u32){ .needle = 5, .count = &count }, S(u32).order)); try std.testing.expectEqual(.{ 1, 1 }, equalRange(f32, &[_]f32{ -54.2, -26.7, 0.0, 56.55, 100.1, 322.0 }, S(f32){ .needle = -33.4, .count = &count }, S(f32).order)); try std.testing.expectEqual(.{ 3, 5 }, equalRange( []const u8, &[_][]const u8{ "Mars", "Venus", "Earth", "Saturn", "Uranus", "Mercury", "Jupiter", "Neptune" }, S(usize){ .needle = 6, .count = &count }, S(usize).orderLength, )); std.debug.print("Count: {}\n", .{count}); } ``` For each comparison, we bump the count. With the current implementation, we get 57 comparisons. With mine, we get 43. With contributions from @Olvilock. This is my second attempt at this, since I messed up the [first one](https://github.com/ziglang/zig/pull/21290). --- lib/std/sort.zig | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/lib/std/sort.zig b/lib/std/sort.zig index 23707f1385..8705d24017 100644 --- a/lib/std/sort.zig +++ b/lib/std/sort.zig @@ -769,10 +769,38 @@ pub fn equalRange( context: anytype, comptime compareFn: fn (@TypeOf(context), T) std.math.Order, ) struct { usize, usize } { - return .{ - lowerBound(T, items, context, compareFn), - upperBound(T, items, context, compareFn), - }; + var low: usize = 0; + var high: usize = items.len; + + while (low < high) { + const mid = low + (high - low) / 2; + switch (compareFn(context, items[mid])) { + .gt => { + low = mid + 1; + }, + .lt => { + high = mid; + }, + .eq => { + return .{ + low + std.sort.lowerBound( + T, + items[low..mid], + context, + compareFn, + ), + mid + std.sort.upperBound( + T, + items[mid..high], + context, + compareFn, + ), + }; + }, + } + } + + return .{ low, low }; } test equalRange { @@ -800,6 +828,7 @@ test equalRange { try std.testing.expectEqual(.{ 6, 6 }, equalRange(i32, &[_]i32{ 2, 4, 8, 16, 32, 64 }, @as(i32, 100), S.orderI32)); try std.testing.expectEqual(.{ 2, 6 }, equalRange(i32, &[_]i32{ 2, 4, 8, 8, 8, 8, 15, 22 }, @as(i32, 8), S.orderI32)); try std.testing.expectEqual(.{ 2, 2 }, equalRange(u32, &[_]u32{ 2, 4, 8, 16, 32, 64 }, @as(u32, 5), S.orderU32)); + try std.testing.expectEqual(.{ 3, 5 }, equalRange(u32, &[_]u32{ 2, 3, 4, 5, 5 }, @as(u32, 5), S.orderU32)); try std.testing.expectEqual(.{ 1, 1 }, equalRange(f32, &[_]f32{ -54.2, -26.7, 0.0, 56.55, 100.1, 322.0 }, @as(f32, -33.4), S.orderF32)); try std.testing.expectEqual(.{ 3, 5 }, equalRange( []const u8,