[heapsort] Protect against integer overflow

(Firstly, I changed `n` to `b`, as that is less confusing. It's not a length, it's a right boundary.)

The invariant maintained is `cur < b`. In the worst case `2*cur + 1` results in a maximum of `2b`. Since `2b` is not guaranteed to be lower than `maxInt`, we have to add one overflow check to `siftDown` to make sure we avoid undefined behavior.

LLVM also seems to have a nicer time compiling this version of the function. It is about 2x faster in my tests (I think LLVM was stumped by the `child += @intFromBool` line), and adding/removing the overflow check has a negligible performance difference on my machine. Of course, we could check `2b <= maxInt` in the parent function, and dispatch to a version of the function without the overflow check in the common case, but that probably is not worth the code size just to eliminate a single instruction.
This commit is contained in:
Niles Salter 2023-06-22 11:32:28 -06:00 committed by GitHub
parent c60896743d
commit 7d511d6428
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,6 +36,8 @@ pub fn insertion(
/// O(1) memory (no allocator required). /// O(1) memory (no allocator required).
/// Sorts in ascending order with respect to the given `lessThan` function. /// Sorts in ascending order with respect to the given `lessThan` function.
pub fn insertionContext(a: usize, b: usize, context: anytype) void { pub fn insertionContext(a: usize, b: usize, context: anytype) void {
assert(a <= b);
var i = a + 1; var i = a + 1;
while (i < b) : (i += 1) { while (i < b) : (i += 1) {
var j = i; var j = i;
@ -73,6 +75,7 @@ pub fn heap(
/// O(1) memory (no allocator required). /// O(1) memory (no allocator required).
/// Sorts in ascending order with respect to the given `lessThan` function. /// Sorts in ascending order with respect to the given `lessThan` function.
pub fn heapContext(a: usize, b: usize, context: anytype) void { pub fn heapContext(a: usize, b: usize, context: anytype) void {
assert(a <= b);
// build the heap in linear time. // build the heap in linear time.
var i = a + (b - a) / 2; var i = a + (b - a) / 2;
while (i > a) { while (i > a) {
@ -89,22 +92,33 @@ pub fn heapContext(a: usize, b: usize, context: anytype) void {
} }
} }
fn siftDown(a: usize, root: usize, n: usize, context: anytype) void { fn siftDown(a: usize, target: usize, b: usize, context: anytype) void {
var node = root; var cur = target;
while (true) { while (true) {
var child = a + 2 * (node - a) + 1; // When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1
if (child >= n) break; // The `+ a + 1` is safe because:
// for `a > 0` then `2a >= a + 1`.
// for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe.
var child = (math.mul(usize, cur - a, 2) catch break) + a + 1;
// choose the greater child. // stop if we overshot the boundary
child += @intFromBool(child + 1 < n and context.lessThan(child, child + 1)); if (!(child < b)) break;
// stop if the invariant holds at `node`. // `next_child` is at most `b`, therefore no overflow is possible
if (!context.lessThan(node, child)) break; const next_child = child + 1;
// swap `node` with the greater child, // store the greater child in `child`
if (next_child < b and context.lessThan(child, next_child)) {
child = next_child;
}
// stop if the Heap invariant holds at `cur`.
if (context.lessThan(child, cur)) break;
// swap `cur` with the greater child,
// move one step down, and continue sifting. // move one step down, and continue sifting.
context.swap(node, child); context.swap(child, cur);
node = child; cur = child;
} }
} }