Simplified pow
This commit is contained in:
parent
18830c8b45
commit
5bdc78c065
@ -175,27 +175,13 @@ pub fn TensorStatic(
|
|||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
shape,
|
shape,
|
||||||
) {
|
) {
|
||||||
|
if (comptime exp < 0) @compileError("Pow only support exp >= 0");
|
||||||
if (comptime exp == 0) return .{ .data = @splat(1) };
|
if (comptime exp == 0) return .{ .data = @splat(1) };
|
||||||
if (comptime exp == 1) return self;
|
if (comptime exp == 1) return self;
|
||||||
|
var data: Vec = self.data;
|
||||||
var base = self.data;
|
for (0..exp - 1) |_|
|
||||||
var result: Vec = @splat(1);
|
data = data * self.data;
|
||||||
comptime var e = @abs(exp);
|
return .{ .data = data };
|
||||||
|
|
||||||
// $O(\log n)$ Exponentiation by squaring applied to the entire vector
|
|
||||||
inline while (e > 0) {
|
|
||||||
if (e % 2 == 1) {
|
|
||||||
result = if (comptime sh.isInt(T)) result *| base else result * base;
|
|
||||||
}
|
|
||||||
e /= 2;
|
|
||||||
if (e > 0) {
|
|
||||||
base = if (comptime sh.isInt(T)) base *| base else base * base;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (comptime !sh.isInt(T) and exp < 0) {
|
|
||||||
result = @as(Vec, @splat(1)) / result;
|
|
||||||
}
|
|
||||||
return .{ .data = result };
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Square root of every element. All dimension exponents must be even.
|
/// Square root of every element. All dimension exponents must be even.
|
||||||
@ -207,9 +193,9 @@ pub fn TensorStatic(
|
|||||||
) {
|
) {
|
||||||
if (comptime !dims.isSquare())
|
if (comptime !dims.isSquare())
|
||||||
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
||||||
if (comptime @typeInfo(T) == .float) {
|
if (comptime @typeInfo(T) == .float)
|
||||||
return .{ .data = @sqrt(self.data) }; // Float is natively vectorized!
|
return .{ .data = @sqrt(self.data) };
|
||||||
} else {
|
|
||||||
const arr: [total]T = self.data; // Add this!
|
const arr: [total]T = self.data; // Add this!
|
||||||
var res_arr: [total]T = undefined;
|
var res_arr: [total]T = undefined;
|
||||||
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
|
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
|
||||||
@ -219,7 +205,6 @@ pub fn TensorStatic(
|
|||||||
}
|
}
|
||||||
return .{ .data = res_arr };
|
return .{ .data = res_arr };
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Negate every element.
|
/// Negate every element.
|
||||||
pub inline fn negate(self: *const Self) Self {
|
pub inline fn negate(self: *const Self) Self {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user