Moved isTensor to shared + added isTensorAlloc/Static
This commit is contained in:
parent
f67e9d709d
commit
00e0f5ab73
@ -6,17 +6,6 @@ const Dimension = Dimensions.Dimension;
|
||||
const Allocator = std.mem.Allocator;
|
||||
const sh = @import("shared.zig");
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// File-scope RHS normalisation helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
inline fn isTensor(comptime Rhs: type) bool {
|
||||
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
|
||||
}
|
||||
|
||||
/// SIMD implementation of a Tensor.
|
||||
/// Limited to tensor of ~2000 values.
|
||||
/// For more, see either TensorAlloc or TensorGPU
|
||||
pub fn TensorAlloc(
|
||||
comptime T: type,
|
||||
comptime d_opt: Dimensions.ArgOpts,
|
||||
@ -48,6 +37,7 @@ pub fn TensorAlloc(
|
||||
pub const total: comptime_int = _total;
|
||||
pub const strides_arr: [shape_.len]comptime_int = _strides;
|
||||
pub const ISTENSOR = true;
|
||||
pub const TENSORALLOC = true;
|
||||
|
||||
// Specific to Alloc
|
||||
|
||||
@ -96,7 +86,7 @@ pub fn TensorAlloc(
|
||||
shape,
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime !dims.eql(RhsType.dims))
|
||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
@ -124,7 +114,7 @@ pub fn TensorAlloc(
|
||||
shape,
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime !dims.eql(RhsType.dims))
|
||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
@ -146,7 +136,7 @@ pub fn TensorAlloc(
|
||||
shape,
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
@ -167,7 +157,7 @@ pub fn TensorAlloc(
|
||||
shape,
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
@ -336,7 +326,7 @@ pub fn TensorAlloc(
|
||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: *Vec, r: *Vec } {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
@ -407,7 +397,7 @@ pub fn TensorAlloc(
|
||||
comptime axis_b: usize,
|
||||
) blk: {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (!isTensor(RhsType))
|
||||
if (!sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
||||
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
|
||||
@ -532,7 +522,7 @@ pub fn TensorAlloc(
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
|
||||
if (!isTensor(RhsType))
|
||||
if (!sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
|
||||
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
||||
|
||||
@ -5,17 +5,6 @@ const Dimensions = @import("Dimensions.zig");
|
||||
const Dimension = Dimensions.Dimension;
|
||||
const sh = @import("shared.zig");
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// File-scope RHS normalisation helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
inline fn isTensor(comptime Rhs: type) bool {
|
||||
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
|
||||
}
|
||||
|
||||
/// SIMD implementation of a Tensor.
|
||||
/// Limited to tensor of ~2000 values.
|
||||
/// For more, see either TensorAlloc or TensorGPU
|
||||
pub fn TensorStatic(
|
||||
comptime T: type,
|
||||
comptime d_opt: Dimensions.ArgOpts,
|
||||
@ -50,6 +39,7 @@ pub fn TensorStatic(
|
||||
pub const total: comptime_int = _total;
|
||||
pub const strides_arr: [shape_.len]comptime_int = _strides;
|
||||
pub const ISTENSOR = true;
|
||||
pub const TENSORSTATIC = true;
|
||||
|
||||
/// Convert N-D coords (row-major) to flat index — fully comptime.
|
||||
/// Usage: Tensor.idx(.{row, col})
|
||||
@ -86,7 +76,7 @@ pub fn TensorStatic(
|
||||
shape,
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime !dims.eql(RhsType.dims))
|
||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
@ -108,7 +98,7 @@ pub fn TensorStatic(
|
||||
shape,
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime !dims.eql(RhsType.dims))
|
||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
@ -130,7 +120,7 @@ pub fn TensorStatic(
|
||||
shape,
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
@ -151,7 +141,7 @@ pub fn TensorStatic(
|
||||
shape,
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
@ -362,7 +352,7 @@ pub fn TensorStatic(
|
||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
if (comptime !sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
@ -433,7 +423,7 @@ pub fn TensorStatic(
|
||||
comptime axis_b: usize,
|
||||
) blk: {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (!isTensor(RhsType))
|
||||
if (!sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
||||
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
|
||||
@ -558,7 +548,7 @@ pub fn TensorStatic(
|
||||
) {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
|
||||
if (!isTensor(RhsType))
|
||||
if (!sh.isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
|
||||
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
||||
|
||||
@ -4,6 +4,18 @@ const UnitScale = Scales.UnitScale;
|
||||
const Dimensions = @import("Dimensions.zig");
|
||||
const Dimension = Dimensions.Dimension;
|
||||
|
||||
pub fn isTensor(comptime T: type) bool {
|
||||
return comptime @typeInfo(T) == .@"struct" and @hasDecl(T, "ISTENSOR");
|
||||
}
|
||||
|
||||
pub fn isTensorStatic(comptime T: type) bool {
|
||||
return comptime isTensor(T) and @hasDecl(T, "TENSORSTATIC");
|
||||
}
|
||||
|
||||
pub fn isTensorAlloc(comptime T: type) bool {
|
||||
return comptime isTensor(T) and @hasDecl(T, "TENSORALLOC");
|
||||
}
|
||||
|
||||
pub fn shapeTotal(shape: []const comptime_int) usize {
|
||||
var t: comptime_int = 1;
|
||||
for (shape) |s| t *= s;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user