From f67e9d709dedb552755a37d9996e3d3a90d23b52 Mon Sep 17 00:00:00 2001 From: adrien Date: Fri, 15 May 2026 00:24:39 +0200 Subject: [PATCH] Working add TensorAlloc --- src/TensorAlloc.zig | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/TensorAlloc.zig b/src/TensorAlloc.zig index 68b400d..8e8656f 100644 --- a/src/TensorAlloc.zig +++ b/src/TensorAlloc.zig @@ -107,8 +107,8 @@ pub fn TensorAlloc( defer area.deinit(); const TargetType = TensorAlloc(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape); - const l: TargetType = try self.to(area.allocator(), TargetType); - const r: TargetType = try rhs.to(area.allocator(), TargetType); + const l = try self.to(area.allocator(), TargetType); + const r = try rhs.to(area.allocator(), TargetType); const new = try TargetType.init(alloc); new.data.* = if (comptime sh.isInt(T)) l.data.* +| r.data.* else l.data.* + r.data.*; @@ -252,14 +252,14 @@ pub fn TensorAlloc( const vec = try TensorAlloc(Dest.ValueType, dims.argsOpt(), scales.argsOpt(), Dest.shape).splat( alloc, - if (comptime T_info == .int and Dest_info == .int) - @as(DestT, @intCast(self.data[0])) + if (comptime Dest_info == .int) + @intCast(self.data[0]) else if (comptime T_info == .float and Dest_info == .float) - @as(DestT, @floatCast(self.data[0])) + @floatCast(self.data[0]) else if (comptime T_info == .int and Dest_info == .float) - @as(DestT, @floatFromInt(self.data[0])) + @floatFromInt(self.data[0]) else if (comptime T_info == .float and Dest_info == .int) - @as(DestT, @intFromFloat(self.data[0])) + @intFromFloat(self.data[0]) else unreachable, ); @@ -278,8 +278,10 @@ pub fn TensorAlloc( const new = try Dest.init(alloc); - if (comptime ratio == 1.0) + if (comptime ratio == 1.0) { + new.data.* = vec.data.*; return new; + } if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) { @@ -307,8 +309,8 @@ pub fn TensorAlloc( } // Cross-type fully vectorized casting with scales - const FVec = @Vector(total, f64); - const float_vec: FVec = switch (comptime @typeInfo(T)) { + const FVec = @Vector(total, DestT); + const float_vec: FVec = switch (comptime T_info) { .float => @floatCast(vec.data.*), .int => @floatFromInt(vec.data.*), else => unreachable,