Merge pull request #16840 from e4m2/rand-int

std.rand: Support integers with >64 bits in more functions
This commit is contained in:
Andrew Kelley 2023-10-21 05:26:19 -04:00 committed by GitHub
commit 3d6e633371
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -113,22 +113,16 @@ pub const Random = struct {
// TODO: endian portability is pointless if the underlying prng isn't endian portable. // TODO: endian portability is pointless if the underlying prng isn't endian portable.
// TODO: document the endian portability of this library. // TODO: document the endian portability of this library.
const byte_aligned_result = mem.readIntSliceLittle(ByteAlignedT, &rand_bytes); const byte_aligned_result = mem.readIntSliceLittle(ByteAlignedT, &rand_bytes);
const unsigned_result = @as(UnsignedT, @truncate(byte_aligned_result)); const unsigned_result: UnsignedT = @truncate(byte_aligned_result);
return @as(T, @bitCast(unsigned_result)); return @bitCast(unsigned_result);
} }
/// Constant-time implementation off `uintLessThan`. /// Constant-time implementation off `uintLessThan`.
/// The results of this function may be biased. /// The results of this function may be biased.
pub fn uintLessThanBiased(r: Random, comptime T: type, less_than: T) T { pub fn uintLessThanBiased(r: Random, comptime T: type, less_than: T) T {
comptime assert(@typeInfo(T).Int.signedness == .unsigned); comptime assert(@typeInfo(T).Int.signedness == .unsigned);
const bits = @typeInfo(T).Int.bits;
comptime assert(bits <= 64); // TODO: workaround: LLVM ERROR: Unsupported library call operation!
assert(0 < less_than); assert(0 < less_than);
if (bits <= 32) { return limitRangeBiased(T, r.int(T), less_than);
return @as(T, @intCast(limitRangeBiased(u32, r.int(u32), less_than)));
} else {
return @as(T, @intCast(limitRangeBiased(u64, r.int(u64), less_than)));
}
} }
/// Returns an evenly distributed random unsigned integer `0 <= i < less_than`. /// Returns an evenly distributed random unsigned integer `0 <= i < less_than`.
@ -142,22 +136,16 @@ pub const Random = struct {
pub fn uintLessThan(r: Random, comptime T: type, less_than: T) T { pub fn uintLessThan(r: Random, comptime T: type, less_than: T) T {
comptime assert(@typeInfo(T).Int.signedness == .unsigned); comptime assert(@typeInfo(T).Int.signedness == .unsigned);
const bits = @typeInfo(T).Int.bits; const bits = @typeInfo(T).Int.bits;
comptime assert(bits <= 64); // TODO: workaround: LLVM ERROR: Unsupported library call operation!
assert(0 < less_than); assert(0 < less_than);
// Small is typically u32
const small_bits = @divTrunc(bits + 31, 32) * 32;
const Small = std.meta.Int(.unsigned, small_bits);
// Large is typically u64
const Large = std.meta.Int(.unsigned, small_bits * 2);
// adapted from: // adapted from:
// http://www.pcg-random.org/posts/bounded-rands.html // http://www.pcg-random.org/posts/bounded-rands.html
// "Lemire's (with an extra tweak from me)" // "Lemire's (with an extra tweak from me)"
var x: Small = r.int(Small); var x = r.int(T);
var m: Large = @as(Large, x) * @as(Large, less_than); var m = math.mulWide(T, x, less_than);
var l: Small = @as(Small, @truncate(m)); var l: T = @truncate(m);
if (l < less_than) { if (l < less_than) {
var t: Small = -%less_than; var t = -%less_than;
if (t >= less_than) { if (t >= less_than) {
t -= less_than; t -= less_than;
@ -166,12 +154,12 @@ pub const Random = struct {
} }
} }
while (l < t) { while (l < t) {
x = r.int(Small); x = r.int(T);
m = @as(Large, x) * @as(Large, less_than); m = math.mulWide(T, x, less_than);
l = @as(Small, @truncate(m)); l = @truncate(m);
} }
} }
return @as(T, @intCast(m >> small_bits)); return @intCast(m >> bits);
} }
/// Constant-time implementation off `uintAtMost`. /// Constant-time implementation off `uintAtMost`.
@ -205,10 +193,10 @@ pub const Random = struct {
if (info.signedness == .signed) { if (info.signedness == .signed) {
// Two's complement makes this math pretty easy. // Two's complement makes this math pretty easy.
const UnsignedT = std.meta.Int(.unsigned, info.bits); const UnsignedT = std.meta.Int(.unsigned, info.bits);
const lo = @as(UnsignedT, @bitCast(at_least)); const lo: UnsignedT = @bitCast(at_least);
const hi = @as(UnsignedT, @bitCast(less_than)); const hi: UnsignedT = @bitCast(less_than);
const result = lo +% r.uintLessThanBiased(UnsignedT, hi -% lo); const result = lo +% r.uintLessThanBiased(UnsignedT, hi -% lo);
return @as(T, @bitCast(result)); return @bitCast(result);
} else { } else {
// The signed implementation would work fine, but we can use stricter arithmetic operators here. // The signed implementation would work fine, but we can use stricter arithmetic operators here.
return at_least + r.uintLessThanBiased(T, less_than - at_least); return at_least + r.uintLessThanBiased(T, less_than - at_least);
@ -224,10 +212,10 @@ pub const Random = struct {
if (info.signedness == .signed) { if (info.signedness == .signed) {
// Two's complement makes this math pretty easy. // Two's complement makes this math pretty easy.
const UnsignedT = std.meta.Int(.unsigned, info.bits); const UnsignedT = std.meta.Int(.unsigned, info.bits);
const lo = @as(UnsignedT, @bitCast(at_least)); const lo: UnsignedT = @bitCast(at_least);
const hi = @as(UnsignedT, @bitCast(less_than)); const hi: UnsignedT = @bitCast(less_than);
const result = lo +% r.uintLessThan(UnsignedT, hi -% lo); const result = lo +% r.uintLessThan(UnsignedT, hi -% lo);
return @as(T, @bitCast(result)); return @bitCast(result);
} else { } else {
// The signed implementation would work fine, but we can use stricter arithmetic operators here. // The signed implementation would work fine, but we can use stricter arithmetic operators here.
return at_least + r.uintLessThan(T, less_than - at_least); return at_least + r.uintLessThan(T, less_than - at_least);
@ -242,10 +230,10 @@ pub const Random = struct {
if (info.signedness == .signed) { if (info.signedness == .signed) {
// Two's complement makes this math pretty easy. // Two's complement makes this math pretty easy.
const UnsignedT = std.meta.Int(.unsigned, info.bits); const UnsignedT = std.meta.Int(.unsigned, info.bits);
const lo = @as(UnsignedT, @bitCast(at_least)); const lo: UnsignedT = @bitCast(at_least);
const hi = @as(UnsignedT, @bitCast(at_most)); const hi: UnsignedT = @bitCast(at_most);
const result = lo +% r.uintAtMostBiased(UnsignedT, hi -% lo); const result = lo +% r.uintAtMostBiased(UnsignedT, hi -% lo);
return @as(T, @bitCast(result)); return @bitCast(result);
} else { } else {
// The signed implementation would work fine, but we can use stricter arithmetic operators here. // The signed implementation would work fine, but we can use stricter arithmetic operators here.
return at_least + r.uintAtMostBiased(T, at_most - at_least); return at_least + r.uintAtMostBiased(T, at_most - at_least);
@ -261,10 +249,10 @@ pub const Random = struct {
if (info.signedness == .signed) { if (info.signedness == .signed) {
// Two's complement makes this math pretty easy. // Two's complement makes this math pretty easy.
const UnsignedT = std.meta.Int(.unsigned, info.bits); const UnsignedT = std.meta.Int(.unsigned, info.bits);
const lo = @as(UnsignedT, @bitCast(at_least)); const lo: UnsignedT = @bitCast(at_least);
const hi = @as(UnsignedT, @bitCast(at_most)); const hi: UnsignedT = @bitCast(at_most);
const result = lo +% r.uintAtMost(UnsignedT, hi -% lo); const result = lo +% r.uintAtMost(UnsignedT, hi -% lo);
return @as(T, @bitCast(result)); return @bitCast(result);
} else { } else {
// The signed implementation would work fine, but we can use stricter arithmetic operators here. // The signed implementation would work fine, but we can use stricter arithmetic operators here.
return at_least + r.uintAtMost(T, at_most - at_least); return at_least + r.uintAtMost(T, at_most - at_least);
@ -293,9 +281,9 @@ pub const Random = struct {
rand_lz += @clz(r.int(u32) | 0x7FF); rand_lz += @clz(r.int(u32) | 0x7FF);
} }
} }
const mantissa = @as(u23, @truncate(rand)); const mantissa: u23 = @truncate(rand);
const exponent = @as(u32, 126 - rand_lz) << 23; const exponent = @as(u32, 126 - rand_lz) << 23;
return @as(f32, @bitCast(exponent | mantissa)); return @bitCast(exponent | mantissa);
}, },
f64 => { f64 => {
// Use 52 random bits for the mantissa, and the rest for the exponent. // Use 52 random bits for the mantissa, and the rest for the exponent.
@ -320,7 +308,7 @@ pub const Random = struct {
} }
const mantissa = rand & 0xFFFFFFFFFFFFF; const mantissa = rand & 0xFFFFFFFFFFFFF;
const exponent = (1022 - rand_lz) << 52; const exponent = (1022 - rand_lz) << 52;
return @as(f64, @bitCast(exponent | mantissa)); return @bitCast(exponent | mantissa);
}, },
else => @compileError("unknown floating point type"), else => @compileError("unknown floating point type"),
} }
@ -332,7 +320,7 @@ pub const Random = struct {
pub fn floatNorm(r: Random, comptime T: type) T { pub fn floatNorm(r: Random, comptime T: type) T {
const value = ziggurat.next_f64(r, ziggurat.NormDist); const value = ziggurat.next_f64(r, ziggurat.NormDist);
switch (T) { switch (T) {
f32 => return @as(f32, @floatCast(value)), f32 => return @floatCast(value),
f64 => return value, f64 => return value,
else => @compileError("unknown floating point type"), else => @compileError("unknown floating point type"),
} }
@ -344,7 +332,7 @@ pub const Random = struct {
pub fn floatExp(r: Random, comptime T: type) T { pub fn floatExp(r: Random, comptime T: type) T {
const value = ziggurat.next_f64(r, ziggurat.ExpDist); const value = ziggurat.next_f64(r, ziggurat.ExpDist);
switch (T) { switch (T) {
f32 => return @as(f32, @floatCast(value)), f32 => return @floatCast(value),
f64 => return value, f64 => return value,
else => @compileError("unknown floating point type"), else => @compileError("unknown floating point type"),
} }
@ -378,10 +366,10 @@ pub const Random = struct {
} }
// `i <= j < max <= maxInt(MinInt)` // `i <= j < max <= maxInt(MinInt)`
const max = @as(MinInt, @intCast(buf.len)); const max: MinInt = @intCast(buf.len);
var i: MinInt = 0; var i: MinInt = 0;
while (i < max - 1) : (i += 1) { while (i < max - 1) : (i += 1) {
const j = @as(MinInt, @intCast(r.intRangeLessThan(Index, i, max))); const j: MinInt = @intCast(r.intRangeLessThan(Index, i, max));
mem.swap(T, &buf[i], &buf[j]); mem.swap(T, &buf[i], &buf[j]);
} }
} }
@ -438,13 +426,12 @@ pub const Random = struct {
pub fn limitRangeBiased(comptime T: type, random_int: T, less_than: T) T { pub fn limitRangeBiased(comptime T: type, random_int: T, less_than: T) T {
comptime assert(@typeInfo(T).Int.signedness == .unsigned); comptime assert(@typeInfo(T).Int.signedness == .unsigned);
const bits = @typeInfo(T).Int.bits; const bits = @typeInfo(T).Int.bits;
const T2 = std.meta.Int(.unsigned, bits * 2);
// adapted from: // adapted from:
// http://www.pcg-random.org/posts/bounded-rands.html // http://www.pcg-random.org/posts/bounded-rands.html
// "Integer Multiplication (Biased)" // "Integer Multiplication (Biased)"
var m: T2 = @as(T2, random_int) * @as(T2, less_than); const m = math.mulWide(T, random_int, less_than);
return @as(T, @intCast(m >> bits)); return @intCast(m >> bits);
} }
// Generator to extend 64-bit seed values into longer sequences. // Generator to extend 64-bit seed values into longer sequences.