diff --git a/src/Tensor.zig b/src/Tensor.zig index 7cb8a3c..6cc93ab 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -14,6 +14,15 @@ pub fn shapeTotal(comptime shape: []const usize) usize { return t; } +/// Check if two shapes are strictly identical. +pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { + if (a.len != b.len) return false; + for (a, 0..) |v, i| { + if (v != b[i]) return false; + } + return true; +} + /// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). /// e.g. shape {3, 4} → strides {4, 1} /// shape {2, 3, 4} → strides {12, 4, 1} @@ -126,10 +135,6 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales { } // ───────────────────────────────────────────────────────────────────────────── // File-scope RHS normalisation helpers -// -// Any bare comptime_int / comptime_float / runtime T used as an arithmetic -// or comparison RHS is wrapped into a dimensionless Tensor of shape {1}. -// Actual Tensor types are passed through unchanged. // ───────────────────────────────────────────────────────────────────────────── fn isTensor(comptime Rhs: type) bool { @@ -193,32 +198,6 @@ pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void { } } -/// ───────────────────────────────────────────────────────────────────────────── -/// Tensor — unified dimensioned ND type. -/// -/// T : element numeric type (f32, f64, i32, i128, …) -/// d_opt : SI dimension exponents -/// s_opt : unit scales -/// shape_ : compile-time shape -/// &.{1} → scalar -/// &.{3} → 3-vector -/// &.{4, 4} → 4×4 matrix -/// &.{3, 3, 3} → 3D field -/// -/// Storage: flat @Vector(total, T) where total = product(shape_). -/// All arithmetic operates on the flat vector directly → SIMD wherever possible. -/// -/// Shape-related comptime constants exposed on every Tensor type: -/// dims : Dimensions — SI exponent struct -/// scales : Scales — unit scale struct -/// shape : []const usize -/// rank : usize = shape.len -/// total : usize = product(shape) -/// strides_arr : [rank]usize — row-major strides -/// -/// Index helper: -/// Tensor.idx(.{row, col}) → flat index (comptime, no runtime cost) -/// ───────────────────────────────────────────────────────────────────────────── pub fn Tensor( comptime T: type, comptime d_opt: Dimensions.ArgOpts, @@ -226,8 +205,10 @@ pub fn Tensor( comptime shape_: []const usize, ) type { comptime { - std.debug.assert(shape_.len >= 1); - for (shape_) |s| std.debug.assert(s >= 1); + if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1)."); + for (shape_) |s| { + if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1."); + } } @setEvalBranchQuota(10_000_000); @@ -236,7 +217,6 @@ pub fn Tensor( const Vec = @Vector(_total, T); return struct { - /// Flat SIMD storage. All arithmetic operates here directly. data: Vec, const Self = @This(); @@ -250,27 +230,19 @@ pub fn Tensor( pub const strides_arr: [shape_.len]usize = _strides; pub const ISTENSOR = true; - // ─────────────────────────────────────────────────────────────── - // Index helper - // ─────────────────────────────────────────────────────────────── - /// Convert N-D coords (row-major) to flat index — fully comptime. /// Usage: Tensor.idx(.{row, col}) pub inline fn idx(comptime coords: [rank]usize) usize { comptime { var flat: usize = 0; for (0..rank) |i| { - std.debug.assert(coords[i] < shape[i]); + if (coords[i] >= shape[i]) @compileError("idx: Coordinate out of bounds"); flat += coords[i] * strides_arr[i]; } return flat; } } - // ─────────────────────────────────────────────────────────────── - // Constructors - // ─────────────────────────────────────────────────────────────── - /// Broadcast a single value across all elements. pub inline fn splat(v: T) Self { return .{ .data = @splat(v) }; @@ -279,19 +251,11 @@ pub fn Tensor( pub const zero: Self = splat(0); pub const one: Self = splat(1); - // ─────────────────────────────────────────────────────────────── - // GPU readiness - // ─────────────────────────────────────────────────────────────── - /// Return a mutable slice to the flat storage — zero-copy WebGPU buffer mapping. pub inline fn asSlice(self: *Self) []T { return @as([*]T, @ptrCast(&self.data))[0..total]; } - // ─────────────────────────────────────────────────────────────── - // Internal: RHS normalisation - // ─────────────────────────────────────────────────────────────── - inline fn RhsT(comptime Rhs: type) type { return RhsTensorType(T, Rhs); } @@ -299,10 +263,6 @@ pub fn Tensor( return toRhsTensor(T, r); } - // ─────────────────────────────────────────────────────────────── - // Internal: scalar broadcast (shape {1} → full Vec) - // ─────────────────────────────────────────────────────────────── - inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec { return if (comptime RhsType.total == 1 and total > 1) @splat(r.data[0]) @@ -310,12 +270,8 @@ pub fn Tensor( r.data; } - // ─────────────────────────────────────────────────────────────── - // Arithmetic - // ─────────────────────────────────────────────────────────────── - /// Element-wise add. Dimensions must match; scales resolve to finer. - /// RHS must have the same element count as self, or total == 1 (broadcast). + /// RHS must have the same shape as self, or total == 1 (broadcast). pub inline fn add(self: Self, r: anytype) Tensor( T, dims.argsOpt(), @@ -326,8 +282,8 @@ pub fn Tensor( const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime RhsType.total != total and RhsType.total != 1) - @compileError("Shape mismatch in add: element counts must match or RHS must be scalar (total=1)."); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -339,7 +295,8 @@ pub fn Tensor( return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; } - /// Element-wise subtract. Dimensions must match; scales resolve to finer. + /// Element-wise sub. Dimensions must match; scales resolve to finer. + /// RHS must have the same shape as self, or total == 1 (broadcast). pub inline fn sub(self: Self, r: anytype) Tensor( T, dims.argsOpt(), @@ -350,8 +307,8 @@ pub fn Tensor( const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime RhsType.total != total and RhsType.total != 1) - @compileError("Shape mismatch in sub."); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS."); const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -373,8 +330,8 @@ pub fn Tensor( ) { const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - if (comptime RhsType.total != total and RhsType.total != 1) - @compileError("Shape mismatch in mul."); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS."); const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); @@ -394,8 +351,8 @@ pub fn Tensor( ) { const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - if (comptime RhsType.total != total and RhsType.total != 1) - @compileError("Shape mismatch in div."); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS."); const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); @@ -411,10 +368,6 @@ pub fn Tensor( } } - // ─────────────────────────────────────────────────────────────── - // Unary - // ─────────────────────────────────────────────────────────────── - /// Absolute value of every element. pub inline fn abs(self: Self) Self { return .{ .data = @bitCast(@abs(self.data)) }; @@ -469,13 +422,9 @@ pub fn Tensor( return .{ .data = -self.data }; } - // ─────────────────────────────────────────────────────────────── - // Conversion - // ─────────────────────────────────────────────────────────────── - /// Convert to a compatible Tensor type. /// • Dimension mismatch → compile error. - /// • Dest.total must equal self.total, or Dest.total == 1 (scalar pattern). + /// • Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern). /// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. pub inline fn to( self: Self, @@ -485,20 +434,19 @@ pub fn Tensor( if (comptime !dims.eql(ActualDest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); - if (comptime Self == ActualDest) return self; + if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) + @compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar."); - comptime std.debug.assert(Dest.total == total or Dest.total == 1); + if (comptime Self == ActualDest) return self; const DestT = ActualDest.ValueType; const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); const DestVec = @Vector(total, DestT); - // ── Same numeric type ────────────────────────────────────── if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; - // Integer — branch prevents division-by-zero if (comptime ratio >= 1.0) { const mult: T = comptime @intFromFloat(@round(ratio)); return .{ .data = self.data *| @as(Vec, @splat(mult)) }; @@ -517,7 +465,6 @@ pub fn Tensor( } } - // ── Cross numeric type ───────────────────────────────────── var result: DestVec = undefined; inline for (0..total) |i| { const float_val: f64 = switch (comptime @typeInfo(T)) { @@ -535,17 +482,6 @@ pub fn Tensor( return .{ .data = result }; } - // ─────────────────────────────────────────────────────────────── - // Comparisons - // - // Return type: bool when total == 1 (scalar semantics) - // [total]bool when total > 1 (element-wise, flat-indexed) - // - // Whole-tensor equality check → eqAll / neAll (always returns bool). - // A shape {1} RHS is broadcast automatically, unifying the old - // eqScalar / gtScalar / … family into the plain eq / gt / … methods. - // ─────────────────────────────────────────────────────────────── - const CmpResult = if (total == 1) bool else [total]bool; inline fn cmpResult(v: @Vector(total, bool)) CmpResult { @@ -555,6 +491,9 @@ pub fn Tensor( /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { const RhsType = @TypeOf(rhs_q); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); + const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { @@ -626,26 +565,6 @@ pub fn Tensor( return !self.eqAll(other); } - // ─────────────────────────────────────────────────────────────── - // Contraction — generalised dot product / matrix multiply / einsum - // - // a.contract(b, axis_a, axis_b) - // - // Sums over dimension `axis_a` of `a` and `axis_b` of `b`. - // Requires a.shape[axis_a] == b.shape[axis_b] (checked at comptime). - // - // Result shape = a.shape \ axis_a ++ b.shape \ axis_b - // Result dims = a.dims + b.dims (exponents summed, as in mul) - // Result scales = finer of a, b - // - // Special cases: - // rank-1 × rank-1, axis 0 × 0 → dot product (result shape {1}) - // rank-2 × rank-2, axis 1 × 0 → matrix multiply - // rank-1 × rank-2, axis 0 × 0 → vector–matrix product - // - // All index arithmetic is comptime; runtime cost is the multiply-add loop only. - // ─────────────────────────────────────────────────────────────── - pub inline fn contract( self: Self, other: anytype, @@ -653,10 +572,10 @@ pub fn Tensor( comptime axis_b: usize, ) blk: { const OT = @TypeOf(other); - std.debug.assert(axis_a < rank); - std.debug.assert(axis_b < OT.rank); - std.debug.assert(shape_[axis_a] == OT.shape[axis_b]); - // Contracted-away free axes; empty joint → scalar shape {1} + if (axis_a >= rank) @compileError("contract: axis_a out of bounds"); + if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds"); + if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes"); + const sa = shapeRemoveAxis(shape_, axis_a); const sb = shapeRemoveAxis(OT.shape, axis_b); const rs_raw = shapeCat(&sa, &sb); @@ -683,20 +602,16 @@ pub fn Tensor( rs, ); - // Normalise scales before accumulation const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_); const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape); const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; - // Precompute result strides from rs_raw (for coord decoding) const rs_raw_strides = comptime shapeStrides(&rs_raw); var result: ResultType = .{ .data = @splat(0) }; inline for (0..ResultType.total) |res_flat| { - // Decode result flat index into free coords using rs_raw layout. - // When rs_raw.len == 0, decodeFlatCoords returns [0]usize{} — correct. const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*; @@ -704,7 +619,6 @@ pub fn Tensor( var acc: T = 0; inline for (0..k) |ki| { - // Reinsert the contracted index into free coords → full coord arrays const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free); const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free); const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); @@ -720,10 +634,6 @@ pub fn Tensor( return result; } - // ─────────────────────────────────────────────────────────────── - // Reduction helpers - // ─────────────────────────────────────────────────────────────── - /// Sum of squared elements. Cheaper than length(); use for ordering. pub inline fn lengthSqr(self: Self) T { return @reduce(.Add, self.data * self.data); @@ -749,10 +659,6 @@ pub fn Tensor( return .{ .data = .{@reduce(.Mul, self.data)} }; } - // ─────────────────────────────────────────────────────────────── - // Formatting - // ─────────────────────────────────────────────────────────────── - pub fn formatNumber( self: Self, writer: *std.Io.Writer, @@ -809,15 +715,6 @@ pub fn Tensor( // ═════════════════════════════════════════════════════════════════════════════ // Tests // ───────────────────────────────────────────────────────────────────────────── -// Naming convention used throughout: -// Tensor(T, d, s, &.{1}) → former Scalar -// Tensor(T, d, s, &.{N}) → former Vector of length N -// .data[0] → former .value() -// .mul(x) → former .mulScalar(x) (x may be scalar Tensor or bare number) -// .div(x) → former .divScalar(x) -// .eq(x) → former .eqScalar(x) (broadcasts when x.total==1) -// .contract(other, 0, 0) → former .dot(other) (for rank-1 tensors) -// ═════════════════════════════════════════════════════════════════════════════ // ─── Scalar tests ───────────────────────────────────────────────────────── @@ -1234,7 +1131,6 @@ test "Vector Comparisons" { } test "Vector vs Scalar broadcast comparison" { - // Replaces the old eqScalar / gtScalar — now just eq / gt with a shape-{1} rhs. const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); @@ -1259,7 +1155,6 @@ test "Vector contract — dot product (rank-1 × rank-1)" { const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } }; const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } }; - // work = force · pos const work = force.contract(pos, 0, 0); try std.testing.expectEqual(50.0, work.data[0]); try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M)); @@ -1268,23 +1163,12 @@ test "Vector contract — dot product (rank-1 × rank-1)" { } test "Vector contract — matrix multiply (rank-2 × rank-2)" { - // 2×3 matrix multiplied by 3×2 matrix → 2×2 result const A = Tensor(f32, .{}, .{}, &.{ 2, 3 }); const B = Tensor(f32, .{}, .{}, &.{ 3, 2 }); - // A = [[1, 2, 3], - // [4, 5, 6]] const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } }; - // B = [[7, 8], - // [9, 10], - // [11, 12]] const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } }; - // C = A @ B (contract over axis 1 of A × axis 0 of B) - // C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58 - // C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64 - // C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139 - // C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154 const c = a.contract(b, 1, 0); try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]); try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]); @@ -1371,16 +1255,15 @@ test "Vector eq broadcast on dimensionless" { test "Tensor idx helper and matrix access" { const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 }); - // Identity-like: set [0][0]=1, [1][1]=2, [2][2]=3 var m: Mat3x3 = Mat3x3.zero; m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0; m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0; m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0; - try std.testing.expectEqual(1.0, m.data[0]); // [0][0] - try std.testing.expectEqual(2.0, m.data[4]); // [1][1] (stride 3 → 1*3+1=4) - try std.testing.expectEqual(3.0, m.data[8]); // [2][2] (2*3+2=8) - try std.testing.expectEqual(0.0, m.data[1]); // [0][1] + try std.testing.expectEqual(1.0, m.data[0]); + try std.testing.expectEqual(2.0, m.data[4]); + try std.testing.expectEqual(3.0, m.data[8]); + try std.testing.expectEqual(0.0, m.data[1]); } test "Tensor strides_arr correctness" {