Simplified pow

This commit is contained in:
adrien 2026-05-04 22:57:53 +02:00
parent 18830c8b45
commit 5bdc78c065

View File

@ -175,27 +175,13 @@ pub fn TensorStatic(
scales.argsOpt(), scales.argsOpt(),
shape, shape,
) { ) {
if (comptime exp < 0) @compileError("Pow only support exp >= 0");
if (comptime exp == 0) return .{ .data = @splat(1) }; if (comptime exp == 0) return .{ .data = @splat(1) };
if (comptime exp == 1) return self; if (comptime exp == 1) return self;
var data: Vec = self.data;
var base = self.data; for (0..exp - 1) |_|
var result: Vec = @splat(1); data = data * self.data;
comptime var e = @abs(exp); return .{ .data = data };
// $O(\log n)$ Exponentiation by squaring applied to the entire vector
inline while (e > 0) {
if (e % 2 == 1) {
result = if (comptime sh.isInt(T)) result *| base else result * base;
}
e /= 2;
if (e > 0) {
base = if (comptime sh.isInt(T)) base *| base else base * base;
}
}
if (comptime !sh.isInt(T) and exp < 0) {
result = @as(Vec, @splat(1)) / result;
}
return .{ .data = result };
} }
/// Square root of every element. All dimension exponents must be even. /// Square root of every element. All dimension exponents must be even.
@ -207,18 +193,17 @@ pub fn TensorStatic(
) { ) {
if (comptime !dims.isSquare()) if (comptime !dims.isSquare())
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
if (comptime @typeInfo(T) == .float) { if (comptime @typeInfo(T) == .float)
return .{ .data = @sqrt(self.data) }; // Float is natively vectorized! return .{ .data = @sqrt(self.data) };
} else {
const arr: [total]T = self.data; // Add this! const arr: [total]T = self.data; // Add this!
var res_arr: [total]T = undefined; var res_arr: [total]T = undefined;
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
for (0..total) |i| { for (0..total) |i| {
const v = arr[i]; const v = arr[i];
res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
}
return .{ .data = res_arr };
} }
return .{ .data = res_arr };
} }
/// Negate every element. /// Negate every element.