Simplified pow

This commit is contained in:
adrien 2026-05-04 22:57:53 +02:00
parent 18830c8b45
commit 5bdc78c065

View File

@ -175,27 +175,13 @@ pub fn TensorStatic(
scales.argsOpt(),
shape,
) {
if (comptime exp < 0) @compileError("Pow only support exp >= 0");
if (comptime exp == 0) return .{ .data = @splat(1) };
if (comptime exp == 1) return self;
var base = self.data;
var result: Vec = @splat(1);
comptime var e = @abs(exp);
// $O(\log n)$ Exponentiation by squaring applied to the entire vector
inline while (e > 0) {
if (e % 2 == 1) {
result = if (comptime sh.isInt(T)) result *| base else result * base;
}
e /= 2;
if (e > 0) {
base = if (comptime sh.isInt(T)) base *| base else base * base;
}
}
if (comptime !sh.isInt(T) and exp < 0) {
result = @as(Vec, @splat(1)) / result;
}
return .{ .data = result };
var data: Vec = self.data;
for (0..exp - 1) |_|
data = data * self.data;
return .{ .data = data };
}
/// Square root of every element. All dimension exponents must be even.
@ -207,9 +193,9 @@ pub fn TensorStatic(
) {
if (comptime !dims.isSquare())
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
if (comptime @typeInfo(T) == .float) {
return .{ .data = @sqrt(self.data) }; // Float is natively vectorized!
} else {
if (comptime @typeInfo(T) == .float)
return .{ .data = @sqrt(self.data) };
const arr: [total]T = self.data; // Add this!
var res_arr: [total]T = undefined;
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
@ -219,7 +205,6 @@ pub fn TensorStatic(
}
return .{ .data = res_arr };
}
}
/// Negate every element.
pub inline fn negate(self: *const Self) Self {