From b729a3f008304d464b431f8ac34ad16cde08ba7b Mon Sep 17 00:00:00 2001 From: expikr <77922942+expikr@users.noreply.github.com> Date: Sat, 20 Jan 2024 14:32:07 +0800 Subject: [PATCH] std.math: make hypot infer type from argument (#17910) using peer type resolution --- lib/std/math/complex/abs.zig | 5 ++--- lib/std/math/complex/sqrt.zig | 8 ++++---- lib/std/math/hypot.zig | 22 +++++++++++++++------- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/lib/std/math/complex/abs.zig b/lib/std/math/complex/abs.zig index 8716360535..97999ee775 100644 --- a/lib/std/math/complex/abs.zig +++ b/lib/std/math/complex/abs.zig @@ -5,9 +5,8 @@ const cmath = math.complex; const Complex = cmath.Complex; /// Returns the absolute value (modulus) of z. -pub fn abs(z: anytype) @TypeOf(z.re) { - const T = @TypeOf(z.re); - return math.hypot(T, z.re, z.im); +pub fn abs(z: anytype) @TypeOf(z.re, z.im) { + return math.hypot(z.re, z.im); } const epsilon = 0.0001; diff --git a/lib/std/math/complex/sqrt.zig b/lib/std/math/complex/sqrt.zig index fe2880e355..745940b955 100644 --- a/lib/std/math/complex/sqrt.zig +++ b/lib/std/math/complex/sqrt.zig @@ -56,13 +56,13 @@ fn sqrt32(z: Complex(f32)) Complex(f32) { const dy = @as(f64, y); if (dx >= 0) { - const t = @sqrt((dx + math.hypot(f64, dx, dy)) * 0.5); + const t = @sqrt((dx + math.hypot(dx, dy)) * 0.5); return Complex(f32).init( @as(f32, @floatCast(t)), @as(f32, @floatCast(dy / (2.0 * t))), ); } else { - const t = @sqrt((-dx + math.hypot(f64, dx, dy)) * 0.5); + const t = @sqrt((-dx + math.hypot(dx, dy)) * 0.5); return Complex(f32).init( @as(f32, @floatCast(@abs(y) / (2.0 * t))), @as(f32, @floatCast(math.copysign(t, y))), @@ -112,10 +112,10 @@ fn sqrt64(z: Complex(f64)) Complex(f64) { var result: Complex(f64) = undefined; if (x >= 0) { - const t = @sqrt((x + math.hypot(f64, x, y)) * 0.5); + const t = @sqrt((x + math.hypot(x, y)) * 0.5); result = Complex(f64).init(t, y / (2.0 * t)); } else { - const t = @sqrt((-x + math.hypot(f64, x, y)) * 0.5); + const t = @sqrt((-x + math.hypot(x, y)) * 0.5); result = Complex(f64).init(@abs(y) / (2.0 * t), math.copysign(t, y)); } diff --git a/lib/std/math/hypot.zig b/lib/std/math/hypot.zig index 9fb569667b..fc5b0ddca8 100644 --- a/lib/std/math/hypot.zig +++ b/lib/std/math/hypot.zig @@ -12,11 +12,15 @@ const maxInt = std.math.maxInt; /// Returns sqrt(x * x + y * y), avoiding unnecessary overflow and underflow. /// /// Special Cases: -/// - hypot(+-inf, y) = +inf -/// - hypot(x, +-inf) = +inf -/// - hypot(nan, y) = nan -/// - hypot(x, nan) = nan -pub fn hypot(comptime T: type, x: T, y: T) T { +/// +/// | x | y | hypot | +/// |-------|-------|-------| +/// | +inf | num | +inf | +/// | num | +-inf | +inf | +/// | nan | any | nan | +/// | any | nan | nan | +pub fn hypot(x: anytype, y: anytype) @TypeOf(x, y) { + const T = @TypeOf(x, y); return switch (T) { f32 => hypot32(x, y), f64 => hypot64(x, y), @@ -121,8 +125,12 @@ fn hypot64(x: f64, y: f64) f64 { } test "math.hypot" { - try expect(hypot(f32, 0.0, -1.2) == hypot32(0.0, -1.2)); - try expect(hypot(f64, 0.0, -1.2) == hypot64(0.0, -1.2)); + const x32: f32 = 0.0; + const y32: f32 = -1.2; + const x64: f64 = 0.0; + const y64: f64 = -1.2; + try expect(hypot(x32, y32) == hypot32(0.0, -1.2)); + try expect(hypot(x64, y64) == hypot64(0.0, -1.2)); } test "math.hypot32" {