diff --git a/std/sort.zig b/std/sort.zig index 15302e5fb0..af65cd3a59 100644 --- a/std/sort.zig +++ b/std/sort.zig @@ -4,47 +4,13 @@ const math = @import("math.zig"); pub const Cmp = math.Cmp; -pub fn sort(inline T: type, array: []T) { +pub fn sort(inline T: type, array: []T, inline cmp: fn(a: T, b: T)->Cmp) { if (array.len > 0) { - quicksort(T, array, 0, array.len - 1); + quicksort(T, array, 0, array.len - 1, cmp); } } -fn quicksort(inline T: type, array: []T, left: usize, right: usize) { - var i = left; - var j = right; - var p = (i + j) / 2; - - while (i <= j) { - while (array[i] < array[p]) { - i += 1; - } - while (array[j] > array[p]) { - j -= 1; - } - if (i <= j) { - const tmp = array[i]; - array[i] = array[j]; - array[j] = tmp; - i += 1; - if (j > 0) j -= 1; - } - } - - if (left < j) quicksort(T, array, left, j); - if (i < right) quicksort(T, array, i, right); -} - -// --------------------------------------- -// sortCmp - -pub fn sortCmp(inline T: type, array: []T, inline cmp: fn(a: T, b: T)->Cmp) { - if (array.len > 0) { - quicksortCmp(T, array, 0, array.len - 1, cmp); - } -} - -fn quicksortCmp(inline T: type, array: []T, left: usize, right: usize, inline cmp: fn(a: T, b: T)->Cmp) { +fn quicksort(inline T: type, array: []T, left: usize, right: usize, inline cmp: fn(a: T, b: T)->Cmp) { var i = left; var j = right; var p = (i + j) / 2; @@ -65,8 +31,28 @@ fn quicksortCmp(inline T: type, array: []T, left: usize, right: usize, inline cm } } - if (left < j) quicksortCmp(T, array, left, j, cmp); - if (i < right) quicksortCmp(T, array, i, right, cmp); + if (left < j) quicksort(T, array, left, j, cmp); + if (i < right) quicksort(T, array, i, right, cmp); +} + +pub fn i32asc(a: i32, b: i32) -> Cmp { + return if (a > b) Cmp.Greater else if (a < b) Cmp.Less else Cmp.Equal; +} + +pub fn i32desc(a: i32, b: i32) -> Cmp { + return reverse(i32asc(a, b)); +} + +pub fn u8asc(a: u8, b: u8) -> Cmp { + return if (a > b) Cmp.Greater else if (a < b) Cmp.Less else Cmp.Equal; +} + +pub fn u8desc(a: u8, b: u8) -> Cmp { + return reverse(u8asc(a, b)); +} + +fn reverse(was: Cmp) -> Cmp { + return if (was == Cmp.Greater) Cmp.Less else if (was == Cmp.Less) Cmp.Greater else Cmp.Equal; } // --------------------------------------- @@ -85,7 +71,7 @@ fn testSort() { }; for (u8cases) |case| { - sort(u8, case[0]); + sort(u8, case[0], u8asc); assert(str.eql(case[0], case[1])); } @@ -99,28 +85,14 @@ fn testSort() { }; for (i32cases) |case| { - sort(i32, case[0]); + sort(i32, case[0], i32asc); assert(str.sliceEql(i32, case[0], case[1])); } } -fn testSortCmp() { +fn testSortDesc() { @setFnTest(this, true); - const i32cases = [][][]i32 { - [][]i32{[]i32{}, []i32{}}, - [][]i32{[]i32{1}, []i32{1}}, - [][]i32{[]i32{0, 1}, []i32{0, 1}}, - [][]i32{[]i32{1, 0}, []i32{0, 1}}, - [][]i32{[]i32{1, -1, 0}, []i32{-1, 0, 1}}, - [][]i32{[]i32{2, 1, 3}, []i32{1, 2, 3}}, - }; - - for (i32cases) |case| { - sortCmp(i32, case[0], normalCmp); - assert(str.sliceEql(i32, case[0], case[1])); - } - const revCases = [][][]i32 { [][]i32{[]i32{}, []i32{}}, [][]i32{[]i32{1}, []i32{1}}, @@ -131,16 +103,9 @@ fn testSortCmp() { }; for (revCases) |case| { - sortCmp(i32, case[0], revCmp); + sort(i32, case[0], i32desc); assert(str.sliceEql(i32, case[0], case[1])); } } -fn normalCmp(a: i32, b: i32) -> Cmp { - return if (a > b) Cmp.Greater else if (a < b) Cmp.Less else Cmp.Equal; -} - -fn revCmp(a: i32, b: i32) -> Cmp { - return if (a < b) Cmp.Greater else if (a > b) Cmp.Less else Cmp.Equal; -}