From 00e0f5ab7326a2f9fa8c0c2ec9ea96b07608b97a Mon Sep 17 00:00:00 2001 From: adrien Date: Fri, 15 May 2026 00:32:58 +0200 Subject: [PATCH] Moved isTensor to shared + added isTensorAlloc/Static --- src/TensorAlloc.zig | 26 ++++++++------------------ src/TensorStatic.zig | 26 ++++++++------------------ src/shared.zig | 12 ++++++++++++ 3 files changed, 28 insertions(+), 36 deletions(-) diff --git a/src/TensorAlloc.zig b/src/TensorAlloc.zig index 8e8656f..cf38c98 100644 --- a/src/TensorAlloc.zig +++ b/src/TensorAlloc.zig @@ -6,17 +6,6 @@ const Dimension = Dimensions.Dimension; const Allocator = std.mem.Allocator; const sh = @import("shared.zig"); -// ───────────────────────────────────────────────────────────────────────────── -// File-scope RHS normalisation helpers -// ───────────────────────────────────────────────────────────────────────────── - -inline fn isTensor(comptime Rhs: type) bool { - return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR"); -} - -/// SIMD implementation of a Tensor. -/// Limited to tensor of ~2000 values. -/// For more, see either TensorAlloc or TensorGPU pub fn TensorAlloc( comptime T: type, comptime d_opt: Dimensions.ArgOpts, @@ -48,6 +37,7 @@ pub fn TensorAlloc( pub const total: comptime_int = _total; pub const strides_arr: [shape_.len]comptime_int = _strides; pub const ISTENSOR = true; + pub const TENSORALLOC = true; // Specific to Alloc @@ -96,7 +86,7 @@ pub fn TensorAlloc( shape, ) { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); @@ -124,7 +114,7 @@ pub fn TensorAlloc( shape, ) { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); @@ -146,7 +136,7 @@ pub fn TensorAlloc( shape, ) { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS."); @@ -167,7 +157,7 @@ pub fn TensorAlloc( shape, ) { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS."); @@ -336,7 +326,7 @@ pub fn TensorAlloc( /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: *Vec, r: *Vec } { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); @@ -407,7 +397,7 @@ pub fn TensorAlloc( comptime axis_b: usize, ) blk: { const RhsType = @TypeOf(rhs); - if (!isTensor(RhsType)) + if (!sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (axis_a >= rank) @compileError("contract: axis_a out of bounds"); if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds"); @@ -532,7 +522,7 @@ pub fn TensorAlloc( ) { const RhsType = @TypeOf(rhs); - if (!isTensor(RhsType)) + if (!sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) @compileError("cross product is only defined for 3D vectors (rank-1, length 3)"); diff --git a/src/TensorStatic.zig b/src/TensorStatic.zig index 1e02ff7..a250411 100644 --- a/src/TensorStatic.zig +++ b/src/TensorStatic.zig @@ -5,17 +5,6 @@ const Dimensions = @import("Dimensions.zig"); const Dimension = Dimensions.Dimension; const sh = @import("shared.zig"); -// ───────────────────────────────────────────────────────────────────────────── -// File-scope RHS normalisation helpers -// ───────────────────────────────────────────────────────────────────────────── - -inline fn isTensor(comptime Rhs: type) bool { - return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR"); -} - -/// SIMD implementation of a Tensor. -/// Limited to tensor of ~2000 values. -/// For more, see either TensorAlloc or TensorGPU pub fn TensorStatic( comptime T: type, comptime d_opt: Dimensions.ArgOpts, @@ -50,6 +39,7 @@ pub fn TensorStatic( pub const total: comptime_int = _total; pub const strides_arr: [shape_.len]comptime_int = _strides; pub const ISTENSOR = true; + pub const TENSORSTATIC = true; /// Convert N-D coords (row-major) to flat index — fully comptime. /// Usage: Tensor.idx(.{row, col}) @@ -86,7 +76,7 @@ pub fn TensorStatic( shape, ) { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); @@ -108,7 +98,7 @@ pub fn TensorStatic( shape, ) { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); @@ -130,7 +120,7 @@ pub fn TensorStatic( shape, ) { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS."); @@ -151,7 +141,7 @@ pub fn TensorStatic( shape, ) { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS."); @@ -362,7 +352,7 @@ pub fn TensorStatic( /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } { const RhsType = @TypeOf(rhs); - if (comptime !isTensor(RhsType)) + if (comptime !sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); @@ -433,7 +423,7 @@ pub fn TensorStatic( comptime axis_b: usize, ) blk: { const RhsType = @TypeOf(rhs); - if (!isTensor(RhsType)) + if (!sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (axis_a >= rank) @compileError("contract: axis_a out of bounds"); if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds"); @@ -558,7 +548,7 @@ pub fn TensorStatic( ) { const RhsType = @TypeOf(rhs); - if (!isTensor(RhsType)) + if (!sh.isTensor(RhsType)) @compileError("rhs can only be a Tensor "); if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) @compileError("cross product is only defined for 3D vectors (rank-1, length 3)"); diff --git a/src/shared.zig b/src/shared.zig index 137011f..45ede7f 100644 --- a/src/shared.zig +++ b/src/shared.zig @@ -4,6 +4,18 @@ const UnitScale = Scales.UnitScale; const Dimensions = @import("Dimensions.zig"); const Dimension = Dimensions.Dimension; +pub fn isTensor(comptime T: type) bool { + return comptime @typeInfo(T) == .@"struct" and @hasDecl(T, "ISTENSOR"); +} + +pub fn isTensorStatic(comptime T: type) bool { + return comptime isTensor(T) and @hasDecl(T, "TENSORSTATIC"); +} + +pub fn isTensorAlloc(comptime T: type) bool { + return comptime isTensor(T) and @hasDecl(T, "TENSORALLOC"); +} + pub fn shapeTotal(shape: []const comptime_int) usize { var t: comptime_int = 1; for (shape) |s| t *= s;