From 934a40fe1a643534488b43f35d9139de476f6e1b Mon Sep 17 00:00:00 2001 From: Adrien Bouvais Date: Sun, 26 Apr 2026 22:16:25 +0200 Subject: [PATCH] Basic untested tensor --- src/Quantity.zig | 1405 ++++++++++++++++++++++++---------------------- 1 file changed, 747 insertions(+), 658 deletions(-) diff --git a/src/Quantity.zig b/src/Quantity.zig index 4e30ec9..f0edea6 100644 --- a/src/Quantity.zig +++ b/src/Quantity.zig @@ -5,240 +5,385 @@ const UnitScale = Scales.UnitScale; const Dimensions = @import("Dimensions.zig"); const Dimension = Dimensions.Dimension; -// --------------------------------------------------------------------------- -// Quantity — single unified dimensioned type. -// -// T : element numeric type (f32, f64, i32, i128, …) -// N : lane count (1 = Scalar, >1 = Vector) -// d : SI dimension exponents -// s : unit scales -// -// All arithmetic is performed directly on the underlying @Vector(N, T), so -// the compiler can emit SIMD instructions wherever the target supports them. -// -// Thin aliases (same type identity, no wrapper overhead): -// Scalar(T, d, s) ≡ Quantity(T, 1, d, s) -// Vector(N, Q) ≡ Quantity(Q.ValueType, N, Q.dims.argsOpt(), Q.scales.argsOpt()) -// --------------------------------------------------------------------------- -// -// To investigate: -// @reduce(comptime op: std.builtin.ReduceOp, value: anytype) E -// @select(comptime T: type, pred: @Vector(len, bool), a: @Vector(len, T), b: @Vector(len, T)) @Vector(len, T) -// @shuffle(comptime E: type, a: @Vector(a_len, E), b: @Vector(b_len, E), comptime mask: @Vector(mask_len, i32)) @Vector(mask_len, E) +// ───────────────────────────────────────────────────────────────────────────── +// Comptime shape utilities +// ───────────────────────────────────────────────────────────────────────────── -pub fn Quantity( +pub fn shapeTotal(comptime shape: []const usize) usize { + var t: usize = 1; + for (shape) |s| t *= s; + return t; +} + +/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). +/// e.g. shape {3, 4} → strides {4, 1} +/// shape {2, 3, 4} → strides {12, 4, 1} +pub fn shapeStrides(comptime shape: []const usize) [shape.len]usize { + var st: [shape.len]usize = undefined; + if (shape.len == 0) return st; + st[shape.len - 1] = 1; + if (shape.len > 1) { + var i: usize = shape.len - 1; + while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i]; + } + return st; +} + +/// Return a copy of `shape` with the element at `axis` removed. +pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [shape.len - 1]usize { + var out: [shape.len - 1]usize = undefined; + var j: usize = 0; + for (shape, 0..) |v, i| { + if (i != axis) { out[j] = v; j += 1; } + } + return out; +} + +/// Concatenate two compile-time slices. +pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b.len]usize { + var out: [a.len + b.len]usize = undefined; + for (a, 0..) |v, i| out[i] = v; + for (b, 0..) |v, i| out[a.len + i] = v; + return out; +} + +/// Decode a flat row-major index into N-D coordinates. +/// Called only in comptime contexts (all arguments are comptime). +pub fn decodeFlatCoords( + comptime flat: usize, + comptime n: usize, + comptime strd: [n]usize, +) [n]usize { + var coords: [n]usize = undefined; + var tmp = flat; + for (0..n) |i| { + coords[i] = if (strd[i] == 0) 0 else tmp / strd[i]; + tmp = if (strd[i] == 0) 0 else tmp % strd[i]; + } + return coords; +} + +/// Encode N-D coordinates into a flat row-major index. +/// Called only in comptime contexts. +pub fn encodeFlatCoords( + comptime coords: []const usize, + comptime n: usize, + comptime strd: [n]usize, +) usize { + var flat: usize = 0; + for (0..n) |i| flat += coords[i] * strd[i]; + return flat; +} + +/// Rebuild a full coordinate array by inserting `val` at `axis` into `free`. +/// `free` holds the remaining (non-contracted) coordinates in order. +pub fn insertAxis( + comptime n: usize, + comptime axis: usize, + comptime val: usize, + comptime free: []const usize, +) [n]usize { + var out: [n]usize = undefined; + var fi: usize = 0; + for (0..n) |i| { + if (i == axis) { + out[i] = val; + } else { + out[i] = free[fi]; + fi += 1; + } + } + return out; +} + +// ───────────────────────────────────────────────────────────────────────────── +// File-scope RHS normalisation helpers +// +// Any bare comptime_int / comptime_float / runtime T used as an arithmetic +// or comparison RHS is wrapped into a dimensionless Tensor of shape {1}. +// Actual Tensor types are passed through unchanged. +// ───────────────────────────────────────────────────────────────────────────── + +fn RhsTensorType(comptime T: type, comptime Rhs: type) type { + if (@hasDecl(Rhs, "ISTENSOR")) return Rhs; + return Tensor(T, .{}, .{}, &.{1}); +} + +fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { + const Rhs = @TypeOf(r); + if (comptime @hasDecl(Rhs, "ISTENSOR")) return r; + const scalar: T = switch (comptime @typeInfo(Rhs)) { + .comptime_int => switch (comptime @typeInfo(T)) { + .float => @as(T, @floatFromInt(r)), + else => @as(T, r), + }, + .comptime_float => switch (comptime @typeInfo(T)) { + .int => @as(T, @intFromFloat(r)), + else => @as(T, r), + }, + .int => switch (comptime @typeInfo(T)) { + .float => @floatFromInt(r), + else => @intCast(r), + }, + .float => switch (comptime @typeInfo(T)) { + .int => @intFromFloat(r), + else => @floatCast(r), + }, + else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)), + }; + return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} }; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tensor — unified dimensioned ND type. +// +// T : element numeric type (f32, f64, i32, i128, …) +// d_opt : SI dimension exponents +// s_opt : unit scales +// shape_ : compile-time shape +// &.{1} → scalar +// &.{3} → 3-vector +// &.{4, 4} → 4×4 matrix +// &.{3, 3, 3} → 3D field +// +// Storage: flat @Vector(total, T) where total = product(shape_). +// All arithmetic operates on the flat vector directly → SIMD wherever possible. +// +// Shape-related comptime constants exposed on every Tensor type: +// dims : Dimensions — SI exponent struct +// scales : Scales — unit scale struct +// shape : []const usize +// rank : usize = shape.len +// total : usize = product(shape) +// strides_arr : [rank]usize — row-major strides +// +// Index helper: +// Tensor.idx(.{row, col}) → flat index (comptime, no runtime cost) +// +// GPU readiness: +// tensor.asSlice() → []T (zero-copy pointer to the flat @Vector storage) +// +// Contraction (replaces dot / cross / matmul): +// a.contract(b, axis_a, axis_b) +// For rank-1 × rank-1 this is the dot product. +// For rank-2 × rank-2 with axis_a=1, axis_b=0 this is matrix multiply. +// +// Removed from Quantity: +// Scalar / Vector aliases, Vec3 / ScalarType, .value(), .vec(), .vec3(), +// dot(), cross(), mulScalar(), divScalar(), eqScalar() and friends. +// Use Tensor(..., &.{1}), .data[0], mul(), div(), eq() respectively. +// ───────────────────────────────────────────────────────────────────────────── + +pub fn Tensor( comptime T: type, - comptime N: usize, comptime d_opt: Dimensions.ArgOpts, comptime s_opt: Scales.ArgOpts, + comptime shape_: []const usize, ) type { - comptime std.debug.assert(N >= 1); + comptime { + std.debug.assert(shape_.len >= 1); + for (shape_) |s| std.debug.assert(s >= 1); + } @setEvalBranchQuota(10_000_000); - // Local shorthand for the SIMD vector type used in storage. - const Vec = @Vector(N, T); + const _total: usize = comptime shapeTotal(shape_); + const _strides = comptime shapeStrides(shape_); + const Vec = @Vector(_total, T); return struct { - /// SIMD-friendly storage. Arithmetic operates here directly. + /// Flat SIMD storage. All arithmetic operates here directly. data: Vec, const Self = @This(); - pub const ValueType: type = T; - pub const Len: usize = N; - pub const dims: Dimensions = Dimensions.init(d_opt); - pub const scales: Scales = Scales.init(s_opt); - pub const ISQUANTITY = true; + pub const ValueType : type = T; + pub const dims : Dimensions = Dimensions.init(d_opt); + pub const scales : Scales = Scales.init(s_opt); + pub const shape : []const usize = shape_; + pub const rank : usize = shape_.len; + pub const total : usize = _total; + pub const strides_arr: [shape_.len]usize = _strides; + pub const ISTENSOR = true; - /// Scalar variant of this quantity (lane=1). Returned by dot(), product(), etc. - pub const ScalarType: type = Quantity(T, 1, d_opt, s_opt); - /// Convenience: a 3-lane vector of the same dimension/scale. - pub const Vec3: type = Quantity(T, 3, d_opt, s_opt); + // ─────────────────────────────────────────────────────────────── + // Index helper + // ─────────────────────────────────────────────────────────────── - // --------------------------------------------------------------- + /// Convert N-D coords (row-major) to flat index — fully comptime. + /// Usage: Tensor.idx(.{row, col}) + pub fn idx(comptime coords: [rank]usize) usize { + comptime { + var flat: usize = 0; + for (0..rank) |i| { + std.debug.assert(coords[i] < shape[i]); + flat += coords[i] * strides_arr[i]; + } + return flat; + } + } + + // ─────────────────────────────────────────────────────────────── // Constructors - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── - /// Broadcast a single value across all N lanes. + /// Broadcast a single value across all elements. pub inline fn splat(v: T) Self { return .{ .data = @splat(v) }; } pub const zero: Self = splat(0); - pub const one: Self = splat(1); + pub const one: Self = splat(1); - // --------------------------------------------------------------- - // Scalar-only helpers (N = 1) - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── + // GPU readiness + // ─────────────────────────────────────────────────────────────── - /// Return the single scalar value. Compile error when N ≠ 1. - pub inline fn value(self: Self) T { - comptime if (N != 1) - @compileError(".value() is only available on Scalar (N=1)."); - return self.data[0]; + /// Return a mutable slice to the flat storage — zero-copy WebGPU buffer mapping. + pub inline fn asSlice(self: *Self) []T { + return @as([*]T, @ptrCast(&self.data))[0..total]; } - /// Expand this scalar into a len-lane vector by splatting. - pub inline fn vec(self: Self, comptime len: usize) Quantity(T, len, d_opt, s_opt) { - comptime if (N != 1) - @compileError(".vec() is only available on Scalar (N=1)."); - return .{ .data = @splat(self.data[0]) }; - } - - pub inline fn vec3(self: Self) Vec3 { - return self.vec(3); - } - - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Internal: RHS normalisation - // - // • For N=1 (Scalar context): bare numbers → Quantity(T, 1, dimless, none) - // • For N>1 (Vector context): bare numbers → Quantity(T, N, dimless, none) - // Quantity(T,1) → broadcast (handled in each op) - // - // A bare number used as rhs is ALWAYS treated as dimensionless. - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── - inline fn RhsT(comptime Rhs: type) type { - return hlp.rhsQuantityType(T, N, Rhs); - } - inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { - return hlp.toRhsQuantity(T, N, r); - } + inline fn RhsT(comptime Rhs: type) type { return RhsTensorType(T, Rhs); } + inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { return toRhsTensor(T, r); } - /// Scalar rhs (N=1) — used by mulScalar / divScalar / eqScalar etc. - inline fn ScalarRhsT(comptime Rhs: type) type { - return hlp.rhsQuantityType(T, 1, Rhs); - } - inline fn scalarRhs(r: anytype) ScalarRhsT(@TypeOf(r)) { - return hlp.toRhsQuantity(T, 1, r); - } + // ─────────────────────────────────────────────────────────────── + // Internal: scalar broadcast (shape {1} → full Vec) + // ─────────────────────────────────────────────────────────────── - // --------------------------------------------------------------- - // Internal: broadcast helper - // - // When an N=1 rhs is used in an N>1 operation, splat it. - // --------------------------------------------------------------- inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec { - if (comptime RhsType.Len == 1 and N > 1) - return @splat(r.data[0]) + return if (comptime RhsType.total == 1 and total > 1) + @splat(r.data[0]) else - return r.data; + r.data; } - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Arithmetic - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── /// Element-wise add. Dimensions must match; scales resolve to finer. - /// For N=1: rhs may be a bare number (treated as dimensionless). - /// For N>1: rhs must be a same-length Quantity. - pub inline fn add(self: Self, r: anytype) Quantity( + /// RHS must have the same element count as self, or total == 1 (broadcast). + pub inline fn add(self: Self, r: anytype) Tensor( T, - N, dims.argsOpt(), hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime N > 1 and RhsType.Len != N) - @compileError("Vector add requires same-length Quantity."); + if (comptime RhsType.total != total and RhsType.total != 1) + @compileError("Shape mismatch in add: element counts must match or RHS must be scalar (total=1)."); - const TargetType = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); + const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; - const rr: Vec = if (comptime RhsType == TargetType) rhs_q.data else rhs_q.to(TargetType).data; + const rr: Vec = blk: { + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + break :blk broadcastToVec(RhsNorm, rn); + }; return .{ .data = if (comptime hlp.isInt(T)) l +| rr else l + rr }; } /// Element-wise subtract. Dimensions must match; scales resolve to finer. - pub inline fn sub(self: Self, r: anytype) Quantity( + pub inline fn sub(self: Self, r: anytype) Tensor( T, - N, dims.argsOpt(), hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime N > 1 and RhsType.Len != N) - @compileError("Vector sub requires same-length Quantity."); + if (comptime RhsType.total != total and RhsType.total != 1) + @compileError("Shape mismatch in sub."); - const TargetType = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); + const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; - const rr: Vec = if (comptime RhsType == TargetType) rhs_q.data else rhs_q.to(TargetType).data; + const rr: Vec = blk: { + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + break :blk broadcastToVec(RhsNorm, rn); + }; return .{ .data = if (comptime hlp.isInt(T)) l -| rr else l - rr }; } - /// Element-wise multiply. Dimension exponents are summed. - /// An N=1 rhs on an N>1 self is automatically broadcast (scalar × vector). - pub inline fn mul(self: Self, r: anytype) Quantity( + /// Element-wise multiply. Dimension exponents summed. + /// Shape {1} RHS is automatically broadcast across all elements. + pub inline fn mul(self: Self, r: anytype) Tensor( T, - N, dims.add(RhsT(@TypeOf(r)).dims).argsOpt(), hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - const SelfNorm = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const RhsNorm = Quantity(T, RhsType.Len, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); + if (comptime RhsType.total != total and RhsType.total != 1) + @compileError("Shape mismatch in mul."); + + const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rr: Vec = broadcastToVec(RhsNorm, rr_base); return .{ .data = if (comptime hlp.isInt(T)) l *| rr else l * rr }; } - /// Element-wise divide. Dimension exponents are subtracted. - /// An N=1 rhs on an N>1 self is automatically broadcast. - pub inline fn div(self: Self, r: anytype) Quantity( + /// Element-wise divide. Dimension exponents subtracted. + /// Shape {1} RHS is automatically broadcast across all elements. + pub inline fn div(self: Self, r: anytype) Tensor( T, - N, dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(), hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - const SelfNorm = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const RhsNorm = Quantity(T, RhsType.Len, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); + if (comptime RhsType.total != total and RhsType.total != 1) + @compileError("Shape mismatch in div."); + + const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rr: Vec = broadcastToVec(RhsNorm, rr_base); if (comptime hlp.isInt(T)) { var result: Vec = undefined; - inline for (0..N) |i| result[i] = @divTrunc(l[i], rr[i]); + inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]); return .{ .data = result }; } else { return .{ .data = l / rr }; } } - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Unary - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── - /// Absolute value of every lane. Uses native `@abs` (SIMD for floats & ints). + /// Absolute value of every element. pub inline fn abs(self: Self) Self { return .{ .data = @bitCast(@abs(self.data)) }; } - /// Raise every lane to a comptime integer exponent. - /// Repeated SIMD multiply — good for small exponents. - pub inline fn pow(self: Self, comptime exp: comptime_int) Quantity( + /// Raise every element to a comptime integer exponent. + pub inline fn pow(self: Self, comptime exp: comptime_int) Tensor( T, - N, dims.scale(exp).argsOpt(), scales.argsOpt(), + shape_, ) { if (comptime hlp.isInt(T)) { - // No SIMD pow for integers — element-wise std.math.powi. var result: Vec = undefined; - inline for (0..N) |i| + inline for (0..total) |i| result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T); return .{ .data = result }; } else { - // Float: unrolled SIMD multiplications. const abs_exp = comptime @abs(exp); var result: Vec = @splat(1); comptime var i = 0; @@ -248,130 +393,127 @@ pub fn Quantity( } } - /// Square root of every lane. All dimension exponents must be even. - pub inline fn sqrt(self: Self) Quantity( + /// Square root of every element. All dimension exponents must be even. + pub inline fn sqrt(self: Self) Tensor( T, - N, dims.div(2).argsOpt(), scales.argsOpt(), + shape_, ) { if (comptime !dims.isSquare()) @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); if (comptime @typeInfo(T) == .float) { return .{ .data = @sqrt(self.data) }; } else { - // Integer sqrt is not SIMD-able — element-wise. var result: Vec = undefined; const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); - inline for (0..N) |i| { + inline for (0..total) |i| { const v = self.data[i]; - if (v < 0) - result[i] = 0 - else - result[i] = @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); + result[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); } return .{ .data = result }; } } - /// Negate every lane. + /// Negate every element. pub inline fn negate(self: Self) Self { return .{ .data = -self.data }; } - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Conversion - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── - /// Convert to a compatible quantity type. Dimension mismatch is a compile error. - /// Dest can have the same Len as this quantity, or Len == 1 (in which case it - /// will be automatically recast to this quantity's Len). - /// The scale ratio is computed entirely at comptime; the only runtime cost is - /// a SIMD multiply-by-splat (or element-wise cast for cross-numeric-type conversions). + /// Convert to a compatible Tensor type. + /// • Dimension mismatch → compile error. + /// • Dest.total must equal self.total, or Dest.total == 1 (scalar pattern). + /// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. pub inline fn to( self: Self, comptime Dest: type, - ) Quantity(Dest.ValueType, N, Dest.dims.argsOpt(), Dest.scales.argsOpt()) { - const ActualDest = Quantity(Dest.ValueType, N, Dest.dims.argsOpt(), Dest.scales.argsOpt()); + ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { + const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); if (comptime !dims.eql(ActualDest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); - if (comptime Self == ActualDest) return self; - // Allow Dest to be exactly matching Len or a Scalar (Len == 1) - comptime std.debug.assert(Dest.Len == N or Dest.Len == 1); + comptime std.debug.assert(Dest.total == total or Dest.total == 1); - const DestT = ActualDest.ValueType; - const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); - const DestVec = @Vector(N, DestT); + const DestT = ActualDest.ValueType; + const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); + const DestVec = @Vector(total, DestT); - // ── Same numeric type path ── + // ── Same numeric type ────────────────────────────────────── if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; - // Integer logic: Branching prevents division by zero errors + // Integer — branch prevents division-by-zero if (comptime ratio >= 1.0) { - // Upscaling (e.g., km -> m, ratio = 1000) const mult: T = comptime @intFromFloat(@round(ratio)); return .{ .data = self.data *| @as(Vec, @splat(mult)) }; } else { - // Downscaling (e.g., m -> km, ratio = 0.001) const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); + const half: T = comptime @divTrunc(div_val, 2); var result: DestVec = undefined; - const half: T = comptime @divTrunc(div_val, 2); - - inline for (0..N) |i| { + inline for (0..total) |i| { const val = self.data[i]; - // Rounding division for integers - result[i] = if (val >= 0) @divTrunc(val + half, div_val) else @divTrunc(val - half, div_val); + result[i] = if (val >= 0) + @divTrunc(val + half, div_val) + else + @divTrunc(val - half, div_val); } return .{ .data = result }; } } - // ── Cross-numeric-type (unchanged) ── + // ── Cross numeric type ───────────────────────────────────── var result: DestVec = undefined; - inline for (0..N) |i| { + inline for (0..total) |i| { const float_val: f64 = switch (comptime @typeInfo(T)) { .float => @floatCast(self.data[i]), - .int => @floatFromInt(self.data[i]), - else => unreachable, + .int => @floatFromInt(self.data[i]), + else => unreachable, }; const scaled = float_val * ratio; result[i] = switch (comptime @typeInfo(DestT)) { .float => @floatCast(scaled), - .int => @intFromFloat(@round(scaled)), - else => unreachable, + .int => @intFromFloat(@round(scaled)), + else => unreachable, }; } return .{ .data = result }; } - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Comparisons // - // Return type: bool when N = 1 (Scalar semantics) - // [N]bool when N > 1 (Vector semantics, element-wise) + // Return type: bool when total == 1 (scalar semantics) + // [total]bool when total > 1 (element-wise, flat-indexed) // - // Whole-vector "all equal/any differ" → use eqAll / neAll. - // Broadcast scalar comparison → use eqScalar / gtScalar / … - // --------------------------------------------------------------- + // Whole-tensor equality check → eqAll / neAll (always returns bool). + // A shape {1} RHS is broadcast automatically, unifying the old + // eqScalar / gtScalar / … family into the plain eq / gt / … methods. + // ─────────────────────────────────────────────────────────────── - const CmpResult = if (N == 1) bool else [N]bool; + const CmpResult = if (total == 1) bool else [total]bool; - inline fn cmpResult(v: @Vector(N, bool)) CmpResult { - return if (comptime N == 1) v[0] else @as([N]bool, v); + inline fn cmpResult(v: @Vector(total, bool)) CmpResult { + return if (comptime total == 1) v[0] else @as([total]bool, v); } + /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { - const RhsType = @TypeOf(rhs_q); - const TargetType = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - return .{ - .l = if (comptime Self == TargetType) self.data else self.to(TargetType).data, - .r = if (comptime RhsType == TargetType) rhs_q.data else rhs_q.to(TargetType).data, + const RhsType = @TypeOf(rhs_q); + const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); + const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; + const rr: Vec = blk: { + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + break :blk broadcastToVec(RhsNorm, rn); }; + return .{ .l = l, .r = rr }; } pub inline fn eq(self: Self, r: anytype) CmpResult { @@ -422,11 +564,7 @@ pub fn Quantity( return cmpResult(p.l <= p.r); } - // --------------------------------------------------------------- - // Vector whole-quantity comparisons (N > 1 intended, but work for N=1 too) - // --------------------------------------------------------------- - - /// True iff every lane is equal after scale resolution. + /// True iff every element is equal after scale resolution. pub inline fn eqAll(self: Self, other: anytype) bool { if (comptime !dims.eql(@TypeOf(other).dims)) @compileError("Dimension mismatch in eqAll."); @@ -434,121 +572,115 @@ pub fn Quantity( return @reduce(.And, p.l == p.r); } + /// True iff any element differs after scale resolution. pub inline fn neAll(self: Self, other: anytype) bool { return !self.eqAll(other); } - // --------------------------------------------------------------- - // Vector broadcast-scalar comparisons (always returns [N]bool) - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── + // Contraction — generalised dot product / matrix multiply / einsum + // + // a.contract(b, axis_a, axis_b) + // + // Sums over dimension `axis_a` of `a` and `axis_b` of `b`. + // Requires a.shape[axis_a] == b.shape[axis_b] (checked at comptime). + // + // Result shape = a.shape \ axis_a ++ b.shape \ axis_b + // Result dims = a.dims + b.dims (exponents summed, as in mul) + // Result scales = finer of a, b + // + // Special cases: + // rank-1 × rank-1, axis 0 × 0 → dot product (result shape {1}) + // rank-2 × rank-2, axis 1 × 0 → matrix multiply + // rank-1 × rank-2, axis 0 × 0 → vector–matrix product + // + // All index arithmetic is comptime; runtime cost is the multiply-add loop only. + // ─────────────────────────────────────────────────────────────── - inline fn broadcastScalarForCmp(self: Self, scalar: anytype) struct { l: Vec, r: Vec } { - const s = scalarRhs(scalar); - const SN = @TypeOf(s); - const TargetScalar = Quantity(T, 1, dims.argsOpt(), hlp.finerScales(Self, SN).argsOpt()); - const TargetSelf = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, SN).argsOpt()); - const s_val: T = if (comptime SN == TargetScalar) s.data[0] else s.to(TargetScalar).data[0]; - const l: Vec = if (comptime Self == TargetSelf) self.data else self.to(TargetSelf).data; - return .{ .l = l, .r = @splat(s_val) }; + pub inline fn contract( + self: Self, + other: anytype, + comptime axis_a: usize, + comptime axis_b: usize, + ) blk: { + const OT = @TypeOf(other); + comptime std.debug.assert(axis_a < rank); + comptime std.debug.assert(axis_b < OT.rank); + comptime std.debug.assert(shape_[axis_a] == OT.shape[axis_b]); + // Contracted-away free axes; empty joint → scalar shape {1} + const sa = shapeRemoveAxis(shape_, axis_a); + const sb = shapeRemoveAxis(OT.shape, axis_b); + const rs_raw = shapeCat(&sa, &sb); + const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw; + break :blk Tensor( + T, + dims.add(OT.dims).argsOpt(), + hlp.finerScales(Self, OT).argsOpt(), + rs, + ); + } { + const OT = @TypeOf(other); + const k: usize = comptime shape_[axis_a]; // contraction dimension + + const sa = comptime shapeRemoveAxis(shape_, axis_a); + const sb = comptime shapeRemoveAxis(OT.shape, axis_b); + const rs_raw = comptime shapeCat(&sa, &sb); + const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; + + const ResultType = Tensor( + T, + dims.add(OT.dims).argsOpt(), + hlp.finerScales(Self, OT).argsOpt(), + rs, + ); + + // Normalise scales before accumulation + const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), shape_); + const OtherNorm = Tensor(T, OT.dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), OT.shape); + const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; + + // Precompute result strides from rs_raw (for coord decoding) + const rs_raw_strides = comptime shapeStrides(&rs_raw); + + var result: ResultType = .{ .data = @splat(0) }; + + inline for (0..ResultType.total) |res_flat| { + // Decode result flat index into free coords using rs_raw layout. + // When rs_raw.len == 0, decodeFlatCoords returns [0]usize{} — correct. + const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); + + const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*; + const b_free: [sb.len]usize = comptime res_coords[sa.len..].*; + + var acc: T = 0; + inline for (0..k) |ki| { + // Reinsert the contracted index into free coords → full coord arrays + const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free); + const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free); + const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); + const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); + + if (comptime hlp.isInt(T)) + acc +|= a_data[a_flat] *| b_data[b_flat] + else + acc += a_data[a_flat] * b_data[b_flat]; + } + result.data[res_flat] = acc; + } + return result; } - pub inline fn eqScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l == p.r); - } + // ─────────────────────────────────────────────────────────────── + // Reduction helpers + // ─────────────────────────────────────────────────────────────── - pub inline fn neScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l != p.r); - } - - pub inline fn gtScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l > p.r); - } - - pub inline fn gteScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l >= p.r); - } - - pub inline fn ltScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l < p.r); - } - - pub inline fn lteScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l <= p.r); - } - - // --------------------------------------------------------------- - // Vector broadcast multiply / divide - // (These are explicit aliases for mul/div with an N=1 rhs; kept for - // clarity and backward-compat with the old Vector API.) - // --------------------------------------------------------------- - - pub inline fn mulScalar(self: Self, scalar: anytype) Quantity( - T, - N, - dims.add(ScalarRhsT(@TypeOf(scalar)).dims).argsOpt(), - hlp.finerScales(Self, ScalarRhsT(@TypeOf(scalar))).argsOpt(), - ) { - return self.mul(scalar); - } - - pub inline fn divScalar(self: Self, scalar: anytype) Quantity( - T, - N, - dims.sub(ScalarRhsT(@TypeOf(scalar)).dims).argsOpt(), - hlp.finerScales(Self, ScalarRhsT(@TypeOf(scalar))).argsOpt(), - ) { - return self.div(scalar); - } - - // --------------------------------------------------------------- - // Vector geometric operations - // --------------------------------------------------------------- - - /// Dot product — sum of element-wise products; returns a Scalar. - pub inline fn dot(self: Self, other: anytype) Quantity( - T, - 1, - dims.add(@TypeOf(other).dims).argsOpt(), - hlp.finerScales(Self, @TypeOf(other)).argsOpt(), - ) { - const Tr = @TypeOf(other); - const SelfNorm = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, Tr).argsOpt()); - const OtherNorm = Quantity(T, N, Tr.dims.argsOpt(), hlp.finerScales(Self, Tr).argsOpt()); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const r2: Vec = if (comptime Tr == OtherNorm) other.data else other.to(OtherNorm).data; - return .{ .data = .{@reduce(.Add, l * r2)} }; - } - - /// 3D cross product. Requires N = 3. - pub inline fn cross(self: Self, other: anytype) Quantity( - T, - 3, - dims.add(@TypeOf(other).dims).argsOpt(), - hlp.finerScales(Self, @TypeOf(other)).argsOpt(), - ) { - comptime if (N != 3) @compileError("cross() requires len=3."); - const a = self.data; - const b = other.data; - return .{ .data = .{ - a[1] * b[2] - a[2] * b[1], - a[2] * b[0] - a[0] * b[2], - a[0] * b[1] - a[1] * b[0], - } }; - } - - /// Sum of squared components. Cheaper than length(); use for comparisons. + /// Sum of squared elements. Cheaper than length(); use for ordering. pub inline fn lengthSqr(self: Self) T { return @reduce(.Add, self.data * self.data); } - /// Euclidean length. Float types use SIMD @reduce → @sqrt. - /// Integer types use integer sqrt (truncated). + /// Euclidean length (L2 norm). pub inline fn length(self: Self) T { const sq = self.lengthSqr(); if (comptime @typeInfo(T) == .int) { @@ -558,43 +690,46 @@ pub fn Quantity( return @sqrt(sq); } - /// Product of all components. Result dimension is (original dim × N). - pub inline fn product(self: Self) Quantity(T, 1, dims.scale(N).argsOpt(), scales.argsOpt()) { + /// Product of all elements. Result has shape {1}; dimension exponent × total. + pub inline fn product(self: Self) Tensor( + T, + dims.scale(@as(comptime_int, total)).argsOpt(), + scales.argsOpt(), + &.{1}, + ) { return .{ .data = .{@reduce(.Mul, self.data)} }; } - // --------------------------------------------------------------- - // Formatting (unchanged from old Scalar / Vector) - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── + // Formatting + // ─────────────────────────────────────────────────────────────── pub fn formatNumber( self: Self, writer: *std.Io.Writer, options: std.fmt.Number, ) !void { - if (comptime N == 1) { - // Scalar-style: just print the value + units + if (comptime total == 1) { switch (@typeInfo(T)) { .float, .comptime_float => try writer.printFloat(self.data[0], options), - .int, .comptime_int => try writer.printInt(self.data[0], 10, .lower, .{ - .width = options.width, + .int, .comptime_int => try writer.printInt(self.data[0], 10, .lower, .{ + .width = options.width, .alignment = options.alignment, - .fill = options.fill, + .fill = options.fill, .precision = options.precision, }), else => unreachable, } } else { - // Vector-style: (v0, v1, …) + units try writer.writeAll("("); - inline for (0..N) |i| { + inline for (0..total) |i| { if (i > 0) try writer.writeAll(", "); switch (@typeInfo(T)) { .float, .comptime_float => try writer.printFloat(self.data[i], options), - .int, .comptime_int => try writer.printInt(self.data[i], 10, .lower, .{ - .width = options.width, + .int, .comptime_int => try writer.printInt(self.data[i], 10, .lower, .{ + .width = options.width, .alignment = options.alignment, - .fill = options.fill, + .fill = options.fill, .precision = options.precision, }), else => unreachable, @@ -622,43 +757,49 @@ pub fn Quantity( }; } -// Scalar tests +// ═════════════════════════════════════════════════════════════════════════════ +// Tests +// ───────────────────────────────────────────────────────────────────────────── +// Naming convention used throughout: +// Tensor(T, d, s, &.{1}) → former Scalar +// Tensor(T, d, s, &.{N}) → former Vector of length N +// .data[0] → former .value() +// .mul(x) → former .mulScalar(x) (x may be scalar Tensor or bare number) +// .div(x) → former .divScalar(x) +// .eq(x) → former .eqScalar(x) (broadcasts when x.total==1) +// .contract(other, 0, 0) → former .dot(other) (for rank-1 tensors) +// ═════════════════════════════════════════════════════════════════════════════ -pub fn Scalar(comptime T: type, comptime d: Dimensions.ArgOpts, comptime s: Scales.ArgOpts) type { - return Quantity(T, 1, d, s); -} +// ─── Scalar tests ───────────────────────────────────────────────────────── test "Scalar initiat" { - const Meter = Scalar(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }); - const Second = Scalar(f32, .{ .T = 1 }, .{ .T = .n }); + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{ .T = .n }, &.{1}); const distance = Meter.splat(10); - const time = Second.splat(2); + const time = Second.splat(2); - try std.testing.expectEqual(10, distance.value()); - try std.testing.expectEqual(2, time.value()); + try std.testing.expectEqual(10, distance.data[0]); + try std.testing.expectEqual(2, time.data[0]); } test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const KiloMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); const m1000 = Meter.splat(1000); - const km1 = KiloMeter.splat(1); - const km2 = KiloMeter.splat(2); + const km1 = KiloMeter.splat(1); + const km2 = KiloMeter.splat(2); - // Equal / Not Equal try std.testing.expect(m1000.eq(km1)); try std.testing.expect(km1.eq(m1000)); try std.testing.expect(km2.ne(m1000)); - // Greater Than / Greater Than or Equal try std.testing.expect(km2.gt(m1000)); try std.testing.expect(km2.gt(km1)); try std.testing.expect(km1.gte(m1000)); try std.testing.expect(km2.gte(m1000)); - // Less Than / Less Than or Equal try std.testing.expect(m1000.lt(km2)); try std.testing.expect(km1.lt(km2)); try std.testing.expect(km1.lte(m1000)); @@ -666,71 +807,65 @@ test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { } test "Scalar Add" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const distance = Meter.splat(10); + const distance = Meter.splat(10); const distance2 = Meter.splat(20); + const added = distance.add(distance2); + try std.testing.expectEqual(30, added.data[0]); + try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); - const added = distance.add(distance2); - try std.testing.expectEqual(30, added.value()); - try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); - - const KiloMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); const distance3 = KiloMeter.splat(2); - const added2 = distance.add(distance3); - try std.testing.expectEqual(2010, added2.value()); - try std.testing.expectEqual(1, @TypeOf(added2).dims.get(.L)); + const added2 = distance.add(distance3); + try std.testing.expectEqual(2010, added2.data[0]); const added3 = distance3.add(distance).to(KiloMeter); - try std.testing.expectEqual(2, added3.value()); - try std.testing.expectEqual(1, @TypeOf(added3).dims.get(.L)); + try std.testing.expectEqual(2, added3.data[0]); - const KiloMeter_f = Scalar(f64, .{ .L = 1 }, .{ .L = .k }); const distance4 = KiloMeter_f.splat(2); - const added4 = distance4.add(distance).to(KiloMeter_f); - try std.testing.expectApproxEqAbs(2.01, added4.value(), 0.000001); - try std.testing.expectEqual(1, @TypeOf(added4).dims.get(.L)); + const added4 = distance4.add(distance).to(KiloMeter_f); + try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001); } test "Scalar Sub" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const KiloMeter_f = Scalar(f64, .{ .L = 1 }, .{ .L = .k }); - - const a = Meter.splat(500); - const b = Meter.splat(200); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const a = Meter.splat(500); + const b = Meter.splat(200); const diff = a.sub(b); - try std.testing.expectEqual(300, diff.value()); + try std.testing.expectEqual(300, diff.data[0]); const diff2 = b.sub(a); - try std.testing.expectEqual(-300, diff2.value()); + try std.testing.expectEqual(-300, diff2.data[0]); - const km_f = KiloMeter_f.splat(2.5); - const m_f = Meter.splat(500); + const km_f = KiloMeter_f.splat(2.5); + const m_f = Meter.splat(500); const diff3 = km_f.sub(m_f); - try std.testing.expectApproxEqAbs(2000, diff3.value(), 1e-4); + try std.testing.expectApproxEqAbs(2000, diff3.data[0], 1e-4); } test "Scalar MulBy" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const Second = Scalar(f32, .{ .T = 1 }, .{}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); - const d = Meter.splat(3.0); - const t = Second.splat(4.0); + const d = Meter.splat(3); + const t = Second.splat(4); + const at = d.mul(t); + try std.testing.expectEqual(12, at.data[0]); + try std.testing.expectEqual(1, @TypeOf(at).dims.get(.L)); + try std.testing.expectEqual(1, @TypeOf(at).dims.get(.T)); - const area_time = d.mul(t); - try std.testing.expectEqual(12, area_time.value()); - try std.testing.expectEqual(1, @TypeOf(area_time).dims.get(.L)); - try std.testing.expectEqual(1, @TypeOf(area_time).dims.get(.T)); - - const d2 = Meter.splat(5.0); + const d2 = Meter.splat(5); const area = d.mul(d2); - try std.testing.expectEqual(15, area.value()); - try std.testing.expectEqual(2, @TypeOf(area).dims.get(.L)); + try std.testing.expectEqual(15, area.data[0]); + try std.testing.expectEqual(2, @TypeOf(area).dims.get(.L)); } test "Scalar MulBy with scale" { - const KiloMeter = Scalar(f32, .{ .L = 1 }, .{ .L = .k }); - const KiloGram = Scalar(f32, .{ .M = 1 }, .{ .M = .k }); + const KiloMeter = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const KiloGram = Tensor(f32, .{ .M = 1 }, .{ .M = .k }, &.{1}); const dist = KiloMeter.splat(2.0); const mass = KiloGram.splat(3.0); @@ -740,126 +875,116 @@ test "Scalar MulBy with scale" { } test "Scalar MulBy with type change" { - const Meter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); - const Second = Scalar(f64, .{ .T = 1 }, .{}); - const KmSec = Scalar(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }); - const KmSec_f = Scalar(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }); + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Second = Tensor(f64, .{ .T = 1 }, .{}, &.{1}); + const KmSec = Tensor(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); + const KmSec_f = Tensor(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); - const d = Meter.splat(3.0); - const t = Second.splat(4.0); + const d = Meter.splat(3); + const t = Second.splat(4); - const area_time = d.mul(t).to(KmSec); - const area_time_f = d.mul(t).to(KmSec_f); - try std.testing.expectEqual(12, area_time.value()); - try std.testing.expectApproxEqAbs(12.0, area_time_f.value(), 0.0001); + try std.testing.expectEqual(12, d.mul(t).to(KmSec).data[0]); + try std.testing.expectApproxEqAbs(12.0, d.mul(t).to(KmSec_f).data[0], 0.0001); } test "Scalar MulBy small" { - const Meter = Scalar(i128, .{ .L = 1 }, .{ .L = .n }); - const Second = Scalar(f32, .{ .T = 1 }, .{}); - - const d = Meter.splat(3.0); - const t = Second.splat(4.0); - - const area_time = d.mul(t); - try std.testing.expectEqual(12, area_time.value()); + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .n }, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const d = Meter.splat(3); + const t = Second.splat(4); + try std.testing.expectEqual(12, d.mul(t).data[0]); } test "Scalar MulBy dimensionless" { - const DimLess = Scalar(i128, .{}, .{}); - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - - const d = Meter.splat(7); + const DimLess = Tensor(i128, .{}, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(7); const scaled = d.mul(DimLess.splat(3)); - try std.testing.expectEqual(21, scaled.value()); + try std.testing.expectEqual(21, scaled.data[0]); } test "Scalar Sqrt" { - const MeterSquare = Scalar(i128, .{ .L = 2 }, .{}); + const MeterSquare = Tensor(i128, .{ .L = 2 }, .{}, &.{1}); + const MeterSquare_f = Tensor(f64, .{ .L = 2 }, .{}, &.{1}); - var d = MeterSquare.splat(9); + var d = MeterSquare.splat(9); var scaled = d.sqrt(); - try std.testing.expectEqual(3, scaled.value()); + try std.testing.expectEqual(3, scaled.data[0]); try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); - d = MeterSquare.splat(-5); + d = MeterSquare.splat(-5); scaled = d.sqrt(); - try std.testing.expectEqual(0, scaled.value()); + try std.testing.expectEqual(0, scaled.data[0]); - const MeterSquare_f = Scalar(f64, .{ .L = 2 }, .{}); - const d2 = MeterSquare_f.splat(20); + const d2 = MeterSquare_f.splat(20); const scaled2 = d2.sqrt(); - try std.testing.expectApproxEqAbs(4.472135955, scaled2.value(), 1e-4); + try std.testing.expectApproxEqAbs(4.472135955, scaled2.data[0], 1e-4); } test "Scalar Chained: velocity and acceleration" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const Second = Scalar(f32, .{ .T = 1 }, .{}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); - const dist = Meter.splat(100.0); - const t1 = Second.splat(5.0); + const dist = Meter.splat(100); + const t1 = Second.splat(5); const velocity = dist.div(t1); - try std.testing.expectEqual(20, velocity.value()); + try std.testing.expectEqual(20, velocity.data[0]); - const t2 = Second.splat(4.0); + const t2 = Second.splat(4); const accel = velocity.div(t2); - try std.testing.expectEqual(5, accel.value()); + try std.testing.expectEqual(5, accel.data[0]); } test "Scalar DivBy integer exact" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const Second = Scalar(f32, .{ .T = 1 }, .{}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); const dist = Meter.splat(120); const time = Second.splat(4); - const vel = dist.div(time); - - try std.testing.expectEqual(30, vel.value()); + const vel = dist.div(time); + try std.testing.expectEqual(30, vel.data[0]); } test "Scalar Finer scales skip dim 0" { - const Dimless = Scalar(i128, .{}, .{}); - const KiloMetre = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); + const Dimless = Tensor(i128, .{}, .{}, &.{1}); + const KiloMetre = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const r = Dimless.splat(30); - const time = KiloMetre.splat(4); - const vel = r.mul(time); - - try std.testing.expectEqual(120, vel.value()); + const r = Dimless.splat(30); + const km = KiloMetre.splat(4); + const vel = r.mul(km); + try std.testing.expectEqual(120, vel.data[0]); try std.testing.expectEqual(Scales.UnitScale.k, @TypeOf(vel).scales.get(.L)); } test "Scalar Conversion chain: km -> m -> cm" { - const KiloMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const CentiMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .c }); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const CentiMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .c }, &.{1}); const km = KiloMeter.splat(15); - const m = km.to(Meter); + const m = km.to(Meter); const cm = m.to(CentiMeter); - - try std.testing.expectEqual(15_000, m.value()); - try std.testing.expectEqual(1_500_000, cm.value()); + try std.testing.expectEqual(15_000, m.data[0]); + try std.testing.expectEqual(1_500_000, cm.data[0]); } test "Scalar Conversion: hours -> minutes -> seconds" { - const Hour = Scalar(i128, .{ .T = 1 }, .{ .T = .hour }); - const Minute = Scalar(i128, .{ .T = 1 }, .{ .T = .min }); - const Second = Scalar(i128, .{ .T = 1 }, .{}); + const Hour = Tensor(i128, .{ .T = 1 }, .{ .T = .hour }, &.{1}); + const Minute = Tensor(i128, .{ .T = 1 }, .{ .T = .min }, &.{1}); + const Second = Tensor(i128, .{ .T = 1 }, .{}, &.{1}); - const h = Hour.splat(1.0); + const h = Hour.splat(1); const min = h.to(Minute); const sec = min.to(Second); - - try std.testing.expectEqual(60, min.value()); - try std.testing.expectEqual(3600, sec.value()); + try std.testing.expectEqual(60, min.data[0]); + try std.testing.expectEqual(3600, sec.data[0]); } -test "Scalar Format Scalar" { - const MeterPerSecondSq = Scalar(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }); - const Meter = Scalar(f32, .{ .L = 1 }, .{}); +test "Scalar Format" { + const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{1}); + const Meter = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); - const m = Meter.splat(1.23456); + const m = Meter.splat(1.23456); const accel = MeterPerSecondSq.splat(9.81); var buf: [64]u8 = undefined; @@ -871,96 +996,72 @@ test "Scalar Format Scalar" { } test "Scalar Abs" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const m1 = Meter.splat(-50); - const m2 = m1.abs(); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const MeterF = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); - try std.testing.expectEqual(50, m2.value()); - - const m_float = Scalar(f32, .{ .L = 1 }, .{}); - const m3 = m_float.splat(-42.5); - try std.testing.expectEqual(42.5, m3.abs().value()); + try std.testing.expectEqual(50, Meter.splat(-50).abs().data[0]); + try std.testing.expectEqual(42.5, MeterF.splat(-42.5).abs().data[0]); } test "Scalar Pow" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const d = Meter.splat(4); - - const area = d.pow(2); - try std.testing.expectEqual(16, area.value()); - - const volume = d.pow(3); - try std.testing.expectEqual(64, volume.value()); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(4); + try std.testing.expectEqual(16, d.pow(2).data[0]); + try std.testing.expectEqual(64, d.pow(3).data[0]); } test "Scalar mul comptime_int" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const d = Meter.splat(7); - - const scaled = d.mul(3); - try std.testing.expectEqual(21, scaled.value()); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(7); + try std.testing.expectEqual(21, d.mul(3).data[0]); } test "Scalar add/sub bare number on dimensionless scalar" { - const DimLess = Scalar(i128, .{}, .{}); + const DimLess = Tensor(i128, .{}, .{}, &.{1}); const a = DimLess.splat(10); - - const b = a.add(5); - try std.testing.expectEqual(15, b.value()); - - const c = a.sub(3); - try std.testing.expectEqual(7, c.value()); + try std.testing.expectEqual(15, a.add(5).data[0]); + try std.testing.expectEqual(7, a.sub(3).data[0]); } test "Scalar Imperial length scales" { - const Foot = Scalar(f64, .{ .L = 1 }, .{ .L = .ft }); - const Meter = Scalar(f64, .{ .L = 1 }, .{}); - const Inch = Scalar(f64, .{ .L = 1 }, .{ .L = .inch }); + const Foot = Tensor(f64, .{ .L = 1 }, .{ .L = .ft }, &.{1}); + const Meter = Tensor(f64, .{ .L = 1 }, .{}, &.{1}); + const Inch = Tensor(f64, .{ .L = 1 }, .{ .L = .inch }, &.{1}); - const one_ft = Foot.splat(1.0); - try std.testing.expectApproxEqAbs(0.3048, one_ft.to(Meter).value(), 1e-9); - - const twelve_in = Inch.splat(12.0); - try std.testing.expectApproxEqAbs(1.0, twelve_in.to(Foot).value(), 1e-9); + try std.testing.expectApproxEqAbs(0.3048, Foot.splat(1.0).to(Meter).data[0], 1e-9); + try std.testing.expectApproxEqAbs(1.0, Inch.splat(12.0).to(Foot).data[0], 1e-9); } test "Scalar Imperial mass scales" { - const Pound = Scalar(f64, .{ .M = 1 }, .{ .M = .lb }); - const Ounce = Scalar(f64, .{ .M = 1 }, .{ .M = .oz }); + const Pound = Tensor(f64, .{ .M = 1 }, .{ .M = .lb }, &.{1}); + const Ounce = Tensor(f64, .{ .M = 1 }, .{ .M = .oz }, &.{1}); - const two_lb = Pound.splat(2.0); - const eight_oz = Ounce.splat(8.0); - const total = two_lb.add(eight_oz).to(Pound); - try std.testing.expectApproxEqAbs(2.5, total.value(), 1e-6); + const total = Pound.splat(2.0).add(Ounce.splat(8.0)).to(Pound); + try std.testing.expectApproxEqAbs(2.5, total.data[0], 1e-6); } test "Scalar comparisons with comptime_int on dimensionless scalar" { - const DimLess = Scalar(i128, .{}, .{}); + const DimLess = Tensor(i128, .{}, .{}, &.{1}); const x = DimLess.splat(42); - try std.testing.expect(x.eq(42)); try std.testing.expect(x.gt(10)); } -// Vector tests - -pub fn Vector(N: comptime_int, Q: type) type { - return Quantity(Q.ValueType, N, Q.dims.argsOpt(), Q.scales.argsOpt()); -} +// ─── Vector / Tensor tests ──────────────────────────────────────────────── test "Vector initiate" { - const Meter = Vector(4, Scalar(f32, .{ .L = 1 }, .{})); - const m = Meter.splat(1); - + const Meter4 = Tensor(f32, .{ .L = 1 }, .{}, &.{4}); + const m = Meter4.splat(1); try std.testing.expect(m.data[0] == 1); + try std.testing.expect(m.data[3] == 1); } test "Vector format" { - const MeterPerSecondSq = Scalar(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }); - const KgMeterPerSecond = Scalar(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }); + const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{3}); + const KgMeterPerSecond = Tensor(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3}); - const accel = MeterPerSecondSq.Vec3.splat(9.81); - const momentum = KgMeterPerSecond.Vec3{ .data = .{ 43, 0, 11 } }; + const accel = MeterPerSecondSq.splat(9.81); + const momentum = KgMeterPerSecond{ .data = .{ 43, 0, 11 } }; var buf: [64]u8 = undefined; var res = try std.fmt.bufPrint(&buf, "{d}", .{accel}); @@ -971,28 +1072,20 @@ test "Vector format" { } test "Vector Vec3 Init and Basic Arithmetic" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const Vec3M = Meter.Vec3; + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); - // Test zero, one, initDefault - const v_zero = Vec3M.zero; + const v_zero = Meter3.zero; try std.testing.expectEqual(0, v_zero.data[0]); - try std.testing.expectEqual(0, v_zero.data[1]); try std.testing.expectEqual(0, v_zero.data[2]); - const v_one = Vec3M.one; + const v_one = Meter3.one; try std.testing.expectEqual(1, v_one.data[0]); - try std.testing.expectEqual(1, v_one.data[1]); - try std.testing.expectEqual(1, v_one.data[2]); - const v_def = Vec3M.splat(5); - try std.testing.expectEqual(5, v_def.data[0]); - try std.testing.expectEqual(5, v_def.data[1]); + const v_def = Meter3.splat(5); try std.testing.expectEqual(5, v_def.data[2]); - // Test add and sub - const v1 = Vec3M{ .data = .{ 10, 20, 30 } }; - const v2 = Vec3M{ .data = .{ 2, 4, 6 } }; + const v1 = Meter3{ .data = .{ 10, 20, 30 } }; + const v2 = Meter3{ .data = .{ 2, 4, 6 } }; const added = v1.add(v2); try std.testing.expectEqual(12, added.data[0]); @@ -1000,260 +1093,256 @@ test "Vector Vec3 Init and Basic Arithmetic" { try std.testing.expectEqual(36, added.data[2]); const subbed = v1.sub(v2); - try std.testing.expectEqual(8, subbed.data[0]); + try std.testing.expectEqual(8, subbed.data[0]); try std.testing.expectEqual(16, subbed.data[1]); try std.testing.expectEqual(24, subbed.data[2]); - // Test negate const neg = v1.negate(); try std.testing.expectEqual(-10, neg.data[0]); try std.testing.expectEqual(-20, neg.data[1]); try std.testing.expectEqual(-30, neg.data[2]); } -test "Vector Kinematics (Scalar Mul/Div)" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const Second = Scalar(i32, .{ .T = 1 }, .{}); - const Vec3M = Meter.Vec3; +test "Vector Kinematics (scalar mul/div broadcast)" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const Second1 = Tensor(i32, .{ .T = 1 }, .{}, &.{1}); - const pos = Vec3M{ .data = .{ 100, 200, 300 } }; - const time = Second.splat(10); + const pos = Meter3{ .data = .{ 100, 200, 300 } }; + const time = Second1.splat(10); - // Vector divided by scalar (Velocity = Position / Time) - const vel = pos.divScalar(time); - try std.testing.expectEqual(10, vel.data[0]); - try std.testing.expectEqual(20, vel.data[1]); - try std.testing.expectEqual(30, vel.data[2]); - try std.testing.expectEqual(1, @TypeOf(vel).dims.get(.L)); - try std.testing.expectEqual(-1, @TypeOf(vel).dims.get(.T)); + const vel = pos.div(time); + try std.testing.expectEqual(10, vel.data[0]); + try std.testing.expectEqual(20, vel.data[1]); + try std.testing.expectEqual(30, vel.data[2]); + try std.testing.expectEqual(1, @TypeOf(vel).dims.get(.L)); + try std.testing.expectEqual(-1, @TypeOf(vel).dims.get(.T)); - // Vector multiplied by scalar (Position = Velocity * Time) - const new_pos = vel.mulScalar(time); + const new_pos = vel.mul(time); try std.testing.expectEqual(100, new_pos.data[0]); - try std.testing.expectEqual(200, new_pos.data[1]); - try std.testing.expectEqual(300, new_pos.data[2]); - try std.testing.expectEqual(1, @TypeOf(new_pos).dims.get(.L)); - try std.testing.expectEqual(0, @TypeOf(new_pos).dims.get(.T)); + try std.testing.expectEqual(0, @TypeOf(new_pos).dims.get(.T)); } test "Vector Element-wise Math and Scaling" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const Vec3M = Meter.Vec3; + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); - const v1 = Vec3M{ .data = .{ 10, 20, 30 } }; - const v2 = Vec3M{ .data = .{ 2, 5, 10 } }; - - // Element-wise division - const div = v1.div(v2); - try std.testing.expectEqual(5, div.data[0]); - try std.testing.expectEqual(4, div.data[1]); - try std.testing.expectEqual(3, div.data[2]); - try std.testing.expectEqual(0, @TypeOf(div).dims.get(.L)); // M / M = Dimensionless + const v1 = Meter3{ .data = .{ 10, 20, 30 } }; + const v2 = Meter3{ .data = .{ 2, 5, 10 } }; + const dv = v1.div(v2); + try std.testing.expectEqual(5, dv.data[0]); + try std.testing.expectEqual(4, dv.data[1]); + try std.testing.expectEqual(3, dv.data[2]); + try std.testing.expectEqual(0, @TypeOf(dv).dims.get(.L)); } test "Vector Conversions" { - const KiloMeter = Scalar(i32, .{ .L = 1 }, .{ .L = .k }); - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - - const v_km = KiloMeter.Vec3{ .data = .{ 1, 2, 3 } }; - const v_m = v_km.to(Meter); + const KiloMeter3 = Tensor(i32, .{ .L = 1 }, .{ .L = .k }, &.{3}); + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const v_km = KiloMeter3{ .data = .{ 1, 2, 3 } }; + const v_m = v_km.to(Meter3); try std.testing.expectEqual(1000, v_m.data[0]); try std.testing.expectEqual(2000, v_m.data[1]); try std.testing.expectEqual(3000, v_m.data[2]); - - // Type checking the result - try std.testing.expectEqual(1, @TypeOf(v_m).dims.get(.L)); try std.testing.expectEqual(UnitScale.none, @TypeOf(v_m).scales.get(.L)); } test "Vector Length" { - const MeterInt = Scalar(i32, .{ .L = 1 }, .{}); - const MeterFloat = Scalar(f32, .{ .L = 1 }, .{}); + const MeterInt3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const MeterFloat3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - // Integer length - // 3-4-5 triangle on XY plane - const v_int = MeterInt.Vec3{ .data = .{ 3, 4, 0 } }; + const v_int = MeterInt3{ .data = .{ 3, 4, 0 } }; try std.testing.expectEqual(25, v_int.lengthSqr()); - try std.testing.expectEqual(5, v_int.length()); + try std.testing.expectEqual(5, v_int.length()); - // Float length - const v_float = MeterFloat.Vec3{ .data = .{ 3.0, 4.0, 0.0 } }; + const v_float = MeterFloat3{ .data = .{ 3.0, 4.0, 0.0 } }; try std.testing.expectApproxEqAbs(@as(f32, 25.0), v_float.lengthSqr(), 1e-4); - try std.testing.expectApproxEqAbs(@as(f32, 5.0), v_float.length(), 1e-4); + try std.testing.expectApproxEqAbs(@as(f32, 5.0), v_float.length(), 1e-4); } test "Vector Comparisons" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - const KiloMeter = Scalar(f32, .{ .L = 1 }, .{ .L = .k }); + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter3 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{3}); - const v1 = Meter.Vec3{ .data = .{ 1000.0, 500.0, 0.0 } }; - const v2 = KiloMeter.Vec3{ .data = .{ 1.0, 0.5, 0.0 } }; - const v3 = KiloMeter.Vec3{ .data = .{ 1.0, 0.6, 0.0 } }; + const v1 = Meter3{ .data = .{ 1000.0, 500.0, 0.0 } }; + const v2 = KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } }; + const v3 = KiloMeter3{ .data = .{ 1.0, 0.6, 0.0 } }; - // 1. Equality (Whole vector) try std.testing.expect(v1.eqAll(v2)); try std.testing.expect(v1.neAll(v3)); - // 2. Element-wise Ordered Comparison - const higher = v3.gt(v1); // compares 1km, 0.6km, 0km vs 1000m, 500m, 0m - try std.testing.expectEqual(false, higher[0]); // 1km == 1000m - try std.testing.expectEqual(true, higher[1]); // 0.6km > 500m - try std.testing.expectEqual(false, higher[2]); // 0 == 0 + const higher = v3.gt(v1); + try std.testing.expectEqual(false, higher[0]); + try std.testing.expectEqual(true, higher[1]); + try std.testing.expectEqual(false, higher[2]); - // 3. Element-wise Equal Comparison - const equal = v3.eq(v1); // compares 1km, 0.6km, 0km vs 1000m, 500m, 0m - try std.testing.expectEqual(true, equal[0]); // 1km == 1000m - try std.testing.expectEqual(false, equal[1]); // 0.6km > 500m - try std.testing.expectEqual(true, equal[2]); // 0 == 0 + const equal = v3.eq(v1); + try std.testing.expectEqual(true, equal[0]); + try std.testing.expectEqual(false, equal[1]); + try std.testing.expectEqual(true, equal[2]); - // 3. Less than or equal const low_eq = v1.lte(v3); try std.testing.expect(low_eq[0] and low_eq[1] and low_eq[2]); } -test "Vector vs Scalar Comparisons" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - const KiloMeter = Scalar(f32, .{ .L = 1 }, .{ .L = .k }); +test "Vector vs Scalar broadcast comparison" { + // Replaces the old eqScalar / gtScalar — now just eq / gt with a shape-{1} rhs. + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const positions = Meter.Vec3{ .data = .{ 500.0, 1200.0, 3000.0 } }; - const threshold = KiloMeter.splat(1); // 1km (1000m) + const positions = Meter3{ .data = .{ 500.0, 1200.0, 3000.0 } }; + const threshold = KiloMeter1.splat(1); // 1 km = 1000 m - // Check which axes exceed the 1km threshold - const exceeded = positions.gtScalar(threshold); + const exceeded = positions.gt(threshold); + try std.testing.expectEqual(false, exceeded[0]); + try std.testing.expectEqual(true, exceeded[1]); + try std.testing.expectEqual(true, exceeded[2]); - try std.testing.expectEqual(false, exceeded[0]); // 500m > 1km is false - try std.testing.expectEqual(true, exceeded[1]); // 1200m > 1km is true - try std.testing.expectEqual(true, exceeded[2]); // 3000m > 1km is true - - // Check for equality (broadcasted) - const exact_match = positions.eqScalar(Meter.splat(500)); - try std.testing.expect(exact_match[0] == true); - try std.testing.expect(exact_match[1] == false); + const Meter1 = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + const exact = positions.eq(Meter1.splat(500)); + try std.testing.expect(exact[0] == true); + try std.testing.expect(exact[1] == false); } -test "Vector Dot and Cross Products" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - const Newton = Scalar(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}); +test "Vector contract — dot product (rank-1 × rank-1)" { + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const Newton3 = Tensor(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3}); - const pos = Meter.Vec3{ .data = .{ 10.0, 0.0, 0.0 } }; - const force = Newton.Vec3{ .data = .{ 5.0, 5.0, 0.0 } }; + const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } }; + const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } }; - // 1. Dot Product (Work = F dot d) - const work = force.dot(pos); - try std.testing.expectEqual(50.0, work.value()); - // Dimensions should be M¹L²T⁻² (Energy/Joules) + // work = force · pos + const work = force.contract(pos, 0, 0); + try std.testing.expectEqual(50.0, work.data[0]); try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M)); try std.testing.expectEqual(2, @TypeOf(work).dims.get(.L)); try std.testing.expectEqual(-2, @TypeOf(work).dims.get(.T)); +} - // 2. Cross Product (Torque = r cross F) - const torque = pos.cross(force); - try std.testing.expectEqual(0.0, torque.data[0]); - try std.testing.expectEqual(0.0, torque.data[1]); - try std.testing.expectEqual(50.0, torque.data[2]); - // Torque dimensions are same as Energy but as a Vector - try std.testing.expectEqual(2, @TypeOf(torque).dims.get(.L)); +test "Vector contract — matrix multiply (rank-2 × rank-2)" { + // 2×3 matrix multiplied by 3×2 matrix → 2×2 result + const A = Tensor(f32, .{}, .{}, &.{2, 3}); + const B = Tensor(f32, .{}, .{}, &.{3, 2}); + + // A = [[1, 2, 3], + // [4, 5, 6]] + const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } }; + // B = [[7, 8], + // [9, 10], + // [11, 12]] + const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } }; + + // C = A @ B (contract over axis 1 of A × axis 0 of B) + // C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58 + // C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64 + // C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139 + // C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154 + const c = a.contract(b, 1, 0); + try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 0})]); + try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 1})]); + try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 0})]); + try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 1})]); } test "Vector Abs, Pow, Sqrt and Product" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - const v1 = Meter.Vec3{ .data = .{ -2.0, 3.0, -4.0 } }; - - // 1. Abs + const v1 = Meter3{ .data = .{ -2.0, 3.0, -4.0 } }; const v_abs = v1.abs(); try std.testing.expectEqual(2.0, v_abs.data[0]); try std.testing.expectEqual(4.0, v_abs.data[2]); - // 2. Product (L1 * L1 * L1 = L3) const vol = v_abs.product(); - try std.testing.expectEqual(24.0, vol.value()); + try std.testing.expectEqual(24.0, vol.data[0]); try std.testing.expectEqual(3, @TypeOf(vol).dims.get(.L)); - // 3. Pow (Scalar exponent: (L1)^2 = L2) const area_vec = v_abs.pow(2); - try std.testing.expectEqual(4.0, area_vec.data[0]); + try std.testing.expectEqual(4.0, area_vec.data[0]); try std.testing.expectEqual(16.0, area_vec.data[2]); try std.testing.expectEqual(2, @TypeOf(area_vec).dims.get(.L)); - // 4. Sqrt const sqrted = area_vec.sqrt(); try std.testing.expectEqual(2, sqrted.data[0]); try std.testing.expectEqual(4, sqrted.data[2]); try std.testing.expectEqual(1, @TypeOf(sqrted).dims.get(.L)); } -test "Vector mulScalar comptime_int" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const v = Meter.Vec3{ .data = .{ 1, 2, 3 } }; - - const scaled = v.mulScalar(10); // comptime_int → dimensionless +test "Vector mul comptime_int broadcast" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const v = Meter3{ .data = .{ 1, 2, 3 } }; + const scaled = v.mul(10); try std.testing.expectEqual(10, scaled.data[0]); try std.testing.expectEqual(20, scaled.data[1]); try std.testing.expectEqual(30, scaled.data[2]); - // Dimensions unchanged: L¹ × dimensionless = L¹ try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); - try std.testing.expectEqual(0, @TypeOf(scaled).dims.get(.T)); } -test "Vector mulScalar comptime_float" { - const MeterF = Scalar(f32, .{ .L = 1 }, .{}); - const v = MeterF.Vec3{ .data = .{ 1.0, 2.0, 4.0 } }; - - const scaled = v.mulScalar(0.5); // comptime_float → dimensionless +test "Vector mul comptime_float broadcast" { + const MeterF3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } }; + const scaled = v.mul(0.5); try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6); try std.testing.expectApproxEqAbs(1.0, scaled.data[1], 1e-6); try std.testing.expectApproxEqAbs(2.0, scaled.data[2], 1e-6); try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); } -test "Vector mulScalar T (value type)" { - const MeterF = Scalar(f32, .{ .L = 1 }, .{}); - const v = MeterF.Vec3{ .data = .{ 3.0, 6.0, 9.0 } }; - const factor: f32 = 2.0; - - const scaled = v.mulScalar(factor); - try std.testing.expectApproxEqAbs(6.0, scaled.data[0], 1e-6); - try std.testing.expectApproxEqAbs(12.0, scaled.data[1], 1e-6); - try std.testing.expectApproxEqAbs(18.0, scaled.data[2], 1e-6); - try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); -} - -test "Vector divScalar comptime_int" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const v = Meter.Vec3{ .data = .{ 10, 20, 30 } }; - - const halved = v.divScalar(2); // comptime_int → dimensionless divisor - try std.testing.expectEqual(5, halved.data[0]); +test "Vector div comptime_int broadcast" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const v = Meter3{ .data = .{ 10, 20, 30 } }; + const halved = v.div(2); + try std.testing.expectEqual(5, halved.data[0]); try std.testing.expectEqual(10, halved.data[1]); try std.testing.expectEqual(15, halved.data[2]); try std.testing.expectEqual(1, @TypeOf(halved).dims.get(.L)); } -test "Vector divScalar comptime_float" { - const MeterF = Scalar(f64, .{ .L = 1 }, .{}); - const v = MeterF.Vec3{ .data = .{ 9.0, 6.0, 3.0 } }; - - const r = v.divScalar(3.0); +test "Vector div comptime_float broadcast" { + const MeterF3 = Tensor(f64, .{ .L = 1 }, .{}, &.{3}); + const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } }; + const r = v.div(3.0); try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9); try std.testing.expectApproxEqAbs(2.0, r.data[1], 1e-9); try std.testing.expectApproxEqAbs(1.0, r.data[2], 1e-9); - try std.testing.expectEqual(1, @TypeOf(r).dims.get(.L)); } -test "Vector eqScalar / gtScalar with comptime_int on dimensionless vector" { - // Bare numbers are dimensionless, so comparisons only work when vector is dimensionless too. - const DimLess = Scalar(i32, .{}, .{}); - const v = DimLess.Vec3{ .data = .{ 1, 2, 3 } }; +test "Vector eq broadcast on dimensionless" { + const DimLess3 = Tensor(i32, .{}, .{}, &.{3}); + const v = DimLess3{ .data = .{ 1, 2, 3 } }; - const eq_res = v.eqScalar(2); + const eq_res = v.eq(2); try std.testing.expectEqual(false, eq_res[0]); - try std.testing.expectEqual(true, eq_res[1]); + try std.testing.expectEqual(true, eq_res[1]); try std.testing.expectEqual(false, eq_res[2]); - const gt_res = v.gtScalar(1); + const gt_res = v.gt(1); try std.testing.expectEqual(false, gt_res[0]); - try std.testing.expectEqual(true, gt_res[1]); - try std.testing.expectEqual(true, gt_res[2]); + try std.testing.expectEqual(true, gt_res[1]); + try std.testing.expectEqual(true, gt_res[2]); +} + +test "Tensor idx helper and matrix access" { + const Mat3x3 = Tensor(f32, .{}, .{}, &.{3, 3}); + // Identity-like: set [0][0]=1, [1][1]=2, [2][2]=3 + var m: Mat3x3 = Mat3x3.zero; + m.data[Mat3x3.idx(.{0, 0})] = 1.0; + m.data[Mat3x3.idx(.{1, 1})] = 2.0; + m.data[Mat3x3.idx(.{2, 2})] = 3.0; + + try std.testing.expectEqual(1.0, m.data[0]); // [0][0] + try std.testing.expectEqual(2.0, m.data[4]); // [1][1] (stride 3 → 1*3+1=4) + try std.testing.expectEqual(3.0, m.data[8]); // [2][2] (2*3+2=8) + try std.testing.expectEqual(0.0, m.data[1]); // [0][1] +} + +test "Tensor strides_arr correctness" { + const T1 = Tensor(f32, .{}, .{}, &.{3}); + const T2 = Tensor(f32, .{}, .{}, &.{3, 4}); + const T3 = Tensor(f32, .{}, .{}, &.{2, 3, 4}); + + try std.testing.expectEqual(1, T1.strides_arr[0]); + try std.testing.expectEqual(4, T2.strides_arr[0]); + try std.testing.expectEqual(1, T2.strides_arr[1]); + try std.testing.expectEqual(12, T3.strides_arr[0]); + try std.testing.expectEqual(4, T3.strides_arr[1]); + try std.testing.expectEqual(1, T3.strides_arr[2]); }