Moved isTensor to shared + added isTensorAlloc/Static

This commit is contained in:
adrien 2026-05-15 00:32:58 +02:00
parent f67e9d709d
commit 00e0f5ab73
3 changed files with 28 additions and 36 deletions

View File

@ -6,17 +6,6 @@ const Dimension = Dimensions.Dimension;
const Allocator = std.mem.Allocator;
const sh = @import("shared.zig");
//
// File-scope RHS normalisation helpers
//
inline fn isTensor(comptime Rhs: type) bool {
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
}
/// SIMD implementation of a Tensor.
/// Limited to tensor of ~2000 values.
/// For more, see either TensorAlloc or TensorGPU
pub fn TensorAlloc(
comptime T: type,
comptime d_opt: Dimensions.ArgOpts,
@ -48,6 +37,7 @@ pub fn TensorAlloc(
pub const total: comptime_int = _total;
pub const strides_arr: [shape_.len]comptime_int = _strides;
pub const ISTENSOR = true;
pub const TENSORALLOC = true;
// Specific to Alloc
@ -96,7 +86,7 @@ pub fn TensorAlloc(
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
@ -124,7 +114,7 @@ pub fn TensorAlloc(
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
@ -146,7 +136,7 @@ pub fn TensorAlloc(
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
@ -167,7 +157,7 @@ pub fn TensorAlloc(
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
@ -336,7 +326,7 @@ pub fn TensorAlloc(
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: *Vec, r: *Vec } {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
@ -407,7 +397,7 @@ pub fn TensorAlloc(
comptime axis_b: usize,
) blk: {
const RhsType = @TypeOf(rhs);
if (!isTensor(RhsType))
if (!sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
@ -532,7 +522,7 @@ pub fn TensorAlloc(
) {
const RhsType = @TypeOf(rhs);
if (!isTensor(RhsType))
if (!sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");

View File

@ -5,17 +5,6 @@ const Dimensions = @import("Dimensions.zig");
const Dimension = Dimensions.Dimension;
const sh = @import("shared.zig");
//
// File-scope RHS normalisation helpers
//
inline fn isTensor(comptime Rhs: type) bool {
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
}
/// SIMD implementation of a Tensor.
/// Limited to tensor of ~2000 values.
/// For more, see either TensorAlloc or TensorGPU
pub fn TensorStatic(
comptime T: type,
comptime d_opt: Dimensions.ArgOpts,
@ -50,6 +39,7 @@ pub fn TensorStatic(
pub const total: comptime_int = _total;
pub const strides_arr: [shape_.len]comptime_int = _strides;
pub const ISTENSOR = true;
pub const TENSORSTATIC = true;
/// Convert N-D coords (row-major) to flat index fully comptime.
/// Usage: Tensor.idx(.{row, col})
@ -86,7 +76,7 @@ pub fn TensorStatic(
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
@ -108,7 +98,7 @@ pub fn TensorStatic(
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
@ -130,7 +120,7 @@ pub fn TensorStatic(
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
@ -151,7 +141,7 @@ pub fn TensorStatic(
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
@ -362,7 +352,7 @@ pub fn TensorStatic(
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } {
const RhsType = @TypeOf(rhs);
if (comptime !isTensor(RhsType))
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
@ -433,7 +423,7 @@ pub fn TensorStatic(
comptime axis_b: usize,
) blk: {
const RhsType = @TypeOf(rhs);
if (!isTensor(RhsType))
if (!sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
@ -558,7 +548,7 @@ pub fn TensorStatic(
) {
const RhsType = @TypeOf(rhs);
if (!isTensor(RhsType))
if (!sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");

View File

@ -4,6 +4,18 @@ const UnitScale = Scales.UnitScale;
const Dimensions = @import("Dimensions.zig");
const Dimension = Dimensions.Dimension;
pub fn isTensor(comptime T: type) bool {
return comptime @typeInfo(T) == .@"struct" and @hasDecl(T, "ISTENSOR");
}
pub fn isTensorStatic(comptime T: type) bool {
return comptime isTensor(T) and @hasDecl(T, "TENSORSTATIC");
}
pub fn isTensorAlloc(comptime T: type) bool {
return comptime isTensor(T) and @hasDecl(T, "TENSORALLOC");
}
pub fn shapeTotal(shape: []const comptime_int) usize {
var t: comptime_int = 1;
for (shape) |s| t *= s;