Working add TensorAlloc

This commit is contained in:
adrien 2026-05-15 00:24:39 +02:00
parent e6d0f62929
commit f67e9d709d

View File

@ -107,8 +107,8 @@ pub fn TensorAlloc(
defer area.deinit(); defer area.deinit();
const TargetType = TensorAlloc(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape); const TargetType = TensorAlloc(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
const l: TargetType = try self.to(area.allocator(), TargetType); const l = try self.to(area.allocator(), TargetType);
const r: TargetType = try rhs.to(area.allocator(), TargetType); const r = try rhs.to(area.allocator(), TargetType);
const new = try TargetType.init(alloc); const new = try TargetType.init(alloc);
new.data.* = if (comptime sh.isInt(T)) l.data.* +| r.data.* else l.data.* + r.data.*; new.data.* = if (comptime sh.isInt(T)) l.data.* +| r.data.* else l.data.* + r.data.*;
@ -252,14 +252,14 @@ pub fn TensorAlloc(
const vec = try TensorAlloc(Dest.ValueType, dims.argsOpt(), scales.argsOpt(), Dest.shape).splat( const vec = try TensorAlloc(Dest.ValueType, dims.argsOpt(), scales.argsOpt(), Dest.shape).splat(
alloc, alloc,
if (comptime T_info == .int and Dest_info == .int) if (comptime Dest_info == .int)
@as(DestT, @intCast(self.data[0])) @intCast(self.data[0])
else if (comptime T_info == .float and Dest_info == .float) else if (comptime T_info == .float and Dest_info == .float)
@as(DestT, @floatCast(self.data[0])) @floatCast(self.data[0])
else if (comptime T_info == .int and Dest_info == .float) else if (comptime T_info == .int and Dest_info == .float)
@as(DestT, @floatFromInt(self.data[0])) @floatFromInt(self.data[0])
else if (comptime T_info == .float and Dest_info == .int) else if (comptime T_info == .float and Dest_info == .int)
@as(DestT, @intFromFloat(self.data[0])) @intFromFloat(self.data[0])
else else
unreachable, unreachable,
); );
@ -278,8 +278,10 @@ pub fn TensorAlloc(
const new = try Dest.init(alloc); const new = try Dest.init(alloc);
if (comptime ratio == 1.0) if (comptime ratio == 1.0) {
new.data.* = vec.data.*;
return new; return new;
}
if (comptime T == DestT) { if (comptime T == DestT) {
if (comptime @typeInfo(T) == .float) { if (comptime @typeInfo(T) == .float) {
@ -307,8 +309,8 @@ pub fn TensorAlloc(
} }
// Cross-type fully vectorized casting with scales // Cross-type fully vectorized casting with scales
const FVec = @Vector(total, f64); const FVec = @Vector(total, DestT);
const float_vec: FVec = switch (comptime @typeInfo(T)) { const float_vec: FVec = switch (comptime T_info) {
.float => @floatCast(vec.data.*), .float => @floatCast(vec.data.*),
.int => @floatFromInt(vec.data.*), .int => @floatFromInt(vec.data.*),
else => unreachable, else => unreachable,