Simplified pow
This commit is contained in:
parent
18830c8b45
commit
5bdc78c065
@ -175,27 +175,13 @@ pub fn TensorStatic(
|
||||
scales.argsOpt(),
|
||||
shape,
|
||||
) {
|
||||
if (comptime exp < 0) @compileError("Pow only support exp >= 0");
|
||||
if (comptime exp == 0) return .{ .data = @splat(1) };
|
||||
if (comptime exp == 1) return self;
|
||||
|
||||
var base = self.data;
|
||||
var result: Vec = @splat(1);
|
||||
comptime var e = @abs(exp);
|
||||
|
||||
// $O(\log n)$ Exponentiation by squaring applied to the entire vector
|
||||
inline while (e > 0) {
|
||||
if (e % 2 == 1) {
|
||||
result = if (comptime sh.isInt(T)) result *| base else result * base;
|
||||
}
|
||||
e /= 2;
|
||||
if (e > 0) {
|
||||
base = if (comptime sh.isInt(T)) base *| base else base * base;
|
||||
}
|
||||
}
|
||||
if (comptime !sh.isInt(T) and exp < 0) {
|
||||
result = @as(Vec, @splat(1)) / result;
|
||||
}
|
||||
return .{ .data = result };
|
||||
var data: Vec = self.data;
|
||||
for (0..exp - 1) |_|
|
||||
data = data * self.data;
|
||||
return .{ .data = data };
|
||||
}
|
||||
|
||||
/// Square root of every element. All dimension exponents must be even.
|
||||
@ -207,9 +193,9 @@ pub fn TensorStatic(
|
||||
) {
|
||||
if (comptime !dims.isSquare())
|
||||
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
||||
if (comptime @typeInfo(T) == .float) {
|
||||
return .{ .data = @sqrt(self.data) }; // Float is natively vectorized!
|
||||
} else {
|
||||
if (comptime @typeInfo(T) == .float)
|
||||
return .{ .data = @sqrt(self.data) };
|
||||
|
||||
const arr: [total]T = self.data; // Add this!
|
||||
var res_arr: [total]T = undefined;
|
||||
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
|
||||
@ -219,7 +205,6 @@ pub fn TensorStatic(
|
||||
}
|
||||
return .{ .data = res_arr };
|
||||
}
|
||||
}
|
||||
|
||||
/// Negate every element.
|
||||
pub inline fn negate(self: *const Self) Self {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user