diff --git a/lib/std/sort.zig b/lib/std/sort.zig index 928503ad40..1ae7e1e574 100644 --- a/lib/std/sort.zig +++ b/lib/std/sort.zig @@ -36,6 +36,8 @@ pub fn insertion( /// O(1) memory (no allocator required). /// Sorts in ascending order with respect to the given `lessThan` function. pub fn insertionContext(a: usize, b: usize, context: anytype) void { + assert(a <= b); + var i = a + 1; while (i < b) : (i += 1) { var j = i; @@ -73,6 +75,7 @@ pub fn heap( /// O(1) memory (no allocator required). /// Sorts in ascending order with respect to the given `lessThan` function. pub fn heapContext(a: usize, b: usize, context: anytype) void { + assert(a <= b); // build the heap in linear time. var i = a + (b - a) / 2; while (i > a) { @@ -89,22 +92,33 @@ pub fn heapContext(a: usize, b: usize, context: anytype) void { } } -fn siftDown(a: usize, root: usize, n: usize, context: anytype) void { - var node = root; +fn siftDown(a: usize, target: usize, b: usize, context: anytype) void { + var cur = target; while (true) { - var child = a + 2 * (node - a) + 1; - if (child >= n) break; + // When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1 + // The `+ a + 1` is safe because: + // for `a > 0` then `2a >= a + 1`. + // for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe. + var child = (math.mul(usize, cur - a, 2) catch break) + a + 1; - // choose the greater child. - child += @intFromBool(child + 1 < n and context.lessThan(child, child + 1)); + // stop if we overshot the boundary + if (!(child < b)) break; - // stop if the invariant holds at `node`. - if (!context.lessThan(node, child)) break; + // `next_child` is at most `b`, therefore no overflow is possible + const next_child = child + 1; - // swap `node` with the greater child, + // store the greater child in `child` + if (next_child < b and context.lessThan(child, next_child)) { + child = next_child; + } + + // stop if the Heap invariant holds at `cur`. + if (context.lessThan(child, cur)) break; + + // swap `cur` with the greater child, // move one step down, and continue sifting. - context.swap(node, child); - node = child; + context.swap(child, cur); + cur = child; } }