Moved isTensor to shared + added isTensorAlloc/Static
This commit is contained in:
parent
f67e9d709d
commit
00e0f5ab73
@ -6,17 +6,6 @@ const Dimension = Dimensions.Dimension;
|
|||||||
const Allocator = std.mem.Allocator;
|
const Allocator = std.mem.Allocator;
|
||||||
const sh = @import("shared.zig");
|
const sh = @import("shared.zig");
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
|
||||||
// File-scope RHS normalisation helpers
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
inline fn isTensor(comptime Rhs: type) bool {
|
|
||||||
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
|
|
||||||
}
|
|
||||||
|
|
||||||
/// SIMD implementation of a Tensor.
|
|
||||||
/// Limited to tensor of ~2000 values.
|
|
||||||
/// For more, see either TensorAlloc or TensorGPU
|
|
||||||
pub fn TensorAlloc(
|
pub fn TensorAlloc(
|
||||||
comptime T: type,
|
comptime T: type,
|
||||||
comptime d_opt: Dimensions.ArgOpts,
|
comptime d_opt: Dimensions.ArgOpts,
|
||||||
@ -48,6 +37,7 @@ pub fn TensorAlloc(
|
|||||||
pub const total: comptime_int = _total;
|
pub const total: comptime_int = _total;
|
||||||
pub const strides_arr: [shape_.len]comptime_int = _strides;
|
pub const strides_arr: [shape_.len]comptime_int = _strides;
|
||||||
pub const ISTENSOR = true;
|
pub const ISTENSOR = true;
|
||||||
|
pub const TENSORALLOC = true;
|
||||||
|
|
||||||
// Specific to Alloc
|
// Specific to Alloc
|
||||||
|
|
||||||
@ -96,7 +86,7 @@ pub fn TensorAlloc(
|
|||||||
shape,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime !dims.eql(RhsType.dims))
|
if (comptime !dims.eql(RhsType.dims))
|
||||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||||
@ -124,7 +114,7 @@ pub fn TensorAlloc(
|
|||||||
shape,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime !dims.eql(RhsType.dims))
|
if (comptime !dims.eql(RhsType.dims))
|
||||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||||
@ -146,7 +136,7 @@ pub fn TensorAlloc(
|
|||||||
shape,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
@ -167,7 +157,7 @@ pub fn TensorAlloc(
|
|||||||
shape,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
@ -336,7 +326,7 @@ pub fn TensorAlloc(
|
|||||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||||
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: *Vec, r: *Vec } {
|
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: *Vec, r: *Vec } {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
@ -407,7 +397,7 @@ pub fn TensorAlloc(
|
|||||||
comptime axis_b: usize,
|
comptime axis_b: usize,
|
||||||
) blk: {
|
) blk: {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (!isTensor(RhsType))
|
if (!sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
||||||
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
|
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
|
||||||
@ -532,7 +522,7 @@ pub fn TensorAlloc(
|
|||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
|
|
||||||
if (!isTensor(RhsType))
|
if (!sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
|
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
|
||||||
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
||||||
|
|||||||
@ -5,17 +5,6 @@ const Dimensions = @import("Dimensions.zig");
|
|||||||
const Dimension = Dimensions.Dimension;
|
const Dimension = Dimensions.Dimension;
|
||||||
const sh = @import("shared.zig");
|
const sh = @import("shared.zig");
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
|
||||||
// File-scope RHS normalisation helpers
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
inline fn isTensor(comptime Rhs: type) bool {
|
|
||||||
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
|
|
||||||
}
|
|
||||||
|
|
||||||
/// SIMD implementation of a Tensor.
|
|
||||||
/// Limited to tensor of ~2000 values.
|
|
||||||
/// For more, see either TensorAlloc or TensorGPU
|
|
||||||
pub fn TensorStatic(
|
pub fn TensorStatic(
|
||||||
comptime T: type,
|
comptime T: type,
|
||||||
comptime d_opt: Dimensions.ArgOpts,
|
comptime d_opt: Dimensions.ArgOpts,
|
||||||
@ -50,6 +39,7 @@ pub fn TensorStatic(
|
|||||||
pub const total: comptime_int = _total;
|
pub const total: comptime_int = _total;
|
||||||
pub const strides_arr: [shape_.len]comptime_int = _strides;
|
pub const strides_arr: [shape_.len]comptime_int = _strides;
|
||||||
pub const ISTENSOR = true;
|
pub const ISTENSOR = true;
|
||||||
|
pub const TENSORSTATIC = true;
|
||||||
|
|
||||||
/// Convert N-D coords (row-major) to flat index — fully comptime.
|
/// Convert N-D coords (row-major) to flat index — fully comptime.
|
||||||
/// Usage: Tensor.idx(.{row, col})
|
/// Usage: Tensor.idx(.{row, col})
|
||||||
@ -86,7 +76,7 @@ pub fn TensorStatic(
|
|||||||
shape,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime !dims.eql(RhsType.dims))
|
if (comptime !dims.eql(RhsType.dims))
|
||||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||||
@ -108,7 +98,7 @@ pub fn TensorStatic(
|
|||||||
shape,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime !dims.eql(RhsType.dims))
|
if (comptime !dims.eql(RhsType.dims))
|
||||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||||
@ -130,7 +120,7 @@ pub fn TensorStatic(
|
|||||||
shape,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
@ -151,7 +141,7 @@ pub fn TensorStatic(
|
|||||||
shape,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
@ -362,7 +352,7 @@ pub fn TensorStatic(
|
|||||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||||
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } {
|
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime !isTensor(RhsType))
|
if (comptime !sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
@ -433,7 +423,7 @@ pub fn TensorStatic(
|
|||||||
comptime axis_b: usize,
|
comptime axis_b: usize,
|
||||||
) blk: {
|
) blk: {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (!isTensor(RhsType))
|
if (!sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
||||||
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
|
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
|
||||||
@ -558,7 +548,7 @@ pub fn TensorStatic(
|
|||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
|
|
||||||
if (!isTensor(RhsType))
|
if (!sh.isTensor(RhsType))
|
||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
|
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
|
||||||
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
||||||
|
|||||||
@ -4,6 +4,18 @@ const UnitScale = Scales.UnitScale;
|
|||||||
const Dimensions = @import("Dimensions.zig");
|
const Dimensions = @import("Dimensions.zig");
|
||||||
const Dimension = Dimensions.Dimension;
|
const Dimension = Dimensions.Dimension;
|
||||||
|
|
||||||
|
pub fn isTensor(comptime T: type) bool {
|
||||||
|
return comptime @typeInfo(T) == .@"struct" and @hasDecl(T, "ISTENSOR");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isTensorStatic(comptime T: type) bool {
|
||||||
|
return comptime isTensor(T) and @hasDecl(T, "TENSORSTATIC");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isTensorAlloc(comptime T: type) bool {
|
||||||
|
return comptime isTensor(T) and @hasDecl(T, "TENSORALLOC");
|
||||||
|
}
|
||||||
|
|
||||||
pub fn shapeTotal(shape: []const comptime_int) usize {
|
pub fn shapeTotal(shape: []const comptime_int) usize {
|
||||||
var t: comptime_int = 1;
|
var t: comptime_int = 1;
|
||||||
for (shape) |s| t *= s;
|
for (shape) |s| t *= s;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user