From f0029449f069c9681bb366412a3680c67a484f61 Mon Sep 17 00:00:00 2001 From: adrien Date: Tue, 28 Apr 2026 14:50:08 +0200 Subject: [PATCH] tmp --- src/Tensor.zig | 51 +++++++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/src/Tensor.zig b/src/Tensor.zig index a0b7e83..30d1de5 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -139,8 +139,9 @@ inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type { /// Take the anyvalue coming from operation and if it is a Tensor, return it. /// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r). -inline fn toRhsTensor(comptime T: type, r: anytype) *const RhsTensorType(T, if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) { - const Rhs = if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r); +inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) { + const is_ptr = @typeInfo(@TypeOf(r)) == .pointer; + const Rhs = if (is_ptr) @TypeOf(r.*) else @TypeOf(r); if (comptime isTensor(Rhs)) return r; const scalar: T = switch (@typeInfo(Rhs)) { .comptime_int => switch (comptime @typeInfo(T)) { @@ -161,7 +162,7 @@ inline fn toRhsTensor(comptime T: type, r: anytype) *const RhsTensorType(T, if ( }, else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)), }; - return &Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} }; + return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} }; } pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void { @@ -253,7 +254,7 @@ pub fn Tensor( return RhsTensorType(T, Rhs); } - inline fn rhs(r: anytype) *const RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) { + inline fn rhs(r: anytype) RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) { return toRhsTensor(T, r); } @@ -266,7 +267,7 @@ pub fn Tensor( /// Element-wise add. Dimensions must match; scales resolve to finer. /// RHS must have the same shape as self, or total == 1 (broadcast). - pub inline fn add(self: *const Self, r: anytype) *const Tensor( + pub inline fn add(self: *const Self, r: anytype) Tensor( T, dims.argsOpt(), finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(), @@ -279,7 +280,7 @@ pub fn Tensor( if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too - return &.{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; + return .{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -288,12 +289,12 @@ pub fn Tensor( const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; - return &.{ .data = if (comptime isInt(T)) l +| rr else l + rr }; + return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; } /// Element-wise sub. Dimensions must match; scales resolve to finer. /// RHS must have the same shape as self, or total == 1 (broadcast). - pub inline fn sub(self: *const Self, r: anytype) *const Tensor( + pub inline fn sub(self: *const Self, r: anytype) Tensor( T, dims.argsOpt(), finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(), @@ -306,7 +307,7 @@ pub fn Tensor( if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS."); if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too - return &.{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; + return .{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -315,12 +316,12 @@ pub fn Tensor( const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; - return &.{ .data = if (comptime isInt(T)) l -| rr else l - rr }; + return .{ .data = if (comptime isInt(T)) l -| rr else l - rr }; } /// Element-wise multiply. Dimension exponents summed. /// Shape {1} RHS is automatically broadcast across all elements. - pub inline fn mul(self: *const Self, r: anytype) *const Tensor( + pub inline fn mul(self: *const Self, r: anytype) Tensor( T, dims.add(RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)).dims).argsOpt(), finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(), @@ -336,12 +337,12 @@ pub fn Tensor( const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr: Vec = broadcastToVec(RhsNorm, rr_base); - return &.{ .data = if (comptime isInt(T)) l *| rr else l * rr }; + return .{ .data = if (comptime isInt(T)) l *| rr else l * rr }; } /// Element-wise divide. Dimension exponents subtracted. /// Shape {1} RHS is automatically broadcast across all elements. - pub inline fn div(self: *const Self, r: anytype) *const Tensor( + pub inline fn div(self: *const Self, r: anytype) Tensor( T, dims.sub(RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)).dims).argsOpt(), finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(), @@ -358,19 +359,19 @@ pub fn Tensor( const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr: Vec = broadcastToVec(RhsNorm, rr_base); if (comptime isInt(T)) { - return &.{ .data = @divTrunc(l, rr) }; + return .{ .data = @divTrunc(l, rr) }; } else { - return &.{ .data = l / rr }; + return .{ .data = l / rr }; } } /// Absolute value of every element. - pub inline fn abs(self: *const Self) *const Self { - return &.{ .data = @bitCast(@abs(self.data)) }; + pub inline fn abs(self: *const Self) Self { + return .{ .data = @bitCast(@abs(self.data)) }; } /// Raise every element to a comptime integer exponent. - pub inline fn pow(self: Self, comptime exp: comptime_int) *const Tensor( + pub inline fn pow(self: Self, comptime exp: comptime_int) Tensor( T, dims.scale(exp).argsOpt(), scales.argsOpt(), @@ -396,11 +397,11 @@ pub fn Tensor( if (comptime !isInt(T) and exp < 0) { result = @as(Vec, @splat(1)) / result; } - return &.{ .data = result }; + return .{ .data = result }; } /// Square root of every element. All dimension exponents must be even. - pub inline fn sqrt(self: Self) *const Tensor( + pub inline fn sqrt(self: Self) Tensor( T, dims.div(2).argsOpt(), scales.argsOpt(), @@ -409,7 +410,7 @@ pub fn Tensor( if (comptime !dims.isSquare()) @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); if (comptime @typeInfo(T) == .float) { - return &.{ .data = @sqrt(self.data) }; // Float is natively vectorized! + return .{ .data = @sqrt(self.data) }; // Float is natively vectorized! } else { const arr: [total]T = self.data; // Add this! var res_arr: [total]T = undefined; @@ -418,13 +419,13 @@ pub fn Tensor( const v = arr[i]; res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); } - return &.{ .data = res_arr }; + return .{ .data = res_arr }; } } /// Negate every element. - pub inline fn negate(self: Self) *const Self { - return &.{ .data = -self.data }; + pub inline fn negate(self: Self) Self { + return .{ .data = -self.data }; } /// Convert to a compatible Tensor type. @@ -434,7 +435,7 @@ pub fn Tensor( pub inline fn to( self: *const Self, comptime Dest: type, - ) *const Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { + ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); if (comptime Self == ActualDest) return self;