Working add TensorAlloc
This commit is contained in:
parent
e6d0f62929
commit
f67e9d709d
@ -107,8 +107,8 @@ pub fn TensorAlloc(
|
||||
defer area.deinit();
|
||||
|
||||
const TargetType = TensorAlloc(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||
const l: TargetType = try self.to(area.allocator(), TargetType);
|
||||
const r: TargetType = try rhs.to(area.allocator(), TargetType);
|
||||
const l = try self.to(area.allocator(), TargetType);
|
||||
const r = try rhs.to(area.allocator(), TargetType);
|
||||
|
||||
const new = try TargetType.init(alloc);
|
||||
new.data.* = if (comptime sh.isInt(T)) l.data.* +| r.data.* else l.data.* + r.data.*;
|
||||
@ -252,14 +252,14 @@ pub fn TensorAlloc(
|
||||
|
||||
const vec = try TensorAlloc(Dest.ValueType, dims.argsOpt(), scales.argsOpt(), Dest.shape).splat(
|
||||
alloc,
|
||||
if (comptime T_info == .int and Dest_info == .int)
|
||||
@as(DestT, @intCast(self.data[0]))
|
||||
if (comptime Dest_info == .int)
|
||||
@intCast(self.data[0])
|
||||
else if (comptime T_info == .float and Dest_info == .float)
|
||||
@as(DestT, @floatCast(self.data[0]))
|
||||
@floatCast(self.data[0])
|
||||
else if (comptime T_info == .int and Dest_info == .float)
|
||||
@as(DestT, @floatFromInt(self.data[0]))
|
||||
@floatFromInt(self.data[0])
|
||||
else if (comptime T_info == .float and Dest_info == .int)
|
||||
@as(DestT, @intFromFloat(self.data[0]))
|
||||
@intFromFloat(self.data[0])
|
||||
else
|
||||
unreachable,
|
||||
);
|
||||
@ -278,8 +278,10 @@ pub fn TensorAlloc(
|
||||
|
||||
const new = try Dest.init(alloc);
|
||||
|
||||
if (comptime ratio == 1.0)
|
||||
if (comptime ratio == 1.0) {
|
||||
new.data.* = vec.data.*;
|
||||
return new;
|
||||
}
|
||||
|
||||
if (comptime T == DestT) {
|
||||
if (comptime @typeInfo(T) == .float) {
|
||||
@ -307,8 +309,8 @@ pub fn TensorAlloc(
|
||||
}
|
||||
|
||||
// Cross-type fully vectorized casting with scales
|
||||
const FVec = @Vector(total, f64);
|
||||
const float_vec: FVec = switch (comptime @typeInfo(T)) {
|
||||
const FVec = @Vector(total, DestT);
|
||||
const float_vec: FVec = switch (comptime T_info) {
|
||||
.float => @floatCast(vec.data.*),
|
||||
.int => @floatFromInt(vec.data.*),
|
||||
else => unreachable,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user