diff --git a/lib/std/rb.zig b/lib/std/rb.zig index c41a269a27..d7840b75b0 100644 --- a/lib/std/rb.zig +++ b/lib/std/rb.zig @@ -1,4 +1,4 @@ -const std = @import("std.zig"); +const std = @import("std"); const assert = std.debug.assert; const testing = std.testing; const Order = std.math.Order; @@ -11,6 +11,7 @@ const Red = Color.Red; const Black = Color.Black; const ReplaceError = error{NotEqual}; +const SortError = error{NotUnique}; // The new comparison function results in duplicates. /// Insert this into your struct that you want to add to a red-black tree. /// Do not use a pointer. Turn the *rb.Node results of the functions in rb @@ -132,7 +133,21 @@ pub const Node = struct { pub const Tree = struct { root: ?*Node, - compareFn: fn (*Node, *Node) Order, + compareFn: fn (*Node, *Node, *Tree) Order, + + /// Re-sorts a tree with a new compare function + pub fn sort(tree: *Tree, newCompareFn: fn (*Node, *Node, *Tree) Order) SortError!void { + var newTree = Tree.init(newCompareFn); + var node: *Node = undefined; + while (true) { + node = tree.first() orelse break; + tree.remove(node); + if (newTree.insert(node) != null) { + return error.NotUnique; // EEXISTS + } + } + tree.* = newTree; + } /// If you have a need for a version that caches this, please file a bug. pub fn first(tree: *Tree) ?*Node { @@ -244,6 +259,7 @@ pub const Tree = struct { return doLookup(key, tree, &parent, &is_left); } + /// If node is not part of tree, behavior is undefined. pub fn remove(tree: *Tree, nodeconst: *Node) void { var node = nodeconst; // as this has the same value as node, it is unsafe to access node after newnode @@ -389,7 +405,7 @@ pub const Tree = struct { var new = newconst; // I assume this can get optimized out if the caller already knows. - if (tree.compareFn(old, new) != .eq) return ReplaceError.NotEqual; + if (tree.compareFn(old, new, tree) != .eq) return ReplaceError.NotEqual; if (old.getParent()) |parent| { parent.setChild(new, parent.left == old); @@ -404,9 +420,11 @@ pub const Tree = struct { new.* = old.*; } - pub fn init(tree: *Tree, f: fn (*Node, *Node) Order) void { - tree.root = null; - tree.compareFn = f; + pub fn init(f: fn (*Node, *Node, *Tree) Order) Tree { + return Tree{ + .root = null, + .compareFn = f, + }; } }; @@ -469,7 +487,7 @@ fn doLookup(key: *Node, tree: *Tree, pparent: *?*Node, is_left: *bool) ?*Node { is_left.* = false; while (maybe_node) |node| { - const res = tree.compareFn(node, key); + const res = tree.compareFn(node, key, tree); if (res == .eq) { return node; } @@ -498,7 +516,7 @@ fn testGetNumber(node: *Node) *testNumber { return @fieldParentPtr(testNumber, "node", node); } -fn testCompare(l: *Node, r: *Node) Order { +fn testCompare(l: *Node, r: *Node, contextIgnored: *Tree) Order { var left = testGetNumber(l); var right = testGetNumber(r); @@ -512,13 +530,17 @@ fn testCompare(l: *Node, r: *Node) Order { unreachable; } +fn testCompareReverse(l: *Node, r: *Node, contextIgnored: *Tree) Order { + return testCompare(r, l, contextIgnored); +} + test "rb" { if (@import("builtin").arch == .aarch64) { // TODO https://github.com/ziglang/zig/issues/3288 return error.SkipZigTest; } - var tree: Tree = undefined; + var tree = Tree.init(testCompare); var ns: [10]testNumber = undefined; ns[0].value = 42; ns[1].value = 41; @@ -534,7 +556,6 @@ test "rb" { var dup: testNumber = undefined; dup.value = 32345; - tree.init(testCompare); _ = tree.insert(&ns[1].node); _ = tree.insert(&ns[2].node); _ = tree.insert(&ns[3].node); @@ -557,8 +578,7 @@ test "rb" { } test "inserting and looking up" { - var tree: Tree = undefined; - tree.init(testCompare); + var tree = Tree.init(testCompare); var number: testNumber = undefined; number.value = 1000; _ = tree.insert(&number.node); @@ -582,8 +602,7 @@ test "multiple inserts, followed by calling first and last" { // TODO https://github.com/ziglang/zig/issues/3288 return error.SkipZigTest; } - var tree: Tree = undefined; - tree.init(testCompare); + var tree = Tree.init(testCompare); var zeroth: testNumber = undefined; zeroth.value = 0; var first: testNumber = undefined; @@ -601,4 +620,8 @@ test "multiple inserts, followed by calling first and last" { var lookupNode: testNumber = undefined; lookupNode.value = 3; assert(tree.lookup(&lookupNode.node) == &third.node); + tree.sort(testCompareReverse) catch unreachable; + assert(testGetNumber(tree.first().?).value == 3); + assert(testGetNumber(tree.last().?).value == 0); + assert(tree.lookup(&lookupNode.node) == &third.node); }