diff --git a/src/Tensor.zig b/src/Tensor.zig index 85be2b1..397dfec 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -175,27 +175,13 @@ pub fn TensorStatic( scales.argsOpt(), shape, ) { + if (comptime exp < 0) @compileError("Pow only support exp >= 0"); if (comptime exp == 0) return .{ .data = @splat(1) }; if (comptime exp == 1) return self; - - var base = self.data; - var result: Vec = @splat(1); - comptime var e = @abs(exp); - - // $O(\log n)$ Exponentiation by squaring applied to the entire vector - inline while (e > 0) { - if (e % 2 == 1) { - result = if (comptime sh.isInt(T)) result *| base else result * base; - } - e /= 2; - if (e > 0) { - base = if (comptime sh.isInt(T)) base *| base else base * base; - } - } - if (comptime !sh.isInt(T) and exp < 0) { - result = @as(Vec, @splat(1)) / result; - } - return .{ .data = result }; + var data: Vec = self.data; + for (0..exp - 1) |_| + data = data * self.data; + return .{ .data = data }; } /// Square root of every element. All dimension exponents must be even. @@ -207,18 +193,17 @@ pub fn TensorStatic( ) { if (comptime !dims.isSquare()) @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); - if (comptime @typeInfo(T) == .float) { - return .{ .data = @sqrt(self.data) }; // Float is natively vectorized! - } else { - const arr: [total]T = self.data; // Add this! - var res_arr: [total]T = undefined; - const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); - for (0..total) |i| { - const v = arr[i]; - res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); - } - return .{ .data = res_arr }; + if (comptime @typeInfo(T) == .float) + return .{ .data = @sqrt(self.data) }; + + const arr: [total]T = self.data; // Add this! + var res_arr: [total]T = undefined; + const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); + for (0..total) |i| { + const v = arr[i]; + res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); } + return .{ .data = res_arr }; } /// Negate every element.