From 5bb8c03697fce798a966feb131f6d906863047ae Mon Sep 17 00:00:00 2001 From: Justin Whear Date: Sun, 28 Aug 2022 04:19:51 -0700 Subject: [PATCH] std.random: add weightedIndex function `weightedIndex` picks from a selection of weighted indices. --- lib/std/rand.zig | 36 ++++++++++++++++++++++++++++++++++++ lib/std/rand/test.zig | 26 ++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/lib/std/rand.zig b/lib/std/rand.zig index c562230cca..c13d2895c9 100644 --- a/lib/std/rand.zig +++ b/lib/std/rand.zig @@ -337,6 +337,42 @@ pub const Random = struct { mem.swap(T, &buf[i], &buf[j]); } } + + /// Randomly selects an index into `proportions`, where the likelihood of each + /// index is weighted by that proportion. + /// + /// This is useful for selecting an item from a slice where weights are not equal. + /// `T` must be a numeric type capable of holding the sum of `proportions`. + pub fn weightedIndex(r: std.rand.Random, comptime T: type, proportions: []T) usize { + // This implementation works by summing the proportions and picking a random + // point in [0, sum). We then loop over the proportions, accumulating + // until our accumulator is greater than the random point. + + var sum: T = 0; + for (proportions) |v| { + sum += v; + } + + const point = if (comptime std.meta.trait.isSignedInt(T)) + r.intRangeLessThan(T, 0, sum) + else if (comptime std.meta.trait.isUnsignedInt(T)) + r.uintLessThan(T, sum) + else if (comptime std.meta.trait.isFloat(T)) + // take care that imprecision doesn't lead to a value slightly greater than sum + std.math.min(r.float(T) * sum, sum - std.math.epsilon(T)) + else + @compileError("weightedIndex does not support proportions of type " ++ @typeName(T)); + + std.debug.assert(point < sum); + + var accumulator: T = 0; + for (proportions) |p, index| { + accumulator += p; + if (point < accumulator) return index; + } + + unreachable; + } }; /// Convert a random integer 0 <= random_int <= maxValue(T), diff --git a/lib/std/rand/test.zig b/lib/std/rand/test.zig index 7c2016901f..cae77d6e37 100644 --- a/lib/std/rand/test.zig +++ b/lib/std/rand/test.zig @@ -445,3 +445,29 @@ test "CSPRNG" { const c = random.int(u64); try expect(a ^ b ^ c != 0); } + +test "Random weightedIndex" { + // Make sure weightedIndex works for various integers and floats + inline for (.{ u64, i4, f32, f64 }) |T| { + var prng = DefaultPrng.init(0); + const random = prng.random(); + + var proportions = [_]T{ 2, 1, 1, 2 }; + var counts = [_]f64{ 0, 0, 0, 0 }; + + const n_trials: u64 = 10_000; + var i: usize = 0; + while (i < n_trials) : (i += 1) { + const pick = random.weightedIndex(T, &proportions); + counts[pick] += 1; + } + + // We expect the first and last counts to be roughly 2x the second and third + const approxEqRel = std.math.approxEqRel; + // Define "roughly" to be within 10% + const tolerance = 0.1; + try std.testing.expect(approxEqRel(f64, counts[0], counts[1] * 2, tolerance)); + try std.testing.expect(approxEqRel(f64, counts[1], counts[2], tolerance)); + try std.testing.expect(approxEqRel(f64, counts[2] * 2, counts[3], tolerance)); + } +}