Added shape comptime check for Tensor add/sub/div/mul

This commit is contained in:
AdrienBouvais 2026-04-27 14:45:41 +02:00
parent f37a196b15
commit 16d25e7e7e

View File

@ -14,6 +14,15 @@ pub fn shapeTotal(comptime shape: []const usize) usize {
return t;
}
/// Check if two shapes are strictly identical.
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
if (a.len != b.len) return false;
for (a, 0..) |v, i| {
if (v != b[i]) return false;
}
return true;
}
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
/// e.g. shape {3, 4} strides {4, 1}
/// shape {2, 3, 4} strides {12, 4, 1}
@ -126,10 +135,6 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
}
//
// File-scope RHS normalisation helpers
//
// Any bare comptime_int / comptime_float / runtime T used as an arithmetic
// or comparison RHS is wrapped into a dimensionless Tensor of shape {1}.
// Actual Tensor types are passed through unchanged.
//
fn isTensor(comptime Rhs: type) bool {
@ -193,32 +198,6 @@ pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
}
}
///
/// Tensor unified dimensioned ND type.
///
/// T : element numeric type (f32, f64, i32, i128, )
/// d_opt : SI dimension exponents
/// s_opt : unit scales
/// shape_ : compile-time shape
/// &.{1} scalar
/// &.{3} 3-vector
/// &.{4, 4} 4×4 matrix
/// &.{3, 3, 3} 3D field
///
/// Storage: flat @Vector(total, T) where total = product(shape_).
/// All arithmetic operates on the flat vector directly SIMD wherever possible.
///
/// Shape-related comptime constants exposed on every Tensor type:
/// dims : Dimensions SI exponent struct
/// scales : Scales unit scale struct
/// shape : []const usize
/// rank : usize = shape.len
/// total : usize = product(shape)
/// strides_arr : [rank]usize row-major strides
///
/// Index helper:
/// Tensor.idx(.{row, col}) flat index (comptime, no runtime cost)
///
pub fn Tensor(
comptime T: type,
comptime d_opt: Dimensions.ArgOpts,
@ -226,8 +205,10 @@ pub fn Tensor(
comptime shape_: []const usize,
) type {
comptime {
std.debug.assert(shape_.len >= 1);
for (shape_) |s| std.debug.assert(s >= 1);
if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
for (shape_) |s| {
if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1.");
}
}
@setEvalBranchQuota(10_000_000);
@ -236,7 +217,6 @@ pub fn Tensor(
const Vec = @Vector(_total, T);
return struct {
/// Flat SIMD storage. All arithmetic operates here directly.
data: Vec,
const Self = @This();
@ -250,27 +230,19 @@ pub fn Tensor(
pub const strides_arr: [shape_.len]usize = _strides;
pub const ISTENSOR = true;
//
// Index helper
//
/// Convert N-D coords (row-major) to flat index fully comptime.
/// Usage: Tensor.idx(.{row, col})
pub inline fn idx(comptime coords: [rank]usize) usize {
comptime {
var flat: usize = 0;
for (0..rank) |i| {
std.debug.assert(coords[i] < shape[i]);
if (coords[i] >= shape[i]) @compileError("idx: Coordinate out of bounds");
flat += coords[i] * strides_arr[i];
}
return flat;
}
}
//
// Constructors
//
/// Broadcast a single value across all elements.
pub inline fn splat(v: T) Self {
return .{ .data = @splat(v) };
@ -279,19 +251,11 @@ pub fn Tensor(
pub const zero: Self = splat(0);
pub const one: Self = splat(1);
//
// GPU readiness
//
/// Return a mutable slice to the flat storage zero-copy WebGPU buffer mapping.
pub inline fn asSlice(self: *Self) []T {
return @as([*]T, @ptrCast(&self.data))[0..total];
}
//
// Internal: RHS normalisation
//
inline fn RhsT(comptime Rhs: type) type {
return RhsTensorType(T, Rhs);
}
@ -299,10 +263,6 @@ pub fn Tensor(
return toRhsTensor(T, r);
}
//
// Internal: scalar broadcast (shape {1} full Vec)
//
inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec {
return if (comptime RhsType.total == 1 and total > 1)
@splat(r.data[0])
@ -310,12 +270,8 @@ pub fn Tensor(
r.data;
}
//
// Arithmetic
//
/// Element-wise add. Dimensions must match; scales resolve to finer.
/// RHS must have the same element count as self, or total == 1 (broadcast).
/// RHS must have the same shape as self, or total == 1 (broadcast).
pub inline fn add(self: Self, r: anytype) Tensor(
T,
dims.argsOpt(),
@ -326,8 +282,8 @@ pub fn Tensor(
const RhsType = @TypeOf(rhs_q);
if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != total and RhsType.total != 1)
@compileError("Shape mismatch in add: element counts must match or RHS must be scalar (total=1).");
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
@ -339,7 +295,8 @@ pub fn Tensor(
return .{ .data = if (comptime isInt(T)) l +| rr else l + rr };
}
/// Element-wise subtract. Dimensions must match; scales resolve to finer.
/// Element-wise sub. Dimensions must match; scales resolve to finer.
/// RHS must have the same shape as self, or total == 1 (broadcast).
pub inline fn sub(self: Self, r: anytype) Tensor(
T,
dims.argsOpt(),
@ -350,8 +307,8 @@ pub fn Tensor(
const RhsType = @TypeOf(rhs_q);
if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != total and RhsType.total != 1)
@compileError("Shape mismatch in sub.");
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
@ -373,8 +330,8 @@ pub fn Tensor(
) {
const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q);
if (comptime RhsType.total != total and RhsType.total != 1)
@compileError("Shape mismatch in mul.");
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
@ -394,8 +351,8 @@ pub fn Tensor(
) {
const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q);
if (comptime RhsType.total != total and RhsType.total != 1)
@compileError("Shape mismatch in div.");
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
@ -411,10 +368,6 @@ pub fn Tensor(
}
}
//
// Unary
//
/// Absolute value of every element.
pub inline fn abs(self: Self) Self {
return .{ .data = @bitCast(@abs(self.data)) };
@ -469,13 +422,9 @@ pub fn Tensor(
return .{ .data = -self.data };
}
//
// Conversion
//
/// Convert to a compatible Tensor type.
/// Dimension mismatch compile error.
/// Dest.total must equal self.total, or Dest.total == 1 (scalar pattern).
/// Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern).
/// Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
pub inline fn to(
self: Self,
@ -485,20 +434,19 @@ pub fn Tensor(
if (comptime !dims.eql(ActualDest.dims))
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
if (comptime Self == ActualDest) return self;
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
comptime std.debug.assert(Dest.total == total or Dest.total == 1);
if (comptime Self == ActualDest) return self;
const DestT = ActualDest.ValueType;
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
const DestVec = @Vector(total, DestT);
// Same numeric type
if (comptime T == DestT) {
if (comptime @typeInfo(T) == .float)
return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
// Integer branch prevents division-by-zero
if (comptime ratio >= 1.0) {
const mult: T = comptime @intFromFloat(@round(ratio));
return .{ .data = self.data *| @as(Vec, @splat(mult)) };
@ -517,7 +465,6 @@ pub fn Tensor(
}
}
// Cross numeric type
var result: DestVec = undefined;
inline for (0..total) |i| {
const float_val: f64 = switch (comptime @typeInfo(T)) {
@ -535,17 +482,6 @@ pub fn Tensor(
return .{ .data = result };
}
//
// Comparisons
//
// Return type: bool when total == 1 (scalar semantics)
// [total]bool when total > 1 (element-wise, flat-indexed)
//
// Whole-tensor equality check eqAll / neAll (always returns bool).
// A shape {1} RHS is broadcast automatically, unifying the old
// eqScalar / gtScalar / family into the plain eq / gt / methods.
//
const CmpResult = if (total == 1) bool else [total]bool;
inline fn cmpResult(v: @Vector(total, bool)) CmpResult {
@ -555,6 +491,9 @@ pub fn Tensor(
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } {
const RhsType = @TypeOf(rhs_q);
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
const rr: Vec = blk: {
@ -626,26 +565,6 @@ pub fn Tensor(
return !self.eqAll(other);
}
//
// Contraction generalised dot product / matrix multiply / einsum
//
// a.contract(b, axis_a, axis_b)
//
// Sums over dimension `axis_a` of `a` and `axis_b` of `b`.
// Requires a.shape[axis_a] == b.shape[axis_b] (checked at comptime).
//
// Result shape = a.shape \ axis_a ++ b.shape \ axis_b
// Result dims = a.dims + b.dims (exponents summed, as in mul)
// Result scales = finer of a, b
//
// Special cases:
// rank-1 × rank-1, axis 0 × 0 dot product (result shape {1})
// rank-2 × rank-2, axis 1 × 0 matrix multiply
// rank-1 × rank-2, axis 0 × 0 vectormatrix product
//
// All index arithmetic is comptime; runtime cost is the multiply-add loop only.
//
pub inline fn contract(
self: Self,
other: anytype,
@ -653,10 +572,10 @@ pub fn Tensor(
comptime axis_b: usize,
) blk: {
const OT = @TypeOf(other);
std.debug.assert(axis_a < rank);
std.debug.assert(axis_b < OT.rank);
std.debug.assert(shape_[axis_a] == OT.shape[axis_b]);
// Contracted-away free axes; empty joint scalar shape {1}
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds");
if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
const sa = shapeRemoveAxis(shape_, axis_a);
const sb = shapeRemoveAxis(OT.shape, axis_b);
const rs_raw = shapeCat(&sa, &sb);
@ -683,20 +602,16 @@ pub fn Tensor(
rs,
);
// Normalise scales before accumulation
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
// Precompute result strides from rs_raw (for coord decoding)
const rs_raw_strides = comptime shapeStrides(&rs_raw);
var result: ResultType = .{ .data = @splat(0) };
inline for (0..ResultType.total) |res_flat| {
// Decode result flat index into free coords using rs_raw layout.
// When rs_raw.len == 0, decodeFlatCoords returns [0]usize{} correct.
const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*;
@ -704,7 +619,6 @@ pub fn Tensor(
var acc: T = 0;
inline for (0..k) |ki| {
// Reinsert the contracted index into free coords full coord arrays
const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free);
const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free);
const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides);
@ -720,10 +634,6 @@ pub fn Tensor(
return result;
}
//
// Reduction helpers
//
/// Sum of squared elements. Cheaper than length(); use for ordering.
pub inline fn lengthSqr(self: Self) T {
return @reduce(.Add, self.data * self.data);
@ -749,10 +659,6 @@ pub fn Tensor(
return .{ .data = .{@reduce(.Mul, self.data)} };
}
//
// Formatting
//
pub fn formatNumber(
self: Self,
writer: *std.Io.Writer,
@ -809,15 +715,6 @@ pub fn Tensor(
//
// Tests
//
// Naming convention used throughout:
// Tensor(T, d, s, &.{1}) former Scalar
// Tensor(T, d, s, &.{N}) former Vector of length N
// .data[0] former .value()
// .mul(x) former .mulScalar(x) (x may be scalar Tensor or bare number)
// .div(x) former .divScalar(x)
// .eq(x) former .eqScalar(x) (broadcasts when x.total==1)
// .contract(other, 0, 0) former .dot(other) (for rank-1 tensors)
//
// Scalar tests
@ -1234,7 +1131,6 @@ test "Vector Comparisons" {
}
test "Vector vs Scalar broadcast comparison" {
// Replaces the old eqScalar / gtScalar now just eq / gt with a shape-{1} rhs.
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1});
@ -1259,7 +1155,6 @@ test "Vector contract — dot product (rank-1 × rank-1)" {
const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } };
const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } };
// work = force · pos
const work = force.contract(pos, 0, 0);
try std.testing.expectEqual(50.0, work.data[0]);
try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M));
@ -1268,23 +1163,12 @@ test "Vector contract — dot product (rank-1 × rank-1)" {
}
test "Vector contract — matrix multiply (rank-2 × rank-2)" {
// 2×3 matrix multiplied by 3×2 matrix 2×2 result
const A = Tensor(f32, .{}, .{}, &.{ 2, 3 });
const B = Tensor(f32, .{}, .{}, &.{ 3, 2 });
// A = [[1, 2, 3],
// [4, 5, 6]]
const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } };
// B = [[7, 8],
// [9, 10],
// [11, 12]]
const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } };
// C = A @ B (contract over axis 1 of A × axis 0 of B)
// C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
// C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64
// C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139
// C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154
const c = a.contract(b, 1, 0);
try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]);
try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]);
@ -1371,16 +1255,15 @@ test "Vector eq broadcast on dimensionless" {
test "Tensor idx helper and matrix access" {
const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 });
// Identity-like: set [0][0]=1, [1][1]=2, [2][2]=3
var m: Mat3x3 = Mat3x3.zero;
m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0;
m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0;
m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0;
try std.testing.expectEqual(1.0, m.data[0]); // [0][0]
try std.testing.expectEqual(2.0, m.data[4]); // [1][1] (stride 3 1*3+1=4)
try std.testing.expectEqual(3.0, m.data[8]); // [2][2] (2*3+2=8)
try std.testing.expectEqual(0.0, m.data[1]); // [0][1]
try std.testing.expectEqual(1.0, m.data[0]);
try std.testing.expectEqual(2.0, m.data[4]);
try std.testing.expectEqual(3.0, m.data[8]);
try std.testing.expectEqual(0.0, m.data[1]);
}
test "Tensor strides_arr correctness" {