From 934a40fe1a643534488b43f35d9139de476f6e1b Mon Sep 17 00:00:00 2001 From: Adrien Bouvais Date: Sun, 26 Apr 2026 22:16:25 +0200 Subject: [PATCH 01/10] Basic untested tensor --- src/Quantity.zig | 1405 ++++++++++++++++++++++++---------------------- 1 file changed, 747 insertions(+), 658 deletions(-) diff --git a/src/Quantity.zig b/src/Quantity.zig index 4e30ec9..f0edea6 100644 --- a/src/Quantity.zig +++ b/src/Quantity.zig @@ -5,240 +5,385 @@ const UnitScale = Scales.UnitScale; const Dimensions = @import("Dimensions.zig"); const Dimension = Dimensions.Dimension; -// --------------------------------------------------------------------------- -// Quantity — single unified dimensioned type. -// -// T : element numeric type (f32, f64, i32, i128, …) -// N : lane count (1 = Scalar, >1 = Vector) -// d : SI dimension exponents -// s : unit scales -// -// All arithmetic is performed directly on the underlying @Vector(N, T), so -// the compiler can emit SIMD instructions wherever the target supports them. -// -// Thin aliases (same type identity, no wrapper overhead): -// Scalar(T, d, s) ≡ Quantity(T, 1, d, s) -// Vector(N, Q) ≡ Quantity(Q.ValueType, N, Q.dims.argsOpt(), Q.scales.argsOpt()) -// --------------------------------------------------------------------------- -// -// To investigate: -// @reduce(comptime op: std.builtin.ReduceOp, value: anytype) E -// @select(comptime T: type, pred: @Vector(len, bool), a: @Vector(len, T), b: @Vector(len, T)) @Vector(len, T) -// @shuffle(comptime E: type, a: @Vector(a_len, E), b: @Vector(b_len, E), comptime mask: @Vector(mask_len, i32)) @Vector(mask_len, E) +// ───────────────────────────────────────────────────────────────────────────── +// Comptime shape utilities +// ───────────────────────────────────────────────────────────────────────────── -pub fn Quantity( +pub fn shapeTotal(comptime shape: []const usize) usize { + var t: usize = 1; + for (shape) |s| t *= s; + return t; +} + +/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). +/// e.g. shape {3, 4} → strides {4, 1} +/// shape {2, 3, 4} → strides {12, 4, 1} +pub fn shapeStrides(comptime shape: []const usize) [shape.len]usize { + var st: [shape.len]usize = undefined; + if (shape.len == 0) return st; + st[shape.len - 1] = 1; + if (shape.len > 1) { + var i: usize = shape.len - 1; + while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i]; + } + return st; +} + +/// Return a copy of `shape` with the element at `axis` removed. +pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [shape.len - 1]usize { + var out: [shape.len - 1]usize = undefined; + var j: usize = 0; + for (shape, 0..) |v, i| { + if (i != axis) { out[j] = v; j += 1; } + } + return out; +} + +/// Concatenate two compile-time slices. +pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b.len]usize { + var out: [a.len + b.len]usize = undefined; + for (a, 0..) |v, i| out[i] = v; + for (b, 0..) |v, i| out[a.len + i] = v; + return out; +} + +/// Decode a flat row-major index into N-D coordinates. +/// Called only in comptime contexts (all arguments are comptime). +pub fn decodeFlatCoords( + comptime flat: usize, + comptime n: usize, + comptime strd: [n]usize, +) [n]usize { + var coords: [n]usize = undefined; + var tmp = flat; + for (0..n) |i| { + coords[i] = if (strd[i] == 0) 0 else tmp / strd[i]; + tmp = if (strd[i] == 0) 0 else tmp % strd[i]; + } + return coords; +} + +/// Encode N-D coordinates into a flat row-major index. +/// Called only in comptime contexts. +pub fn encodeFlatCoords( + comptime coords: []const usize, + comptime n: usize, + comptime strd: [n]usize, +) usize { + var flat: usize = 0; + for (0..n) |i| flat += coords[i] * strd[i]; + return flat; +} + +/// Rebuild a full coordinate array by inserting `val` at `axis` into `free`. +/// `free` holds the remaining (non-contracted) coordinates in order. +pub fn insertAxis( + comptime n: usize, + comptime axis: usize, + comptime val: usize, + comptime free: []const usize, +) [n]usize { + var out: [n]usize = undefined; + var fi: usize = 0; + for (0..n) |i| { + if (i == axis) { + out[i] = val; + } else { + out[i] = free[fi]; + fi += 1; + } + } + return out; +} + +// ───────────────────────────────────────────────────────────────────────────── +// File-scope RHS normalisation helpers +// +// Any bare comptime_int / comptime_float / runtime T used as an arithmetic +// or comparison RHS is wrapped into a dimensionless Tensor of shape {1}. +// Actual Tensor types are passed through unchanged. +// ───────────────────────────────────────────────────────────────────────────── + +fn RhsTensorType(comptime T: type, comptime Rhs: type) type { + if (@hasDecl(Rhs, "ISTENSOR")) return Rhs; + return Tensor(T, .{}, .{}, &.{1}); +} + +fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { + const Rhs = @TypeOf(r); + if (comptime @hasDecl(Rhs, "ISTENSOR")) return r; + const scalar: T = switch (comptime @typeInfo(Rhs)) { + .comptime_int => switch (comptime @typeInfo(T)) { + .float => @as(T, @floatFromInt(r)), + else => @as(T, r), + }, + .comptime_float => switch (comptime @typeInfo(T)) { + .int => @as(T, @intFromFloat(r)), + else => @as(T, r), + }, + .int => switch (comptime @typeInfo(T)) { + .float => @floatFromInt(r), + else => @intCast(r), + }, + .float => switch (comptime @typeInfo(T)) { + .int => @intFromFloat(r), + else => @floatCast(r), + }, + else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)), + }; + return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} }; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tensor — unified dimensioned ND type. +// +// T : element numeric type (f32, f64, i32, i128, …) +// d_opt : SI dimension exponents +// s_opt : unit scales +// shape_ : compile-time shape +// &.{1} → scalar +// &.{3} → 3-vector +// &.{4, 4} → 4×4 matrix +// &.{3, 3, 3} → 3D field +// +// Storage: flat @Vector(total, T) where total = product(shape_). +// All arithmetic operates on the flat vector directly → SIMD wherever possible. +// +// Shape-related comptime constants exposed on every Tensor type: +// dims : Dimensions — SI exponent struct +// scales : Scales — unit scale struct +// shape : []const usize +// rank : usize = shape.len +// total : usize = product(shape) +// strides_arr : [rank]usize — row-major strides +// +// Index helper: +// Tensor.idx(.{row, col}) → flat index (comptime, no runtime cost) +// +// GPU readiness: +// tensor.asSlice() → []T (zero-copy pointer to the flat @Vector storage) +// +// Contraction (replaces dot / cross / matmul): +// a.contract(b, axis_a, axis_b) +// For rank-1 × rank-1 this is the dot product. +// For rank-2 × rank-2 with axis_a=1, axis_b=0 this is matrix multiply. +// +// Removed from Quantity: +// Scalar / Vector aliases, Vec3 / ScalarType, .value(), .vec(), .vec3(), +// dot(), cross(), mulScalar(), divScalar(), eqScalar() and friends. +// Use Tensor(..., &.{1}), .data[0], mul(), div(), eq() respectively. +// ───────────────────────────────────────────────────────────────────────────── + +pub fn Tensor( comptime T: type, - comptime N: usize, comptime d_opt: Dimensions.ArgOpts, comptime s_opt: Scales.ArgOpts, + comptime shape_: []const usize, ) type { - comptime std.debug.assert(N >= 1); + comptime { + std.debug.assert(shape_.len >= 1); + for (shape_) |s| std.debug.assert(s >= 1); + } @setEvalBranchQuota(10_000_000); - // Local shorthand for the SIMD vector type used in storage. - const Vec = @Vector(N, T); + const _total: usize = comptime shapeTotal(shape_); + const _strides = comptime shapeStrides(shape_); + const Vec = @Vector(_total, T); return struct { - /// SIMD-friendly storage. Arithmetic operates here directly. + /// Flat SIMD storage. All arithmetic operates here directly. data: Vec, const Self = @This(); - pub const ValueType: type = T; - pub const Len: usize = N; - pub const dims: Dimensions = Dimensions.init(d_opt); - pub const scales: Scales = Scales.init(s_opt); - pub const ISQUANTITY = true; + pub const ValueType : type = T; + pub const dims : Dimensions = Dimensions.init(d_opt); + pub const scales : Scales = Scales.init(s_opt); + pub const shape : []const usize = shape_; + pub const rank : usize = shape_.len; + pub const total : usize = _total; + pub const strides_arr: [shape_.len]usize = _strides; + pub const ISTENSOR = true; - /// Scalar variant of this quantity (lane=1). Returned by dot(), product(), etc. - pub const ScalarType: type = Quantity(T, 1, d_opt, s_opt); - /// Convenience: a 3-lane vector of the same dimension/scale. - pub const Vec3: type = Quantity(T, 3, d_opt, s_opt); + // ─────────────────────────────────────────────────────────────── + // Index helper + // ─────────────────────────────────────────────────────────────── - // --------------------------------------------------------------- + /// Convert N-D coords (row-major) to flat index — fully comptime. + /// Usage: Tensor.idx(.{row, col}) + pub fn idx(comptime coords: [rank]usize) usize { + comptime { + var flat: usize = 0; + for (0..rank) |i| { + std.debug.assert(coords[i] < shape[i]); + flat += coords[i] * strides_arr[i]; + } + return flat; + } + } + + // ─────────────────────────────────────────────────────────────── // Constructors - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── - /// Broadcast a single value across all N lanes. + /// Broadcast a single value across all elements. pub inline fn splat(v: T) Self { return .{ .data = @splat(v) }; } pub const zero: Self = splat(0); - pub const one: Self = splat(1); + pub const one: Self = splat(1); - // --------------------------------------------------------------- - // Scalar-only helpers (N = 1) - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── + // GPU readiness + // ─────────────────────────────────────────────────────────────── - /// Return the single scalar value. Compile error when N ≠ 1. - pub inline fn value(self: Self) T { - comptime if (N != 1) - @compileError(".value() is only available on Scalar (N=1)."); - return self.data[0]; + /// Return a mutable slice to the flat storage — zero-copy WebGPU buffer mapping. + pub inline fn asSlice(self: *Self) []T { + return @as([*]T, @ptrCast(&self.data))[0..total]; } - /// Expand this scalar into a len-lane vector by splatting. - pub inline fn vec(self: Self, comptime len: usize) Quantity(T, len, d_opt, s_opt) { - comptime if (N != 1) - @compileError(".vec() is only available on Scalar (N=1)."); - return .{ .data = @splat(self.data[0]) }; - } - - pub inline fn vec3(self: Self) Vec3 { - return self.vec(3); - } - - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Internal: RHS normalisation - // - // • For N=1 (Scalar context): bare numbers → Quantity(T, 1, dimless, none) - // • For N>1 (Vector context): bare numbers → Quantity(T, N, dimless, none) - // Quantity(T,1) → broadcast (handled in each op) - // - // A bare number used as rhs is ALWAYS treated as dimensionless. - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── - inline fn RhsT(comptime Rhs: type) type { - return hlp.rhsQuantityType(T, N, Rhs); - } - inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { - return hlp.toRhsQuantity(T, N, r); - } + inline fn RhsT(comptime Rhs: type) type { return RhsTensorType(T, Rhs); } + inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { return toRhsTensor(T, r); } - /// Scalar rhs (N=1) — used by mulScalar / divScalar / eqScalar etc. - inline fn ScalarRhsT(comptime Rhs: type) type { - return hlp.rhsQuantityType(T, 1, Rhs); - } - inline fn scalarRhs(r: anytype) ScalarRhsT(@TypeOf(r)) { - return hlp.toRhsQuantity(T, 1, r); - } + // ─────────────────────────────────────────────────────────────── + // Internal: scalar broadcast (shape {1} → full Vec) + // ─────────────────────────────────────────────────────────────── - // --------------------------------------------------------------- - // Internal: broadcast helper - // - // When an N=1 rhs is used in an N>1 operation, splat it. - // --------------------------------------------------------------- inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec { - if (comptime RhsType.Len == 1 and N > 1) - return @splat(r.data[0]) + return if (comptime RhsType.total == 1 and total > 1) + @splat(r.data[0]) else - return r.data; + r.data; } - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Arithmetic - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── /// Element-wise add. Dimensions must match; scales resolve to finer. - /// For N=1: rhs may be a bare number (treated as dimensionless). - /// For N>1: rhs must be a same-length Quantity. - pub inline fn add(self: Self, r: anytype) Quantity( + /// RHS must have the same element count as self, or total == 1 (broadcast). + pub inline fn add(self: Self, r: anytype) Tensor( T, - N, dims.argsOpt(), hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime N > 1 and RhsType.Len != N) - @compileError("Vector add requires same-length Quantity."); + if (comptime RhsType.total != total and RhsType.total != 1) + @compileError("Shape mismatch in add: element counts must match or RHS must be scalar (total=1)."); - const TargetType = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); + const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; - const rr: Vec = if (comptime RhsType == TargetType) rhs_q.data else rhs_q.to(TargetType).data; + const rr: Vec = blk: { + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + break :blk broadcastToVec(RhsNorm, rn); + }; return .{ .data = if (comptime hlp.isInt(T)) l +| rr else l + rr }; } /// Element-wise subtract. Dimensions must match; scales resolve to finer. - pub inline fn sub(self: Self, r: anytype) Quantity( + pub inline fn sub(self: Self, r: anytype) Tensor( T, - N, dims.argsOpt(), hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime N > 1 and RhsType.Len != N) - @compileError("Vector sub requires same-length Quantity."); + if (comptime RhsType.total != total and RhsType.total != 1) + @compileError("Shape mismatch in sub."); - const TargetType = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); + const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; - const rr: Vec = if (comptime RhsType == TargetType) rhs_q.data else rhs_q.to(TargetType).data; + const rr: Vec = blk: { + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + break :blk broadcastToVec(RhsNorm, rn); + }; return .{ .data = if (comptime hlp.isInt(T)) l -| rr else l - rr }; } - /// Element-wise multiply. Dimension exponents are summed. - /// An N=1 rhs on an N>1 self is automatically broadcast (scalar × vector). - pub inline fn mul(self: Self, r: anytype) Quantity( + /// Element-wise multiply. Dimension exponents summed. + /// Shape {1} RHS is automatically broadcast across all elements. + pub inline fn mul(self: Self, r: anytype) Tensor( T, - N, dims.add(RhsT(@TypeOf(r)).dims).argsOpt(), hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - const SelfNorm = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const RhsNorm = Quantity(T, RhsType.Len, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); + if (comptime RhsType.total != total and RhsType.total != 1) + @compileError("Shape mismatch in mul."); + + const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rr: Vec = broadcastToVec(RhsNorm, rr_base); return .{ .data = if (comptime hlp.isInt(T)) l *| rr else l * rr }; } - /// Element-wise divide. Dimension exponents are subtracted. - /// An N=1 rhs on an N>1 self is automatically broadcast. - pub inline fn div(self: Self, r: anytype) Quantity( + /// Element-wise divide. Dimension exponents subtracted. + /// Shape {1} RHS is automatically broadcast across all elements. + pub inline fn div(self: Self, r: anytype) Tensor( T, - N, dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(), hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - const SelfNorm = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const RhsNorm = Quantity(T, RhsType.Len, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); + if (comptime RhsType.total != total and RhsType.total != 1) + @compileError("Shape mismatch in div."); + + const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rr: Vec = broadcastToVec(RhsNorm, rr_base); if (comptime hlp.isInt(T)) { var result: Vec = undefined; - inline for (0..N) |i| result[i] = @divTrunc(l[i], rr[i]); + inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]); return .{ .data = result }; } else { return .{ .data = l / rr }; } } - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Unary - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── - /// Absolute value of every lane. Uses native `@abs` (SIMD for floats & ints). + /// Absolute value of every element. pub inline fn abs(self: Self) Self { return .{ .data = @bitCast(@abs(self.data)) }; } - /// Raise every lane to a comptime integer exponent. - /// Repeated SIMD multiply — good for small exponents. - pub inline fn pow(self: Self, comptime exp: comptime_int) Quantity( + /// Raise every element to a comptime integer exponent. + pub inline fn pow(self: Self, comptime exp: comptime_int) Tensor( T, - N, dims.scale(exp).argsOpt(), scales.argsOpt(), + shape_, ) { if (comptime hlp.isInt(T)) { - // No SIMD pow for integers — element-wise std.math.powi. var result: Vec = undefined; - inline for (0..N) |i| + inline for (0..total) |i| result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T); return .{ .data = result }; } else { - // Float: unrolled SIMD multiplications. const abs_exp = comptime @abs(exp); var result: Vec = @splat(1); comptime var i = 0; @@ -248,130 +393,127 @@ pub fn Quantity( } } - /// Square root of every lane. All dimension exponents must be even. - pub inline fn sqrt(self: Self) Quantity( + /// Square root of every element. All dimension exponents must be even. + pub inline fn sqrt(self: Self) Tensor( T, - N, dims.div(2).argsOpt(), scales.argsOpt(), + shape_, ) { if (comptime !dims.isSquare()) @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); if (comptime @typeInfo(T) == .float) { return .{ .data = @sqrt(self.data) }; } else { - // Integer sqrt is not SIMD-able — element-wise. var result: Vec = undefined; const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); - inline for (0..N) |i| { + inline for (0..total) |i| { const v = self.data[i]; - if (v < 0) - result[i] = 0 - else - result[i] = @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); + result[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); } return .{ .data = result }; } } - /// Negate every lane. + /// Negate every element. pub inline fn negate(self: Self) Self { return .{ .data = -self.data }; } - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Conversion - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── - /// Convert to a compatible quantity type. Dimension mismatch is a compile error. - /// Dest can have the same Len as this quantity, or Len == 1 (in which case it - /// will be automatically recast to this quantity's Len). - /// The scale ratio is computed entirely at comptime; the only runtime cost is - /// a SIMD multiply-by-splat (or element-wise cast for cross-numeric-type conversions). + /// Convert to a compatible Tensor type. + /// • Dimension mismatch → compile error. + /// • Dest.total must equal self.total, or Dest.total == 1 (scalar pattern). + /// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. pub inline fn to( self: Self, comptime Dest: type, - ) Quantity(Dest.ValueType, N, Dest.dims.argsOpt(), Dest.scales.argsOpt()) { - const ActualDest = Quantity(Dest.ValueType, N, Dest.dims.argsOpt(), Dest.scales.argsOpt()); + ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { + const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); if (comptime !dims.eql(ActualDest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); - if (comptime Self == ActualDest) return self; - // Allow Dest to be exactly matching Len or a Scalar (Len == 1) - comptime std.debug.assert(Dest.Len == N or Dest.Len == 1); + comptime std.debug.assert(Dest.total == total or Dest.total == 1); - const DestT = ActualDest.ValueType; - const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); - const DestVec = @Vector(N, DestT); + const DestT = ActualDest.ValueType; + const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); + const DestVec = @Vector(total, DestT); - // ── Same numeric type path ── + // ── Same numeric type ────────────────────────────────────── if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; - // Integer logic: Branching prevents division by zero errors + // Integer — branch prevents division-by-zero if (comptime ratio >= 1.0) { - // Upscaling (e.g., km -> m, ratio = 1000) const mult: T = comptime @intFromFloat(@round(ratio)); return .{ .data = self.data *| @as(Vec, @splat(mult)) }; } else { - // Downscaling (e.g., m -> km, ratio = 0.001) const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); + const half: T = comptime @divTrunc(div_val, 2); var result: DestVec = undefined; - const half: T = comptime @divTrunc(div_val, 2); - - inline for (0..N) |i| { + inline for (0..total) |i| { const val = self.data[i]; - // Rounding division for integers - result[i] = if (val >= 0) @divTrunc(val + half, div_val) else @divTrunc(val - half, div_val); + result[i] = if (val >= 0) + @divTrunc(val + half, div_val) + else + @divTrunc(val - half, div_val); } return .{ .data = result }; } } - // ── Cross-numeric-type (unchanged) ── + // ── Cross numeric type ───────────────────────────────────── var result: DestVec = undefined; - inline for (0..N) |i| { + inline for (0..total) |i| { const float_val: f64 = switch (comptime @typeInfo(T)) { .float => @floatCast(self.data[i]), - .int => @floatFromInt(self.data[i]), - else => unreachable, + .int => @floatFromInt(self.data[i]), + else => unreachable, }; const scaled = float_val * ratio; result[i] = switch (comptime @typeInfo(DestT)) { .float => @floatCast(scaled), - .int => @intFromFloat(@round(scaled)), - else => unreachable, + .int => @intFromFloat(@round(scaled)), + else => unreachable, }; } return .{ .data = result }; } - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── // Comparisons // - // Return type: bool when N = 1 (Scalar semantics) - // [N]bool when N > 1 (Vector semantics, element-wise) + // Return type: bool when total == 1 (scalar semantics) + // [total]bool when total > 1 (element-wise, flat-indexed) // - // Whole-vector "all equal/any differ" → use eqAll / neAll. - // Broadcast scalar comparison → use eqScalar / gtScalar / … - // --------------------------------------------------------------- + // Whole-tensor equality check → eqAll / neAll (always returns bool). + // A shape {1} RHS is broadcast automatically, unifying the old + // eqScalar / gtScalar / … family into the plain eq / gt / … methods. + // ─────────────────────────────────────────────────────────────── - const CmpResult = if (N == 1) bool else [N]bool; + const CmpResult = if (total == 1) bool else [total]bool; - inline fn cmpResult(v: @Vector(N, bool)) CmpResult { - return if (comptime N == 1) v[0] else @as([N]bool, v); + inline fn cmpResult(v: @Vector(total, bool)) CmpResult { + return if (comptime total == 1) v[0] else @as([total]bool, v); } + /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { - const RhsType = @TypeOf(rhs_q); - const TargetType = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - return .{ - .l = if (comptime Self == TargetType) self.data else self.to(TargetType).data, - .r = if (comptime RhsType == TargetType) rhs_q.data else rhs_q.to(TargetType).data, + const RhsType = @TypeOf(rhs_q); + const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); + const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; + const rr: Vec = blk: { + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + break :blk broadcastToVec(RhsNorm, rn); }; + return .{ .l = l, .r = rr }; } pub inline fn eq(self: Self, r: anytype) CmpResult { @@ -422,11 +564,7 @@ pub fn Quantity( return cmpResult(p.l <= p.r); } - // --------------------------------------------------------------- - // Vector whole-quantity comparisons (N > 1 intended, but work for N=1 too) - // --------------------------------------------------------------- - - /// True iff every lane is equal after scale resolution. + /// True iff every element is equal after scale resolution. pub inline fn eqAll(self: Self, other: anytype) bool { if (comptime !dims.eql(@TypeOf(other).dims)) @compileError("Dimension mismatch in eqAll."); @@ -434,121 +572,115 @@ pub fn Quantity( return @reduce(.And, p.l == p.r); } + /// True iff any element differs after scale resolution. pub inline fn neAll(self: Self, other: anytype) bool { return !self.eqAll(other); } - // --------------------------------------------------------------- - // Vector broadcast-scalar comparisons (always returns [N]bool) - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── + // Contraction — generalised dot product / matrix multiply / einsum + // + // a.contract(b, axis_a, axis_b) + // + // Sums over dimension `axis_a` of `a` and `axis_b` of `b`. + // Requires a.shape[axis_a] == b.shape[axis_b] (checked at comptime). + // + // Result shape = a.shape \ axis_a ++ b.shape \ axis_b + // Result dims = a.dims + b.dims (exponents summed, as in mul) + // Result scales = finer of a, b + // + // Special cases: + // rank-1 × rank-1, axis 0 × 0 → dot product (result shape {1}) + // rank-2 × rank-2, axis 1 × 0 → matrix multiply + // rank-1 × rank-2, axis 0 × 0 → vector–matrix product + // + // All index arithmetic is comptime; runtime cost is the multiply-add loop only. + // ─────────────────────────────────────────────────────────────── - inline fn broadcastScalarForCmp(self: Self, scalar: anytype) struct { l: Vec, r: Vec } { - const s = scalarRhs(scalar); - const SN = @TypeOf(s); - const TargetScalar = Quantity(T, 1, dims.argsOpt(), hlp.finerScales(Self, SN).argsOpt()); - const TargetSelf = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, SN).argsOpt()); - const s_val: T = if (comptime SN == TargetScalar) s.data[0] else s.to(TargetScalar).data[0]; - const l: Vec = if (comptime Self == TargetSelf) self.data else self.to(TargetSelf).data; - return .{ .l = l, .r = @splat(s_val) }; + pub inline fn contract( + self: Self, + other: anytype, + comptime axis_a: usize, + comptime axis_b: usize, + ) blk: { + const OT = @TypeOf(other); + comptime std.debug.assert(axis_a < rank); + comptime std.debug.assert(axis_b < OT.rank); + comptime std.debug.assert(shape_[axis_a] == OT.shape[axis_b]); + // Contracted-away free axes; empty joint → scalar shape {1} + const sa = shapeRemoveAxis(shape_, axis_a); + const sb = shapeRemoveAxis(OT.shape, axis_b); + const rs_raw = shapeCat(&sa, &sb); + const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw; + break :blk Tensor( + T, + dims.add(OT.dims).argsOpt(), + hlp.finerScales(Self, OT).argsOpt(), + rs, + ); + } { + const OT = @TypeOf(other); + const k: usize = comptime shape_[axis_a]; // contraction dimension + + const sa = comptime shapeRemoveAxis(shape_, axis_a); + const sb = comptime shapeRemoveAxis(OT.shape, axis_b); + const rs_raw = comptime shapeCat(&sa, &sb); + const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; + + const ResultType = Tensor( + T, + dims.add(OT.dims).argsOpt(), + hlp.finerScales(Self, OT).argsOpt(), + rs, + ); + + // Normalise scales before accumulation + const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), shape_); + const OtherNorm = Tensor(T, OT.dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), OT.shape); + const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; + + // Precompute result strides from rs_raw (for coord decoding) + const rs_raw_strides = comptime shapeStrides(&rs_raw); + + var result: ResultType = .{ .data = @splat(0) }; + + inline for (0..ResultType.total) |res_flat| { + // Decode result flat index into free coords using rs_raw layout. + // When rs_raw.len == 0, decodeFlatCoords returns [0]usize{} — correct. + const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); + + const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*; + const b_free: [sb.len]usize = comptime res_coords[sa.len..].*; + + var acc: T = 0; + inline for (0..k) |ki| { + // Reinsert the contracted index into free coords → full coord arrays + const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free); + const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free); + const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); + const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); + + if (comptime hlp.isInt(T)) + acc +|= a_data[a_flat] *| b_data[b_flat] + else + acc += a_data[a_flat] * b_data[b_flat]; + } + result.data[res_flat] = acc; + } + return result; } - pub inline fn eqScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l == p.r); - } + // ─────────────────────────────────────────────────────────────── + // Reduction helpers + // ─────────────────────────────────────────────────────────────── - pub inline fn neScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l != p.r); - } - - pub inline fn gtScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l > p.r); - } - - pub inline fn gteScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l >= p.r); - } - - pub inline fn ltScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l < p.r); - } - - pub inline fn lteScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l <= p.r); - } - - // --------------------------------------------------------------- - // Vector broadcast multiply / divide - // (These are explicit aliases for mul/div with an N=1 rhs; kept for - // clarity and backward-compat with the old Vector API.) - // --------------------------------------------------------------- - - pub inline fn mulScalar(self: Self, scalar: anytype) Quantity( - T, - N, - dims.add(ScalarRhsT(@TypeOf(scalar)).dims).argsOpt(), - hlp.finerScales(Self, ScalarRhsT(@TypeOf(scalar))).argsOpt(), - ) { - return self.mul(scalar); - } - - pub inline fn divScalar(self: Self, scalar: anytype) Quantity( - T, - N, - dims.sub(ScalarRhsT(@TypeOf(scalar)).dims).argsOpt(), - hlp.finerScales(Self, ScalarRhsT(@TypeOf(scalar))).argsOpt(), - ) { - return self.div(scalar); - } - - // --------------------------------------------------------------- - // Vector geometric operations - // --------------------------------------------------------------- - - /// Dot product — sum of element-wise products; returns a Scalar. - pub inline fn dot(self: Self, other: anytype) Quantity( - T, - 1, - dims.add(@TypeOf(other).dims).argsOpt(), - hlp.finerScales(Self, @TypeOf(other)).argsOpt(), - ) { - const Tr = @TypeOf(other); - const SelfNorm = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, Tr).argsOpt()); - const OtherNorm = Quantity(T, N, Tr.dims.argsOpt(), hlp.finerScales(Self, Tr).argsOpt()); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const r2: Vec = if (comptime Tr == OtherNorm) other.data else other.to(OtherNorm).data; - return .{ .data = .{@reduce(.Add, l * r2)} }; - } - - /// 3D cross product. Requires N = 3. - pub inline fn cross(self: Self, other: anytype) Quantity( - T, - 3, - dims.add(@TypeOf(other).dims).argsOpt(), - hlp.finerScales(Self, @TypeOf(other)).argsOpt(), - ) { - comptime if (N != 3) @compileError("cross() requires len=3."); - const a = self.data; - const b = other.data; - return .{ .data = .{ - a[1] * b[2] - a[2] * b[1], - a[2] * b[0] - a[0] * b[2], - a[0] * b[1] - a[1] * b[0], - } }; - } - - /// Sum of squared components. Cheaper than length(); use for comparisons. + /// Sum of squared elements. Cheaper than length(); use for ordering. pub inline fn lengthSqr(self: Self) T { return @reduce(.Add, self.data * self.data); } - /// Euclidean length. Float types use SIMD @reduce → @sqrt. - /// Integer types use integer sqrt (truncated). + /// Euclidean length (L2 norm). pub inline fn length(self: Self) T { const sq = self.lengthSqr(); if (comptime @typeInfo(T) == .int) { @@ -558,43 +690,46 @@ pub fn Quantity( return @sqrt(sq); } - /// Product of all components. Result dimension is (original dim × N). - pub inline fn product(self: Self) Quantity(T, 1, dims.scale(N).argsOpt(), scales.argsOpt()) { + /// Product of all elements. Result has shape {1}; dimension exponent × total. + pub inline fn product(self: Self) Tensor( + T, + dims.scale(@as(comptime_int, total)).argsOpt(), + scales.argsOpt(), + &.{1}, + ) { return .{ .data = .{@reduce(.Mul, self.data)} }; } - // --------------------------------------------------------------- - // Formatting (unchanged from old Scalar / Vector) - // --------------------------------------------------------------- + // ─────────────────────────────────────────────────────────────── + // Formatting + // ─────────────────────────────────────────────────────────────── pub fn formatNumber( self: Self, writer: *std.Io.Writer, options: std.fmt.Number, ) !void { - if (comptime N == 1) { - // Scalar-style: just print the value + units + if (comptime total == 1) { switch (@typeInfo(T)) { .float, .comptime_float => try writer.printFloat(self.data[0], options), - .int, .comptime_int => try writer.printInt(self.data[0], 10, .lower, .{ - .width = options.width, + .int, .comptime_int => try writer.printInt(self.data[0], 10, .lower, .{ + .width = options.width, .alignment = options.alignment, - .fill = options.fill, + .fill = options.fill, .precision = options.precision, }), else => unreachable, } } else { - // Vector-style: (v0, v1, …) + units try writer.writeAll("("); - inline for (0..N) |i| { + inline for (0..total) |i| { if (i > 0) try writer.writeAll(", "); switch (@typeInfo(T)) { .float, .comptime_float => try writer.printFloat(self.data[i], options), - .int, .comptime_int => try writer.printInt(self.data[i], 10, .lower, .{ - .width = options.width, + .int, .comptime_int => try writer.printInt(self.data[i], 10, .lower, .{ + .width = options.width, .alignment = options.alignment, - .fill = options.fill, + .fill = options.fill, .precision = options.precision, }), else => unreachable, @@ -622,43 +757,49 @@ pub fn Quantity( }; } -// Scalar tests +// ═════════════════════════════════════════════════════════════════════════════ +// Tests +// ───────────────────────────────────────────────────────────────────────────── +// Naming convention used throughout: +// Tensor(T, d, s, &.{1}) → former Scalar +// Tensor(T, d, s, &.{N}) → former Vector of length N +// .data[0] → former .value() +// .mul(x) → former .mulScalar(x) (x may be scalar Tensor or bare number) +// .div(x) → former .divScalar(x) +// .eq(x) → former .eqScalar(x) (broadcasts when x.total==1) +// .contract(other, 0, 0) → former .dot(other) (for rank-1 tensors) +// ═════════════════════════════════════════════════════════════════════════════ -pub fn Scalar(comptime T: type, comptime d: Dimensions.ArgOpts, comptime s: Scales.ArgOpts) type { - return Quantity(T, 1, d, s); -} +// ─── Scalar tests ───────────────────────────────────────────────────────── test "Scalar initiat" { - const Meter = Scalar(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }); - const Second = Scalar(f32, .{ .T = 1 }, .{ .T = .n }); + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{ .T = .n }, &.{1}); const distance = Meter.splat(10); - const time = Second.splat(2); + const time = Second.splat(2); - try std.testing.expectEqual(10, distance.value()); - try std.testing.expectEqual(2, time.value()); + try std.testing.expectEqual(10, distance.data[0]); + try std.testing.expectEqual(2, time.data[0]); } test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const KiloMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); const m1000 = Meter.splat(1000); - const km1 = KiloMeter.splat(1); - const km2 = KiloMeter.splat(2); + const km1 = KiloMeter.splat(1); + const km2 = KiloMeter.splat(2); - // Equal / Not Equal try std.testing.expect(m1000.eq(km1)); try std.testing.expect(km1.eq(m1000)); try std.testing.expect(km2.ne(m1000)); - // Greater Than / Greater Than or Equal try std.testing.expect(km2.gt(m1000)); try std.testing.expect(km2.gt(km1)); try std.testing.expect(km1.gte(m1000)); try std.testing.expect(km2.gte(m1000)); - // Less Than / Less Than or Equal try std.testing.expect(m1000.lt(km2)); try std.testing.expect(km1.lt(km2)); try std.testing.expect(km1.lte(m1000)); @@ -666,71 +807,65 @@ test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { } test "Scalar Add" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const distance = Meter.splat(10); + const distance = Meter.splat(10); const distance2 = Meter.splat(20); + const added = distance.add(distance2); + try std.testing.expectEqual(30, added.data[0]); + try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); - const added = distance.add(distance2); - try std.testing.expectEqual(30, added.value()); - try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); - - const KiloMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); const distance3 = KiloMeter.splat(2); - const added2 = distance.add(distance3); - try std.testing.expectEqual(2010, added2.value()); - try std.testing.expectEqual(1, @TypeOf(added2).dims.get(.L)); + const added2 = distance.add(distance3); + try std.testing.expectEqual(2010, added2.data[0]); const added3 = distance3.add(distance).to(KiloMeter); - try std.testing.expectEqual(2, added3.value()); - try std.testing.expectEqual(1, @TypeOf(added3).dims.get(.L)); + try std.testing.expectEqual(2, added3.data[0]); - const KiloMeter_f = Scalar(f64, .{ .L = 1 }, .{ .L = .k }); const distance4 = KiloMeter_f.splat(2); - const added4 = distance4.add(distance).to(KiloMeter_f); - try std.testing.expectApproxEqAbs(2.01, added4.value(), 0.000001); - try std.testing.expectEqual(1, @TypeOf(added4).dims.get(.L)); + const added4 = distance4.add(distance).to(KiloMeter_f); + try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001); } test "Scalar Sub" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const KiloMeter_f = Scalar(f64, .{ .L = 1 }, .{ .L = .k }); - - const a = Meter.splat(500); - const b = Meter.splat(200); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const a = Meter.splat(500); + const b = Meter.splat(200); const diff = a.sub(b); - try std.testing.expectEqual(300, diff.value()); + try std.testing.expectEqual(300, diff.data[0]); const diff2 = b.sub(a); - try std.testing.expectEqual(-300, diff2.value()); + try std.testing.expectEqual(-300, diff2.data[0]); - const km_f = KiloMeter_f.splat(2.5); - const m_f = Meter.splat(500); + const km_f = KiloMeter_f.splat(2.5); + const m_f = Meter.splat(500); const diff3 = km_f.sub(m_f); - try std.testing.expectApproxEqAbs(2000, diff3.value(), 1e-4); + try std.testing.expectApproxEqAbs(2000, diff3.data[0], 1e-4); } test "Scalar MulBy" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const Second = Scalar(f32, .{ .T = 1 }, .{}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); - const d = Meter.splat(3.0); - const t = Second.splat(4.0); + const d = Meter.splat(3); + const t = Second.splat(4); + const at = d.mul(t); + try std.testing.expectEqual(12, at.data[0]); + try std.testing.expectEqual(1, @TypeOf(at).dims.get(.L)); + try std.testing.expectEqual(1, @TypeOf(at).dims.get(.T)); - const area_time = d.mul(t); - try std.testing.expectEqual(12, area_time.value()); - try std.testing.expectEqual(1, @TypeOf(area_time).dims.get(.L)); - try std.testing.expectEqual(1, @TypeOf(area_time).dims.get(.T)); - - const d2 = Meter.splat(5.0); + const d2 = Meter.splat(5); const area = d.mul(d2); - try std.testing.expectEqual(15, area.value()); - try std.testing.expectEqual(2, @TypeOf(area).dims.get(.L)); + try std.testing.expectEqual(15, area.data[0]); + try std.testing.expectEqual(2, @TypeOf(area).dims.get(.L)); } test "Scalar MulBy with scale" { - const KiloMeter = Scalar(f32, .{ .L = 1 }, .{ .L = .k }); - const KiloGram = Scalar(f32, .{ .M = 1 }, .{ .M = .k }); + const KiloMeter = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const KiloGram = Tensor(f32, .{ .M = 1 }, .{ .M = .k }, &.{1}); const dist = KiloMeter.splat(2.0); const mass = KiloGram.splat(3.0); @@ -740,126 +875,116 @@ test "Scalar MulBy with scale" { } test "Scalar MulBy with type change" { - const Meter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); - const Second = Scalar(f64, .{ .T = 1 }, .{}); - const KmSec = Scalar(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }); - const KmSec_f = Scalar(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }); + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Second = Tensor(f64, .{ .T = 1 }, .{}, &.{1}); + const KmSec = Tensor(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); + const KmSec_f = Tensor(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); - const d = Meter.splat(3.0); - const t = Second.splat(4.0); + const d = Meter.splat(3); + const t = Second.splat(4); - const area_time = d.mul(t).to(KmSec); - const area_time_f = d.mul(t).to(KmSec_f); - try std.testing.expectEqual(12, area_time.value()); - try std.testing.expectApproxEqAbs(12.0, area_time_f.value(), 0.0001); + try std.testing.expectEqual(12, d.mul(t).to(KmSec).data[0]); + try std.testing.expectApproxEqAbs(12.0, d.mul(t).to(KmSec_f).data[0], 0.0001); } test "Scalar MulBy small" { - const Meter = Scalar(i128, .{ .L = 1 }, .{ .L = .n }); - const Second = Scalar(f32, .{ .T = 1 }, .{}); - - const d = Meter.splat(3.0); - const t = Second.splat(4.0); - - const area_time = d.mul(t); - try std.testing.expectEqual(12, area_time.value()); + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .n }, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const d = Meter.splat(3); + const t = Second.splat(4); + try std.testing.expectEqual(12, d.mul(t).data[0]); } test "Scalar MulBy dimensionless" { - const DimLess = Scalar(i128, .{}, .{}); - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - - const d = Meter.splat(7); + const DimLess = Tensor(i128, .{}, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(7); const scaled = d.mul(DimLess.splat(3)); - try std.testing.expectEqual(21, scaled.value()); + try std.testing.expectEqual(21, scaled.data[0]); } test "Scalar Sqrt" { - const MeterSquare = Scalar(i128, .{ .L = 2 }, .{}); + const MeterSquare = Tensor(i128, .{ .L = 2 }, .{}, &.{1}); + const MeterSquare_f = Tensor(f64, .{ .L = 2 }, .{}, &.{1}); - var d = MeterSquare.splat(9); + var d = MeterSquare.splat(9); var scaled = d.sqrt(); - try std.testing.expectEqual(3, scaled.value()); + try std.testing.expectEqual(3, scaled.data[0]); try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); - d = MeterSquare.splat(-5); + d = MeterSquare.splat(-5); scaled = d.sqrt(); - try std.testing.expectEqual(0, scaled.value()); + try std.testing.expectEqual(0, scaled.data[0]); - const MeterSquare_f = Scalar(f64, .{ .L = 2 }, .{}); - const d2 = MeterSquare_f.splat(20); + const d2 = MeterSquare_f.splat(20); const scaled2 = d2.sqrt(); - try std.testing.expectApproxEqAbs(4.472135955, scaled2.value(), 1e-4); + try std.testing.expectApproxEqAbs(4.472135955, scaled2.data[0], 1e-4); } test "Scalar Chained: velocity and acceleration" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const Second = Scalar(f32, .{ .T = 1 }, .{}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); - const dist = Meter.splat(100.0); - const t1 = Second.splat(5.0); + const dist = Meter.splat(100); + const t1 = Second.splat(5); const velocity = dist.div(t1); - try std.testing.expectEqual(20, velocity.value()); + try std.testing.expectEqual(20, velocity.data[0]); - const t2 = Second.splat(4.0); + const t2 = Second.splat(4); const accel = velocity.div(t2); - try std.testing.expectEqual(5, accel.value()); + try std.testing.expectEqual(5, accel.data[0]); } test "Scalar DivBy integer exact" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const Second = Scalar(f32, .{ .T = 1 }, .{}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); const dist = Meter.splat(120); const time = Second.splat(4); - const vel = dist.div(time); - - try std.testing.expectEqual(30, vel.value()); + const vel = dist.div(time); + try std.testing.expectEqual(30, vel.data[0]); } test "Scalar Finer scales skip dim 0" { - const Dimless = Scalar(i128, .{}, .{}); - const KiloMetre = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); + const Dimless = Tensor(i128, .{}, .{}, &.{1}); + const KiloMetre = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const r = Dimless.splat(30); - const time = KiloMetre.splat(4); - const vel = r.mul(time); - - try std.testing.expectEqual(120, vel.value()); + const r = Dimless.splat(30); + const km = KiloMetre.splat(4); + const vel = r.mul(km); + try std.testing.expectEqual(120, vel.data[0]); try std.testing.expectEqual(Scales.UnitScale.k, @TypeOf(vel).scales.get(.L)); } test "Scalar Conversion chain: km -> m -> cm" { - const KiloMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const CentiMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .c }); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const CentiMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .c }, &.{1}); const km = KiloMeter.splat(15); - const m = km.to(Meter); + const m = km.to(Meter); const cm = m.to(CentiMeter); - - try std.testing.expectEqual(15_000, m.value()); - try std.testing.expectEqual(1_500_000, cm.value()); + try std.testing.expectEqual(15_000, m.data[0]); + try std.testing.expectEqual(1_500_000, cm.data[0]); } test "Scalar Conversion: hours -> minutes -> seconds" { - const Hour = Scalar(i128, .{ .T = 1 }, .{ .T = .hour }); - const Minute = Scalar(i128, .{ .T = 1 }, .{ .T = .min }); - const Second = Scalar(i128, .{ .T = 1 }, .{}); + const Hour = Tensor(i128, .{ .T = 1 }, .{ .T = .hour }, &.{1}); + const Minute = Tensor(i128, .{ .T = 1 }, .{ .T = .min }, &.{1}); + const Second = Tensor(i128, .{ .T = 1 }, .{}, &.{1}); - const h = Hour.splat(1.0); + const h = Hour.splat(1); const min = h.to(Minute); const sec = min.to(Second); - - try std.testing.expectEqual(60, min.value()); - try std.testing.expectEqual(3600, sec.value()); + try std.testing.expectEqual(60, min.data[0]); + try std.testing.expectEqual(3600, sec.data[0]); } -test "Scalar Format Scalar" { - const MeterPerSecondSq = Scalar(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }); - const Meter = Scalar(f32, .{ .L = 1 }, .{}); +test "Scalar Format" { + const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{1}); + const Meter = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); - const m = Meter.splat(1.23456); + const m = Meter.splat(1.23456); const accel = MeterPerSecondSq.splat(9.81); var buf: [64]u8 = undefined; @@ -871,96 +996,72 @@ test "Scalar Format Scalar" { } test "Scalar Abs" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const m1 = Meter.splat(-50); - const m2 = m1.abs(); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const MeterF = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); - try std.testing.expectEqual(50, m2.value()); - - const m_float = Scalar(f32, .{ .L = 1 }, .{}); - const m3 = m_float.splat(-42.5); - try std.testing.expectEqual(42.5, m3.abs().value()); + try std.testing.expectEqual(50, Meter.splat(-50).abs().data[0]); + try std.testing.expectEqual(42.5, MeterF.splat(-42.5).abs().data[0]); } test "Scalar Pow" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const d = Meter.splat(4); - - const area = d.pow(2); - try std.testing.expectEqual(16, area.value()); - - const volume = d.pow(3); - try std.testing.expectEqual(64, volume.value()); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(4); + try std.testing.expectEqual(16, d.pow(2).data[0]); + try std.testing.expectEqual(64, d.pow(3).data[0]); } test "Scalar mul comptime_int" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const d = Meter.splat(7); - - const scaled = d.mul(3); - try std.testing.expectEqual(21, scaled.value()); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(7); + try std.testing.expectEqual(21, d.mul(3).data[0]); } test "Scalar add/sub bare number on dimensionless scalar" { - const DimLess = Scalar(i128, .{}, .{}); + const DimLess = Tensor(i128, .{}, .{}, &.{1}); const a = DimLess.splat(10); - - const b = a.add(5); - try std.testing.expectEqual(15, b.value()); - - const c = a.sub(3); - try std.testing.expectEqual(7, c.value()); + try std.testing.expectEqual(15, a.add(5).data[0]); + try std.testing.expectEqual(7, a.sub(3).data[0]); } test "Scalar Imperial length scales" { - const Foot = Scalar(f64, .{ .L = 1 }, .{ .L = .ft }); - const Meter = Scalar(f64, .{ .L = 1 }, .{}); - const Inch = Scalar(f64, .{ .L = 1 }, .{ .L = .inch }); + const Foot = Tensor(f64, .{ .L = 1 }, .{ .L = .ft }, &.{1}); + const Meter = Tensor(f64, .{ .L = 1 }, .{}, &.{1}); + const Inch = Tensor(f64, .{ .L = 1 }, .{ .L = .inch }, &.{1}); - const one_ft = Foot.splat(1.0); - try std.testing.expectApproxEqAbs(0.3048, one_ft.to(Meter).value(), 1e-9); - - const twelve_in = Inch.splat(12.0); - try std.testing.expectApproxEqAbs(1.0, twelve_in.to(Foot).value(), 1e-9); + try std.testing.expectApproxEqAbs(0.3048, Foot.splat(1.0).to(Meter).data[0], 1e-9); + try std.testing.expectApproxEqAbs(1.0, Inch.splat(12.0).to(Foot).data[0], 1e-9); } test "Scalar Imperial mass scales" { - const Pound = Scalar(f64, .{ .M = 1 }, .{ .M = .lb }); - const Ounce = Scalar(f64, .{ .M = 1 }, .{ .M = .oz }); + const Pound = Tensor(f64, .{ .M = 1 }, .{ .M = .lb }, &.{1}); + const Ounce = Tensor(f64, .{ .M = 1 }, .{ .M = .oz }, &.{1}); - const two_lb = Pound.splat(2.0); - const eight_oz = Ounce.splat(8.0); - const total = two_lb.add(eight_oz).to(Pound); - try std.testing.expectApproxEqAbs(2.5, total.value(), 1e-6); + const total = Pound.splat(2.0).add(Ounce.splat(8.0)).to(Pound); + try std.testing.expectApproxEqAbs(2.5, total.data[0], 1e-6); } test "Scalar comparisons with comptime_int on dimensionless scalar" { - const DimLess = Scalar(i128, .{}, .{}); + const DimLess = Tensor(i128, .{}, .{}, &.{1}); const x = DimLess.splat(42); - try std.testing.expect(x.eq(42)); try std.testing.expect(x.gt(10)); } -// Vector tests - -pub fn Vector(N: comptime_int, Q: type) type { - return Quantity(Q.ValueType, N, Q.dims.argsOpt(), Q.scales.argsOpt()); -} +// ─── Vector / Tensor tests ──────────────────────────────────────────────── test "Vector initiate" { - const Meter = Vector(4, Scalar(f32, .{ .L = 1 }, .{})); - const m = Meter.splat(1); - + const Meter4 = Tensor(f32, .{ .L = 1 }, .{}, &.{4}); + const m = Meter4.splat(1); try std.testing.expect(m.data[0] == 1); + try std.testing.expect(m.data[3] == 1); } test "Vector format" { - const MeterPerSecondSq = Scalar(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }); - const KgMeterPerSecond = Scalar(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }); + const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{3}); + const KgMeterPerSecond = Tensor(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3}); - const accel = MeterPerSecondSq.Vec3.splat(9.81); - const momentum = KgMeterPerSecond.Vec3{ .data = .{ 43, 0, 11 } }; + const accel = MeterPerSecondSq.splat(9.81); + const momentum = KgMeterPerSecond{ .data = .{ 43, 0, 11 } }; var buf: [64]u8 = undefined; var res = try std.fmt.bufPrint(&buf, "{d}", .{accel}); @@ -971,28 +1072,20 @@ test "Vector format" { } test "Vector Vec3 Init and Basic Arithmetic" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const Vec3M = Meter.Vec3; + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); - // Test zero, one, initDefault - const v_zero = Vec3M.zero; + const v_zero = Meter3.zero; try std.testing.expectEqual(0, v_zero.data[0]); - try std.testing.expectEqual(0, v_zero.data[1]); try std.testing.expectEqual(0, v_zero.data[2]); - const v_one = Vec3M.one; + const v_one = Meter3.one; try std.testing.expectEqual(1, v_one.data[0]); - try std.testing.expectEqual(1, v_one.data[1]); - try std.testing.expectEqual(1, v_one.data[2]); - const v_def = Vec3M.splat(5); - try std.testing.expectEqual(5, v_def.data[0]); - try std.testing.expectEqual(5, v_def.data[1]); + const v_def = Meter3.splat(5); try std.testing.expectEqual(5, v_def.data[2]); - // Test add and sub - const v1 = Vec3M{ .data = .{ 10, 20, 30 } }; - const v2 = Vec3M{ .data = .{ 2, 4, 6 } }; + const v1 = Meter3{ .data = .{ 10, 20, 30 } }; + const v2 = Meter3{ .data = .{ 2, 4, 6 } }; const added = v1.add(v2); try std.testing.expectEqual(12, added.data[0]); @@ -1000,260 +1093,256 @@ test "Vector Vec3 Init and Basic Arithmetic" { try std.testing.expectEqual(36, added.data[2]); const subbed = v1.sub(v2); - try std.testing.expectEqual(8, subbed.data[0]); + try std.testing.expectEqual(8, subbed.data[0]); try std.testing.expectEqual(16, subbed.data[1]); try std.testing.expectEqual(24, subbed.data[2]); - // Test negate const neg = v1.negate(); try std.testing.expectEqual(-10, neg.data[0]); try std.testing.expectEqual(-20, neg.data[1]); try std.testing.expectEqual(-30, neg.data[2]); } -test "Vector Kinematics (Scalar Mul/Div)" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const Second = Scalar(i32, .{ .T = 1 }, .{}); - const Vec3M = Meter.Vec3; +test "Vector Kinematics (scalar mul/div broadcast)" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const Second1 = Tensor(i32, .{ .T = 1 }, .{}, &.{1}); - const pos = Vec3M{ .data = .{ 100, 200, 300 } }; - const time = Second.splat(10); + const pos = Meter3{ .data = .{ 100, 200, 300 } }; + const time = Second1.splat(10); - // Vector divided by scalar (Velocity = Position / Time) - const vel = pos.divScalar(time); - try std.testing.expectEqual(10, vel.data[0]); - try std.testing.expectEqual(20, vel.data[1]); - try std.testing.expectEqual(30, vel.data[2]); - try std.testing.expectEqual(1, @TypeOf(vel).dims.get(.L)); - try std.testing.expectEqual(-1, @TypeOf(vel).dims.get(.T)); + const vel = pos.div(time); + try std.testing.expectEqual(10, vel.data[0]); + try std.testing.expectEqual(20, vel.data[1]); + try std.testing.expectEqual(30, vel.data[2]); + try std.testing.expectEqual(1, @TypeOf(vel).dims.get(.L)); + try std.testing.expectEqual(-1, @TypeOf(vel).dims.get(.T)); - // Vector multiplied by scalar (Position = Velocity * Time) - const new_pos = vel.mulScalar(time); + const new_pos = vel.mul(time); try std.testing.expectEqual(100, new_pos.data[0]); - try std.testing.expectEqual(200, new_pos.data[1]); - try std.testing.expectEqual(300, new_pos.data[2]); - try std.testing.expectEqual(1, @TypeOf(new_pos).dims.get(.L)); - try std.testing.expectEqual(0, @TypeOf(new_pos).dims.get(.T)); + try std.testing.expectEqual(0, @TypeOf(new_pos).dims.get(.T)); } test "Vector Element-wise Math and Scaling" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const Vec3M = Meter.Vec3; + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); - const v1 = Vec3M{ .data = .{ 10, 20, 30 } }; - const v2 = Vec3M{ .data = .{ 2, 5, 10 } }; - - // Element-wise division - const div = v1.div(v2); - try std.testing.expectEqual(5, div.data[0]); - try std.testing.expectEqual(4, div.data[1]); - try std.testing.expectEqual(3, div.data[2]); - try std.testing.expectEqual(0, @TypeOf(div).dims.get(.L)); // M / M = Dimensionless + const v1 = Meter3{ .data = .{ 10, 20, 30 } }; + const v2 = Meter3{ .data = .{ 2, 5, 10 } }; + const dv = v1.div(v2); + try std.testing.expectEqual(5, dv.data[0]); + try std.testing.expectEqual(4, dv.data[1]); + try std.testing.expectEqual(3, dv.data[2]); + try std.testing.expectEqual(0, @TypeOf(dv).dims.get(.L)); } test "Vector Conversions" { - const KiloMeter = Scalar(i32, .{ .L = 1 }, .{ .L = .k }); - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - - const v_km = KiloMeter.Vec3{ .data = .{ 1, 2, 3 } }; - const v_m = v_km.to(Meter); + const KiloMeter3 = Tensor(i32, .{ .L = 1 }, .{ .L = .k }, &.{3}); + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const v_km = KiloMeter3{ .data = .{ 1, 2, 3 } }; + const v_m = v_km.to(Meter3); try std.testing.expectEqual(1000, v_m.data[0]); try std.testing.expectEqual(2000, v_m.data[1]); try std.testing.expectEqual(3000, v_m.data[2]); - - // Type checking the result - try std.testing.expectEqual(1, @TypeOf(v_m).dims.get(.L)); try std.testing.expectEqual(UnitScale.none, @TypeOf(v_m).scales.get(.L)); } test "Vector Length" { - const MeterInt = Scalar(i32, .{ .L = 1 }, .{}); - const MeterFloat = Scalar(f32, .{ .L = 1 }, .{}); + const MeterInt3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const MeterFloat3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - // Integer length - // 3-4-5 triangle on XY plane - const v_int = MeterInt.Vec3{ .data = .{ 3, 4, 0 } }; + const v_int = MeterInt3{ .data = .{ 3, 4, 0 } }; try std.testing.expectEqual(25, v_int.lengthSqr()); - try std.testing.expectEqual(5, v_int.length()); + try std.testing.expectEqual(5, v_int.length()); - // Float length - const v_float = MeterFloat.Vec3{ .data = .{ 3.0, 4.0, 0.0 } }; + const v_float = MeterFloat3{ .data = .{ 3.0, 4.0, 0.0 } }; try std.testing.expectApproxEqAbs(@as(f32, 25.0), v_float.lengthSqr(), 1e-4); - try std.testing.expectApproxEqAbs(@as(f32, 5.0), v_float.length(), 1e-4); + try std.testing.expectApproxEqAbs(@as(f32, 5.0), v_float.length(), 1e-4); } test "Vector Comparisons" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - const KiloMeter = Scalar(f32, .{ .L = 1 }, .{ .L = .k }); + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter3 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{3}); - const v1 = Meter.Vec3{ .data = .{ 1000.0, 500.0, 0.0 } }; - const v2 = KiloMeter.Vec3{ .data = .{ 1.0, 0.5, 0.0 } }; - const v3 = KiloMeter.Vec3{ .data = .{ 1.0, 0.6, 0.0 } }; + const v1 = Meter3{ .data = .{ 1000.0, 500.0, 0.0 } }; + const v2 = KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } }; + const v3 = KiloMeter3{ .data = .{ 1.0, 0.6, 0.0 } }; - // 1. Equality (Whole vector) try std.testing.expect(v1.eqAll(v2)); try std.testing.expect(v1.neAll(v3)); - // 2. Element-wise Ordered Comparison - const higher = v3.gt(v1); // compares 1km, 0.6km, 0km vs 1000m, 500m, 0m - try std.testing.expectEqual(false, higher[0]); // 1km == 1000m - try std.testing.expectEqual(true, higher[1]); // 0.6km > 500m - try std.testing.expectEqual(false, higher[2]); // 0 == 0 + const higher = v3.gt(v1); + try std.testing.expectEqual(false, higher[0]); + try std.testing.expectEqual(true, higher[1]); + try std.testing.expectEqual(false, higher[2]); - // 3. Element-wise Equal Comparison - const equal = v3.eq(v1); // compares 1km, 0.6km, 0km vs 1000m, 500m, 0m - try std.testing.expectEqual(true, equal[0]); // 1km == 1000m - try std.testing.expectEqual(false, equal[1]); // 0.6km > 500m - try std.testing.expectEqual(true, equal[2]); // 0 == 0 + const equal = v3.eq(v1); + try std.testing.expectEqual(true, equal[0]); + try std.testing.expectEqual(false, equal[1]); + try std.testing.expectEqual(true, equal[2]); - // 3. Less than or equal const low_eq = v1.lte(v3); try std.testing.expect(low_eq[0] and low_eq[1] and low_eq[2]); } -test "Vector vs Scalar Comparisons" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - const KiloMeter = Scalar(f32, .{ .L = 1 }, .{ .L = .k }); +test "Vector vs Scalar broadcast comparison" { + // Replaces the old eqScalar / gtScalar — now just eq / gt with a shape-{1} rhs. + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const positions = Meter.Vec3{ .data = .{ 500.0, 1200.0, 3000.0 } }; - const threshold = KiloMeter.splat(1); // 1km (1000m) + const positions = Meter3{ .data = .{ 500.0, 1200.0, 3000.0 } }; + const threshold = KiloMeter1.splat(1); // 1 km = 1000 m - // Check which axes exceed the 1km threshold - const exceeded = positions.gtScalar(threshold); + const exceeded = positions.gt(threshold); + try std.testing.expectEqual(false, exceeded[0]); + try std.testing.expectEqual(true, exceeded[1]); + try std.testing.expectEqual(true, exceeded[2]); - try std.testing.expectEqual(false, exceeded[0]); // 500m > 1km is false - try std.testing.expectEqual(true, exceeded[1]); // 1200m > 1km is true - try std.testing.expectEqual(true, exceeded[2]); // 3000m > 1km is true - - // Check for equality (broadcasted) - const exact_match = positions.eqScalar(Meter.splat(500)); - try std.testing.expect(exact_match[0] == true); - try std.testing.expect(exact_match[1] == false); + const Meter1 = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + const exact = positions.eq(Meter1.splat(500)); + try std.testing.expect(exact[0] == true); + try std.testing.expect(exact[1] == false); } -test "Vector Dot and Cross Products" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - const Newton = Scalar(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}); +test "Vector contract — dot product (rank-1 × rank-1)" { + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const Newton3 = Tensor(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3}); - const pos = Meter.Vec3{ .data = .{ 10.0, 0.0, 0.0 } }; - const force = Newton.Vec3{ .data = .{ 5.0, 5.0, 0.0 } }; + const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } }; + const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } }; - // 1. Dot Product (Work = F dot d) - const work = force.dot(pos); - try std.testing.expectEqual(50.0, work.value()); - // Dimensions should be M¹L²T⁻² (Energy/Joules) + // work = force · pos + const work = force.contract(pos, 0, 0); + try std.testing.expectEqual(50.0, work.data[0]); try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M)); try std.testing.expectEqual(2, @TypeOf(work).dims.get(.L)); try std.testing.expectEqual(-2, @TypeOf(work).dims.get(.T)); +} - // 2. Cross Product (Torque = r cross F) - const torque = pos.cross(force); - try std.testing.expectEqual(0.0, torque.data[0]); - try std.testing.expectEqual(0.0, torque.data[1]); - try std.testing.expectEqual(50.0, torque.data[2]); - // Torque dimensions are same as Energy but as a Vector - try std.testing.expectEqual(2, @TypeOf(torque).dims.get(.L)); +test "Vector contract — matrix multiply (rank-2 × rank-2)" { + // 2×3 matrix multiplied by 3×2 matrix → 2×2 result + const A = Tensor(f32, .{}, .{}, &.{2, 3}); + const B = Tensor(f32, .{}, .{}, &.{3, 2}); + + // A = [[1, 2, 3], + // [4, 5, 6]] + const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } }; + // B = [[7, 8], + // [9, 10], + // [11, 12]] + const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } }; + + // C = A @ B (contract over axis 1 of A × axis 0 of B) + // C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58 + // C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64 + // C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139 + // C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154 + const c = a.contract(b, 1, 0); + try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 0})]); + try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 1})]); + try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 0})]); + try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 1})]); } test "Vector Abs, Pow, Sqrt and Product" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - const v1 = Meter.Vec3{ .data = .{ -2.0, 3.0, -4.0 } }; - - // 1. Abs + const v1 = Meter3{ .data = .{ -2.0, 3.0, -4.0 } }; const v_abs = v1.abs(); try std.testing.expectEqual(2.0, v_abs.data[0]); try std.testing.expectEqual(4.0, v_abs.data[2]); - // 2. Product (L1 * L1 * L1 = L3) const vol = v_abs.product(); - try std.testing.expectEqual(24.0, vol.value()); + try std.testing.expectEqual(24.0, vol.data[0]); try std.testing.expectEqual(3, @TypeOf(vol).dims.get(.L)); - // 3. Pow (Scalar exponent: (L1)^2 = L2) const area_vec = v_abs.pow(2); - try std.testing.expectEqual(4.0, area_vec.data[0]); + try std.testing.expectEqual(4.0, area_vec.data[0]); try std.testing.expectEqual(16.0, area_vec.data[2]); try std.testing.expectEqual(2, @TypeOf(area_vec).dims.get(.L)); - // 4. Sqrt const sqrted = area_vec.sqrt(); try std.testing.expectEqual(2, sqrted.data[0]); try std.testing.expectEqual(4, sqrted.data[2]); try std.testing.expectEqual(1, @TypeOf(sqrted).dims.get(.L)); } -test "Vector mulScalar comptime_int" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const v = Meter.Vec3{ .data = .{ 1, 2, 3 } }; - - const scaled = v.mulScalar(10); // comptime_int → dimensionless +test "Vector mul comptime_int broadcast" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const v = Meter3{ .data = .{ 1, 2, 3 } }; + const scaled = v.mul(10); try std.testing.expectEqual(10, scaled.data[0]); try std.testing.expectEqual(20, scaled.data[1]); try std.testing.expectEqual(30, scaled.data[2]); - // Dimensions unchanged: L¹ × dimensionless = L¹ try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); - try std.testing.expectEqual(0, @TypeOf(scaled).dims.get(.T)); } -test "Vector mulScalar comptime_float" { - const MeterF = Scalar(f32, .{ .L = 1 }, .{}); - const v = MeterF.Vec3{ .data = .{ 1.0, 2.0, 4.0 } }; - - const scaled = v.mulScalar(0.5); // comptime_float → dimensionless +test "Vector mul comptime_float broadcast" { + const MeterF3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } }; + const scaled = v.mul(0.5); try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6); try std.testing.expectApproxEqAbs(1.0, scaled.data[1], 1e-6); try std.testing.expectApproxEqAbs(2.0, scaled.data[2], 1e-6); try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); } -test "Vector mulScalar T (value type)" { - const MeterF = Scalar(f32, .{ .L = 1 }, .{}); - const v = MeterF.Vec3{ .data = .{ 3.0, 6.0, 9.0 } }; - const factor: f32 = 2.0; - - const scaled = v.mulScalar(factor); - try std.testing.expectApproxEqAbs(6.0, scaled.data[0], 1e-6); - try std.testing.expectApproxEqAbs(12.0, scaled.data[1], 1e-6); - try std.testing.expectApproxEqAbs(18.0, scaled.data[2], 1e-6); - try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); -} - -test "Vector divScalar comptime_int" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const v = Meter.Vec3{ .data = .{ 10, 20, 30 } }; - - const halved = v.divScalar(2); // comptime_int → dimensionless divisor - try std.testing.expectEqual(5, halved.data[0]); +test "Vector div comptime_int broadcast" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const v = Meter3{ .data = .{ 10, 20, 30 } }; + const halved = v.div(2); + try std.testing.expectEqual(5, halved.data[0]); try std.testing.expectEqual(10, halved.data[1]); try std.testing.expectEqual(15, halved.data[2]); try std.testing.expectEqual(1, @TypeOf(halved).dims.get(.L)); } -test "Vector divScalar comptime_float" { - const MeterF = Scalar(f64, .{ .L = 1 }, .{}); - const v = MeterF.Vec3{ .data = .{ 9.0, 6.0, 3.0 } }; - - const r = v.divScalar(3.0); +test "Vector div comptime_float broadcast" { + const MeterF3 = Tensor(f64, .{ .L = 1 }, .{}, &.{3}); + const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } }; + const r = v.div(3.0); try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9); try std.testing.expectApproxEqAbs(2.0, r.data[1], 1e-9); try std.testing.expectApproxEqAbs(1.0, r.data[2], 1e-9); - try std.testing.expectEqual(1, @TypeOf(r).dims.get(.L)); } -test "Vector eqScalar / gtScalar with comptime_int on dimensionless vector" { - // Bare numbers are dimensionless, so comparisons only work when vector is dimensionless too. - const DimLess = Scalar(i32, .{}, .{}); - const v = DimLess.Vec3{ .data = .{ 1, 2, 3 } }; +test "Vector eq broadcast on dimensionless" { + const DimLess3 = Tensor(i32, .{}, .{}, &.{3}); + const v = DimLess3{ .data = .{ 1, 2, 3 } }; - const eq_res = v.eqScalar(2); + const eq_res = v.eq(2); try std.testing.expectEqual(false, eq_res[0]); - try std.testing.expectEqual(true, eq_res[1]); + try std.testing.expectEqual(true, eq_res[1]); try std.testing.expectEqual(false, eq_res[2]); - const gt_res = v.gtScalar(1); + const gt_res = v.gt(1); try std.testing.expectEqual(false, gt_res[0]); - try std.testing.expectEqual(true, gt_res[1]); - try std.testing.expectEqual(true, gt_res[2]); + try std.testing.expectEqual(true, gt_res[1]); + try std.testing.expectEqual(true, gt_res[2]); +} + +test "Tensor idx helper and matrix access" { + const Mat3x3 = Tensor(f32, .{}, .{}, &.{3, 3}); + // Identity-like: set [0][0]=1, [1][1]=2, [2][2]=3 + var m: Mat3x3 = Mat3x3.zero; + m.data[Mat3x3.idx(.{0, 0})] = 1.0; + m.data[Mat3x3.idx(.{1, 1})] = 2.0; + m.data[Mat3x3.idx(.{2, 2})] = 3.0; + + try std.testing.expectEqual(1.0, m.data[0]); // [0][0] + try std.testing.expectEqual(2.0, m.data[4]); // [1][1] (stride 3 → 1*3+1=4) + try std.testing.expectEqual(3.0, m.data[8]); // [2][2] (2*3+2=8) + try std.testing.expectEqual(0.0, m.data[1]); // [0][1] +} + +test "Tensor strides_arr correctness" { + const T1 = Tensor(f32, .{}, .{}, &.{3}); + const T2 = Tensor(f32, .{}, .{}, &.{3, 4}); + const T3 = Tensor(f32, .{}, .{}, &.{2, 3, 4}); + + try std.testing.expectEqual(1, T1.strides_arr[0]); + try std.testing.expectEqual(4, T2.strides_arr[0]); + try std.testing.expectEqual(1, T2.strides_arr[1]); + try std.testing.expectEqual(12, T3.strides_arr[0]); + try std.testing.expectEqual(4, T3.strides_arr[1]); + try std.testing.expectEqual(1, T3.strides_arr[2]); } From f37a196b1540bb0c5ad8ed3dc880460b931f7e0c Mon Sep 17 00:00:00 2001 From: AdrienBouvais Date: Mon, 27 Apr 2026 09:11:24 +0200 Subject: [PATCH 02/10] Fixed new Tensor to be everything (Scalar, Vector, Matrix and above) --- src/Base.zig | 59 ++-- src/Scales.zig | 3 +- src/{Quantity.zig => Tensor.zig} | 585 +++++++++++++++++-------------- src/helper.zig | 97 ----- src/main.zig | 6 +- 5 files changed, 354 insertions(+), 396 deletions(-) rename src/{Quantity.zig => Tensor.zig} (73%) delete mode 100644 src/helper.zig diff --git a/src/Base.zig b/src/Base.zig index 4bde68b..4a475a7 100644 --- a/src/Base.zig +++ b/src/Base.zig @@ -3,34 +3,39 @@ const std = @import("std"); // Adjust these imports to match your actual file names const Dimensions = @import("Dimensions.zig"); const Scales = @import("Scales.zig"); -const Scalar = @import("Quantity.zig").Scalar; +const Tensor = @import("Tensor.zig").Tensor; fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime s: Scales.ArgOpts) type { return struct { - const dims = Dimensions.init(d); - const scales = Scales.init(s); + pub const dims = Dimensions.init(d); + pub const scales = Scales.init(s); /// Instantiates the constant into a specific numeric type. - pub fn Of(comptime T: type) Scalar(T, d, s) { - return .{ .data = @splat(@as(T, @floatCast(val))) }; + pub fn Of(comptime T: type) Tensor(T, d, s, &.{1}) { + const casted_val: T = switch (@typeInfo(T)) { + .float => @floatCast(val), + .int => @intFromFloat(val), + else => @compileError("Unsupported type for PhysicalConstant"), + }; + return Tensor(T, d, s, &.{1}).splat(casted_val); } }; } fn BaseScalar(comptime d: Dimensions.ArgOpts) type { return struct { - const dims = Dimensions.init(d); + pub const dims = Dimensions.init(d); /// Creates a Scalar of this dimension using default scales. - /// Example: const V = Quantities.Velocity.Base(f32); + /// Example: const V = Quantities.Velocity.Of(f32); pub fn Of(comptime T: type) type { - return Scalar(T, d, .{}); + return Tensor(T, d, .{}, &.{1}); } /// Creates a Scalar of this dimension using custom scales. - /// Example: const Kmh = Quantities.Velocity.Scaled(f32, Scales.init(.{ .L = .k, .T = .hour })); + /// Example: const Kmh = Quantities.Velocity.Scaled(f32, .{ .L = .k, .T = .hour }); pub fn Scaled(comptime T: type, comptime s: Scales.ArgOpts) type { - return Scalar(T, d, s); + return Tensor(T, d, s, &.{1}); } }; } @@ -107,7 +112,7 @@ pub const ElectricCapacitance = BaseScalar(.{ .T = 4, .L = -2, .M = -1, .I = 2 } pub const ElectricImpedance = ElectricResistance; pub const MagneticFlux = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .I = -1 }); pub const MagneticDensity = BaseScalar(.{ .M = 1, .T = -2, .I = -1 }); -pub const MagneticStrength = BaseScalar(.{ .L = -1, .I = 1 }); // Fixed typo from MagneticStrengh +pub const MagneticStrength = BaseScalar(.{ .L = -1, .I = 1 }); pub const MagneticMoment = BaseScalar(.{ .L = 2, .I = 1 }); // ========================================== @@ -140,7 +145,7 @@ pub const ThermalHeat = Energy; pub const ThermalWork = Energy; pub const ThermalCapacity = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .Tr = -1 }); pub const ThermalCapacityPerMass = BaseScalar(.{ .L = 2, .T = -2, .Tr = -1 }); -pub const ThermalFluxDensity = BaseScalar(.{ .M = 1, .T = -3 }); // Fixed typo from ThermalluxDensity +pub const ThermalFluxDensity = BaseScalar(.{ .M = 1, .T = -3 }); pub const ThermalConductance = BaseScalar(.{ .M = 1, .L = 2, .T = -3, .Tr = -1 }); pub const ThermalConductivity = BaseScalar(.{ .M = 1, .L = 1, .T = -3, .Tr = -1 }); pub const ThermalResistance = BaseScalar(.{ .M = -1, .L = -2, .T = 3, .Tr = 1 }); @@ -152,20 +157,24 @@ pub const ThermalEntropy = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .Tr = -1 }); // ========================================== pub const Frequency = BaseScalar(.{ .T = -1 }); pub const Viscosity = BaseScalar(.{ .M = 1, .L = -1, .T = -1 }); -pub const SurfaceTension = BaseScalar(.{ .M = 1, .T = -2 }); // Corrected from MT-2a +pub const SurfaceTension = BaseScalar(.{ .M = 1, .T = -2 }); + +// ========================================== +// Tests +// ========================================== test "BaseQuantities - Core dimensions instantiation" { // Basic types via generic wrappers const M = Meter.Of(f32); const distance = M.splat(100); - try std.testing.expectEqual(100.0, distance.value()); + try std.testing.expectEqual(100.0, distance.data[0]); try std.testing.expectEqual(1, M.dims.get(.L)); try std.testing.expectEqual(0, M.dims.get(.T)); // Test specific scale variants const Kmh = Speed.Scaled(f32, .{ .L = .k, .T = .hour }); const speed = Kmh.splat(120); - try std.testing.expectEqual(120.0, speed.value()); + try std.testing.expectEqual(120.0, speed.data[0]); try std.testing.expectEqual(.k, @TypeOf(speed).scales.get(.L)); try std.testing.expectEqual(.hour, @TypeOf(speed).scales.get(.T)); } @@ -176,12 +185,12 @@ test "BaseQuantities - Kinematics equations" { // Velocity = Distance / Time const v = d.div(t); - try std.testing.expectEqual(25.0, v.value()); + try std.testing.expectEqual(25.0, v.data[0]); try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims)); // Acceleration = Velocity / Time const a = v.div(t); - try std.testing.expectEqual(12.5, a.value()); + try std.testing.expectEqual(12.5, a.data[0]); try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims)); } @@ -193,13 +202,13 @@ test "BaseQuantities - Dynamics (Force and Work)" { // Force = mass * acceleration const f = m.mul(a); - try std.testing.expectEqual(98, f.value()); + try std.testing.expectEqual(98, f.data[0]); try std.testing.expect(Force.dims.eql(@TypeOf(f).dims)); // Energy (Work) = Force * distance const distance = Meter.Of(f32).splat(5.0); const energy = f.mul(distance); - try std.testing.expectEqual(490, energy.value()); + try std.testing.expectEqual(490, energy.data[0]); try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims)); } @@ -209,26 +218,26 @@ test "BaseQuantities - Electric combinations" { // Charge = Current * time const charge = current.mul(time); - try std.testing.expectEqual(6.0, charge.value()); + try std.testing.expectEqual(6.0, charge.data[0]); try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims)); } test "Constants - Initialization and dimension checks" { // Speed of Light const c = Constants.SpeedOfLight.Of(f64); - try std.testing.expectEqual(299792458.0, c.value()); + try std.testing.expectEqual(299792458.0, c.data[0]); try std.testing.expectEqual(1, @TypeOf(c).dims.get(.L)); try std.testing.expectEqual(-1, @TypeOf(c).dims.get(.T)); // Electron Mass (verifying scale as well) const me = Constants.ElectronMass.Of(f64); - try std.testing.expectEqual(9.1093837139e-31, me.value()); + try std.testing.expectEqual(9.1093837139e-31, me.data[0]); try std.testing.expectEqual(1, @TypeOf(me).dims.get(.M)); try std.testing.expectEqual(.k, @TypeOf(me).scales.get(.M)); // Should be scaled to kg // Boltzmann Constant (Complex derived dimensions) const kb = Constants.Boltzmann.Of(f64); - try std.testing.expectEqual(1.380649e-23, kb.value()); + try std.testing.expectEqual(1.380649e-23, kb.data[0]); try std.testing.expectEqual(1, @TypeOf(kb).dims.get(.M)); try std.testing.expectEqual(2, @TypeOf(kb).dims.get(.L)); try std.testing.expectEqual(-2, @TypeOf(kb).dims.get(.T)); @@ -237,7 +246,7 @@ test "Constants - Initialization and dimension checks" { // Vacuum Permittivity const eps0 = Constants.VacuumPermittivity.Of(f64); - try std.testing.expectEqual(8.8541878188e-12, eps0.value()); + try std.testing.expectEqual(8.8541878188e-12, eps0.data[0]); try std.testing.expectEqual(-1, @TypeOf(eps0).dims.get(.M)); try std.testing.expectEqual(-3, @TypeOf(eps0).dims.get(.L)); try std.testing.expectEqual(4, @TypeOf(eps0).dims.get(.T)); @@ -245,7 +254,7 @@ test "Constants - Initialization and dimension checks" { // Fine Structure Constant (Dimensionless) const alpha = Constants.FineStructure.Of(f64); - try std.testing.expectEqual(0.0072973525643, alpha.value()); + try std.testing.expectEqual(0.0072973525643, alpha.data[0]); try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.M)); try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.L)); } diff --git a/src/Scales.zig b/src/Scales.zig index 246565f..06031fb 100644 --- a/src/Scales.zig +++ b/src/Scales.zig @@ -1,5 +1,4 @@ const std = @import("std"); -const hlp = @import("helper.zig"); const Dimensions = @import("Dimensions.zig"); const Dimension = @import("Dimensions.zig").Dimension; @@ -99,7 +98,7 @@ data: std.EnumArray(Dimension, UnitScale), pub fn init(comptime init_val: ArgOpts) Self { comptime var s = Self{ .data = std.EnumArray(Dimension, UnitScale).initFill(.none) }; inline for (std.meta.fields(@TypeOf(init_val))) |f| { - if (comptime hlp.isInt(@TypeOf(@field(init_val, f.name)))) + if (comptime @typeInfo(@TypeOf(@field(init_val, f.name))) == .comptime_int) s.data.set(@field(Dimension, f.name), @enumFromInt(@field(init_val, f.name))) else s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); diff --git a/src/Quantity.zig b/src/Tensor.zig similarity index 73% rename from src/Quantity.zig rename to src/Tensor.zig index f0edea6..7cb8a3c 100644 --- a/src/Quantity.zig +++ b/src/Tensor.zig @@ -1,12 +1,11 @@ const std = @import("std"); -const hlp = @import("helper.zig"); const Scales = @import("Scales.zig"); const UnitScale = Scales.UnitScale; const Dimensions = @import("Dimensions.zig"); const Dimension = Dimensions.Dimension; // ───────────────────────────────────────────────────────────────────────────── -// Comptime shape utilities +// Comptime utilities // ───────────────────────────────────────────────────────────────────────────── pub fn shapeTotal(comptime shape: []const usize) usize { @@ -34,7 +33,10 @@ pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [sha var out: [shape.len - 1]usize = undefined; var j: usize = 0; for (shape, 0..) |v, i| { - if (i != axis) { out[j] = v; j += 1; } + if (i != axis) { + out[j] = v; + j += 1; + } } return out; } @@ -96,6 +98,32 @@ pub fn insertAxis( return out; } +fn isInt(comptime T: type) bool { + return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int; +} + +fn finerScales(comptime T1: type, comptime T2: type) Scales { + const d1: Dimensions = T1.dims; + const d2: Dimensions = T2.dims; + const s1: Scales = T1.scales; + const s2: Scales = T2.scales; + comptime var out = Scales.initFill(.none); + inline for (std.enums.values(Dimension)) |dim| { + const scale1 = comptime s1.get(dim); + const scale2 = comptime s2.get(dim); + out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0) + .none + else if (comptime d1.get(dim) == 0) + scale2 + else if (comptime d2.get(dim) == 0) + scale1 + else if (comptime scale1.getFactor() > scale2.getFactor()) + scale2 + else + scale1); + } + comptime return out; +} // ───────────────────────────────────────────────────────────────────────────── // File-scope RHS normalisation helpers // @@ -104,76 +132,93 @@ pub fn insertAxis( // Actual Tensor types are passed through unchanged. // ───────────────────────────────────────────────────────────────────────────── +fn isTensor(comptime Rhs: type) bool { + return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR"); +} + fn RhsTensorType(comptime T: type, comptime Rhs: type) type { - if (@hasDecl(Rhs, "ISTENSOR")) return Rhs; + if (comptime isTensor(Rhs)) return Rhs; return Tensor(T, .{}, .{}, &.{1}); } fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { const Rhs = @TypeOf(r); - if (comptime @hasDecl(Rhs, "ISTENSOR")) return r; + if (comptime isTensor(Rhs)) return r; const scalar: T = switch (comptime @typeInfo(Rhs)) { .comptime_int => switch (comptime @typeInfo(T)) { .float => @as(T, @floatFromInt(r)), - else => @as(T, r), + else => @as(T, r), }, .comptime_float => switch (comptime @typeInfo(T)) { - .int => @as(T, @intFromFloat(r)), - else => @as(T, r), + .int => @as(T, @intFromFloat(r)), + else => @as(T, r), }, - .int => switch (comptime @typeInfo(T)) { + .int => switch (comptime @typeInfo(T)) { .float => @floatFromInt(r), - else => @intCast(r), + else => @intCast(r), }, .float => switch (comptime @typeInfo(T)) { - .int => @intFromFloat(r), - else => @floatCast(r), + .int => @intFromFloat(r), + else => @floatCast(r), }, else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)), }; return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} }; } -// ───────────────────────────────────────────────────────────────────────────── -// Tensor — unified dimensioned ND type. -// -// T : element numeric type (f32, f64, i32, i128, …) -// d_opt : SI dimension exponents -// s_opt : unit scales -// shape_ : compile-time shape -// &.{1} → scalar -// &.{3} → 3-vector -// &.{4, 4} → 4×4 matrix -// &.{3, 3, 3} → 3D field -// -// Storage: flat @Vector(total, T) where total = product(shape_). -// All arithmetic operates on the flat vector directly → SIMD wherever possible. -// -// Shape-related comptime constants exposed on every Tensor type: -// dims : Dimensions — SI exponent struct -// scales : Scales — unit scale struct -// shape : []const usize -// rank : usize = shape.len -// total : usize = product(shape) -// strides_arr : [rank]usize — row-major strides -// -// Index helper: -// Tensor.idx(.{row, col}) → flat index (comptime, no runtime cost) -// -// GPU readiness: -// tensor.asSlice() → []T (zero-copy pointer to the flat @Vector storage) -// -// Contraction (replaces dot / cross / matmul): -// a.contract(b, axis_a, axis_b) -// For rank-1 × rank-1 this is the dot product. -// For rank-2 × rank-2 with axis_a=1, axis_b=0 this is matrix multiply. -// -// Removed from Quantity: -// Scalar / Vector aliases, Vec3 / ScalarType, .value(), .vec(), .vec3(), -// dot(), cross(), mulScalar(), divScalar(), eqScalar() and friends. -// Use Tensor(..., &.{1}), .data[0], mul(), div(), eq() respectively. -// ───────────────────────────────────────────────────────────────────────────── +pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void { + if (n == 0) return; + var val = n; + if (val < 0) { + try writer.writeAll("\u{207B}"); + val = -val; + } + var buf: [12]u8 = undefined; + const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return; + for (str) |c| { + const s = switch (c) { + '0' => "\u{2070}", + '1' => "\u{00B9}", + '2' => "\u{00B2}", + '3' => "\u{00B3}", + '4' => "\u{2074}", + '5' => "\u{2075}", + '6' => "\u{2076}", + '7' => "\u{2077}", + '8' => "\u{2078}", + '9' => "\u{2079}", + else => unreachable, + }; + try writer.writeAll(s); + } +} +/// ───────────────────────────────────────────────────────────────────────────── +/// Tensor — unified dimensioned ND type. +/// +/// T : element numeric type (f32, f64, i32, i128, …) +/// d_opt : SI dimension exponents +/// s_opt : unit scales +/// shape_ : compile-time shape +/// &.{1} → scalar +/// &.{3} → 3-vector +/// &.{4, 4} → 4×4 matrix +/// &.{3, 3, 3} → 3D field +/// +/// Storage: flat @Vector(total, T) where total = product(shape_). +/// All arithmetic operates on the flat vector directly → SIMD wherever possible. +/// +/// Shape-related comptime constants exposed on every Tensor type: +/// dims : Dimensions — SI exponent struct +/// scales : Scales — unit scale struct +/// shape : []const usize +/// rank : usize = shape.len +/// total : usize = product(shape) +/// strides_arr : [rank]usize — row-major strides +/// +/// Index helper: +/// Tensor.idx(.{row, col}) → flat index (comptime, no runtime cost) +/// ───────────────────────────────────────────────────────────────────────────── pub fn Tensor( comptime T: type, comptime d_opt: Dimensions.ArgOpts, @@ -186,9 +231,9 @@ pub fn Tensor( } @setEvalBranchQuota(10_000_000); - const _total: usize = comptime shapeTotal(shape_); - const _strides = comptime shapeStrides(shape_); - const Vec = @Vector(_total, T); + const _total: usize = comptime shapeTotal(shape_); + const _strides = comptime shapeStrides(shape_); + const Vec = @Vector(_total, T); return struct { /// Flat SIMD storage. All arithmetic operates here directly. @@ -196,14 +241,14 @@ pub fn Tensor( const Self = @This(); - pub const ValueType : type = T; - pub const dims : Dimensions = Dimensions.init(d_opt); - pub const scales : Scales = Scales.init(s_opt); - pub const shape : []const usize = shape_; - pub const rank : usize = shape_.len; - pub const total : usize = _total; + pub const ValueType: type = T; + pub const dims: Dimensions = Dimensions.init(d_opt); + pub const scales: Scales = Scales.init(s_opt); + pub const shape: []const usize = shape_; + pub const rank: usize = shape_.len; + pub const total: usize = _total; pub const strides_arr: [shape_.len]usize = _strides; - pub const ISTENSOR = true; + pub const ISTENSOR = true; // ─────────────────────────────────────────────────────────────── // Index helper @@ -211,7 +256,7 @@ pub fn Tensor( /// Convert N-D coords (row-major) to flat index — fully comptime. /// Usage: Tensor.idx(.{row, col}) - pub fn idx(comptime coords: [rank]usize) usize { + pub inline fn idx(comptime coords: [rank]usize) usize { comptime { var flat: usize = 0; for (0..rank) |i| { @@ -232,7 +277,7 @@ pub fn Tensor( } pub const zero: Self = splat(0); - pub const one: Self = splat(1); + pub const one: Self = splat(1); // ─────────────────────────────────────────────────────────────── // GPU readiness @@ -247,8 +292,12 @@ pub fn Tensor( // Internal: RHS normalisation // ─────────────────────────────────────────────────────────────── - inline fn RhsT(comptime Rhs: type) type { return RhsTensorType(T, Rhs); } - inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { return toRhsTensor(T, r); } + inline fn RhsT(comptime Rhs: type) type { + return RhsTensorType(T, Rhs); + } + inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { + return toRhsTensor(T, r); + } // ─────────────────────────────────────────────────────────────── // Internal: scalar broadcast (shape {1} → full Vec) @@ -270,48 +319,48 @@ pub fn Tensor( pub inline fn add(self: Self, r: anytype) Tensor( T, dims.argsOpt(), - hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != total and RhsType.total != 1) @compileError("Shape mismatch in add: element counts must match or RHS must be scalar (total=1)."); - const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); + const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; - return .{ .data = if (comptime hlp.isInt(T)) l +| rr else l + rr }; + return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; } /// Element-wise subtract. Dimensions must match; scales resolve to finer. pub inline fn sub(self: Self, r: anytype) Tensor( T, dims.argsOpt(), - hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != total and RhsType.total != 1) @compileError("Shape mismatch in sub."); - const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); + const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; - return .{ .data = if (comptime hlp.isInt(T)) l -| rr else l - rr }; + return .{ .data = if (comptime isInt(T)) l -| rr else l - rr }; } /// Element-wise multiply. Dimension exponents summed. @@ -319,20 +368,20 @@ pub fn Tensor( pub inline fn mul(self: Self, r: anytype) Tensor( T, dims.add(RhsT(@TypeOf(r)).dims).argsOpt(), - hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); if (comptime RhsType.total != total and RhsType.total != 1) @compileError("Shape mismatch in mul."); - const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); - return .{ .data = if (comptime hlp.isInt(T)) l *| rr else l * rr }; + const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rr: Vec = broadcastToVec(RhsNorm, rr_base); + return .{ .data = if (comptime isInt(T)) l *| rr else l * rr }; } /// Element-wise divide. Dimension exponents subtracted. @@ -340,20 +389,20 @@ pub fn Tensor( pub inline fn div(self: Self, r: anytype) Tensor( T, dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(), - hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { - const rhs_q = rhs(r); + const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); if (comptime RhsType.total != total and RhsType.total != 1) @compileError("Shape mismatch in div."); - const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); - if (comptime hlp.isInt(T)) { + const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rr: Vec = broadcastToVec(RhsNorm, rr_base); + if (comptime isInt(T)) { var result: Vec = undefined; inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]); return .{ .data = result }; @@ -378,7 +427,7 @@ pub fn Tensor( scales.argsOpt(), shape_, ) { - if (comptime hlp.isInt(T)) { + if (comptime isInt(T)) { var result: Vec = undefined; inline for (0..total) |i| result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T); @@ -440,8 +489,8 @@ pub fn Tensor( comptime std.debug.assert(Dest.total == total or Dest.total == 1); - const DestT = ActualDest.ValueType; - const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); + const DestT = ActualDest.ValueType; + const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); const DestVec = @Vector(total, DestT); // ── Same numeric type ────────────────────────────────────── @@ -455,7 +504,7 @@ pub fn Tensor( return .{ .data = self.data *| @as(Vec, @splat(mult)) }; } else { const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); - const half: T = comptime @divTrunc(div_val, 2); + const half: T = comptime @divTrunc(div_val, 2); var result: DestVec = undefined; inline for (0..total) |i| { const val = self.data[i]; @@ -473,14 +522,14 @@ pub fn Tensor( inline for (0..total) |i| { const float_val: f64 = switch (comptime @typeInfo(T)) { .float => @floatCast(self.data[i]), - .int => @floatFromInt(self.data[i]), - else => unreachable, + .int => @floatFromInt(self.data[i]), + else => unreachable, }; const scaled = float_val * ratio; result[i] = switch (comptime @typeInfo(DestT)) { .float => @floatCast(scaled), - .int => @intFromFloat(@round(scaled)), - else => unreachable, + .int => @intFromFloat(@round(scaled)), + else => unreachable, }; } return .{ .data = result }; @@ -505,11 +554,11 @@ pub fn Tensor( /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { - const RhsType = @TypeOf(rhs_q); - const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); + const RhsType = @TypeOf(rhs_q); + const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; @@ -604,39 +653,39 @@ pub fn Tensor( comptime axis_b: usize, ) blk: { const OT = @TypeOf(other); - comptime std.debug.assert(axis_a < rank); - comptime std.debug.assert(axis_b < OT.rank); - comptime std.debug.assert(shape_[axis_a] == OT.shape[axis_b]); + std.debug.assert(axis_a < rank); + std.debug.assert(axis_b < OT.rank); + std.debug.assert(shape_[axis_a] == OT.shape[axis_b]); // Contracted-away free axes; empty joint → scalar shape {1} - const sa = shapeRemoveAxis(shape_, axis_a); - const sb = shapeRemoveAxis(OT.shape, axis_b); + const sa = shapeRemoveAxis(shape_, axis_a); + const sb = shapeRemoveAxis(OT.shape, axis_b); const rs_raw = shapeCat(&sa, &sb); const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw; break :blk Tensor( T, dims.add(OT.dims).argsOpt(), - hlp.finerScales(Self, OT).argsOpt(), + finerScales(Self, OT).argsOpt(), rs, ); } { const OT = @TypeOf(other); const k: usize = comptime shape_[axis_a]; // contraction dimension - const sa = comptime shapeRemoveAxis(shape_, axis_a); - const sb = comptime shapeRemoveAxis(OT.shape, axis_b); + const sa = comptime shapeRemoveAxis(shape_, axis_a); + const sb = comptime shapeRemoveAxis(OT.shape, axis_b); const rs_raw = comptime shapeCat(&sa, &sb); const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; const ResultType = Tensor( T, dims.add(OT.dims).argsOpt(), - hlp.finerScales(Self, OT).argsOpt(), + finerScales(Self, OT).argsOpt(), rs, ); // Normalise scales before accumulation - const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), shape_); - const OtherNorm = Tensor(T, OT.dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), OT.shape); + const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_); + const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape); const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; @@ -658,10 +707,10 @@ pub fn Tensor( // Reinsert the contracted index into free coords → full coord arrays const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free); const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free); - const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); - const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); + const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); + const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); - if (comptime hlp.isInt(T)) + if (comptime isInt(T)) acc +|= a_data[a_flat] *| b_data[b_flat] else acc += a_data[a_flat] * b_data[b_flat]; @@ -712,10 +761,10 @@ pub fn Tensor( if (comptime total == 1) { switch (@typeInfo(T)) { .float, .comptime_float => try writer.printFloat(self.data[0], options), - .int, .comptime_int => try writer.printInt(self.data[0], 10, .lower, .{ - .width = options.width, + .int, .comptime_int => try writer.printInt(self.data[0], 10, .lower, .{ + .width = options.width, .alignment = options.alignment, - .fill = options.fill, + .fill = options.fill, .precision = options.precision, }), else => unreachable, @@ -726,10 +775,10 @@ pub fn Tensor( if (i > 0) try writer.writeAll(", "); switch (@typeInfo(T)) { .float, .comptime_float => try writer.printFloat(self.data[i], options), - .int, .comptime_int => try writer.printInt(self.data[i], 10, .lower, .{ - .width = options.width, + .int, .comptime_int => try writer.printInt(self.data[i], 10, .lower, .{ + .width = options.width, .alignment = options.alignment, - .fill = options.fill, + .fill = options.fill, .precision = options.precision, }), else => unreachable, @@ -751,7 +800,7 @@ pub fn Tensor( else try writer.print("{s}{s}", .{ uscale.str(), bu.unit() }); - if (v != 1) try hlp.printSuperscript(writer, v); + if (v != 1) try printSuperscript(writer, v); } } }; @@ -773,23 +822,23 @@ pub fn Tensor( // ─── Scalar tests ───────────────────────────────────────────────────────── test "Scalar initiat" { - const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{ .T = .n }, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{ .T = .n }, &.{1}); const distance = Meter.splat(10); - const time = Second.splat(2); + const time = Second.splat(2); try std.testing.expectEqual(10, distance.data[0]); - try std.testing.expectEqual(2, time.data[0]); + try std.testing.expectEqual(2, time.data[0]); } test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); const m1000 = Meter.splat(1000); - const km1 = KiloMeter.splat(1); - const km2 = KiloMeter.splat(2); + const km1 = KiloMeter.splat(1); + const km2 = KiloMeter.splat(2); try std.testing.expect(m1000.eq(km1)); try std.testing.expect(km1.eq(m1000)); @@ -807,65 +856,65 @@ test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { } test "Scalar Add" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const distance = Meter.splat(10); + const distance = Meter.splat(10); const distance2 = Meter.splat(20); - const added = distance.add(distance2); + const added = distance.add(distance2); try std.testing.expectEqual(30, added.data[0]); - try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); + try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); const distance3 = KiloMeter.splat(2); - const added2 = distance.add(distance3); + const added2 = distance.add(distance3); try std.testing.expectEqual(2010, added2.data[0]); const added3 = distance3.add(distance).to(KiloMeter); try std.testing.expectEqual(2, added3.data[0]); const distance4 = KiloMeter_f.splat(2); - const added4 = distance4.add(distance).to(KiloMeter_f); + const added4 = distance4.add(distance).to(KiloMeter_f); try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001); } test "Scalar Sub" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const a = Meter.splat(500); - const b = Meter.splat(200); + const a = Meter.splat(500); + const b = Meter.splat(200); const diff = a.sub(b); - try std.testing.expectEqual(300, diff.data[0]); + try std.testing.expectEqual(300, diff.data[0]); const diff2 = b.sub(a); try std.testing.expectEqual(-300, diff2.data[0]); - const km_f = KiloMeter_f.splat(2.5); - const m_f = Meter.splat(500); + const km_f = KiloMeter_f.splat(2.5); + const m_f = Meter.splat(500); const diff3 = km_f.sub(m_f); try std.testing.expectApproxEqAbs(2000, diff3.data[0], 1e-4); } test "Scalar MulBy" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); - const d = Meter.splat(3); - const t = Second.splat(4); + const d = Meter.splat(3); + const t = Second.splat(4); const at = d.mul(t); try std.testing.expectEqual(12, at.data[0]); - try std.testing.expectEqual(1, @TypeOf(at).dims.get(.L)); - try std.testing.expectEqual(1, @TypeOf(at).dims.get(.T)); + try std.testing.expectEqual(1, @TypeOf(at).dims.get(.L)); + try std.testing.expectEqual(1, @TypeOf(at).dims.get(.T)); - const d2 = Meter.splat(5); + const d2 = Meter.splat(5); const area = d.mul(d2); try std.testing.expectEqual(15, area.data[0]); - try std.testing.expectEqual(2, @TypeOf(area).dims.get(.L)); + try std.testing.expectEqual(2, @TypeOf(area).dims.get(.L)); } test "Scalar MulBy with scale" { const KiloMeter = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const KiloGram = Tensor(f32, .{ .M = 1 }, .{ .M = .k }, &.{1}); + const KiloGram = Tensor(f32, .{ .M = 1 }, .{ .M = .k }, &.{1}); const dist = KiloMeter.splat(2.0); const mass = KiloGram.splat(3.0); @@ -875,9 +924,9 @@ test "Scalar MulBy with scale" { } test "Scalar MulBy with type change" { - const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const Second = Tensor(f64, .{ .T = 1 }, .{}, &.{1}); - const KmSec = Tensor(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Second = Tensor(f64, .{ .T = 1 }, .{}, &.{1}); + const KmSec = Tensor(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); const KmSec_f = Tensor(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); const d = Meter.splat(3); @@ -888,103 +937,103 @@ test "Scalar MulBy with type change" { } test "Scalar MulBy small" { - const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .n }, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .n }, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); const d = Meter.splat(3); const t = Second.splat(4); try std.testing.expectEqual(12, d.mul(t).data[0]); } test "Scalar MulBy dimensionless" { - const DimLess = Tensor(i128, .{}, .{}, &.{1}); - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const d = Meter.splat(7); + const DimLess = Tensor(i128, .{}, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(7); const scaled = d.mul(DimLess.splat(3)); try std.testing.expectEqual(21, scaled.data[0]); } test "Scalar Sqrt" { - const MeterSquare = Tensor(i128, .{ .L = 2 }, .{}, &.{1}); - const MeterSquare_f = Tensor(f64, .{ .L = 2 }, .{}, &.{1}); + const MeterSquare = Tensor(i128, .{ .L = 2 }, .{}, &.{1}); + const MeterSquare_f = Tensor(f64, .{ .L = 2 }, .{}, &.{1}); - var d = MeterSquare.splat(9); + var d = MeterSquare.splat(9); var scaled = d.sqrt(); try std.testing.expectEqual(3, scaled.data[0]); try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); - d = MeterSquare.splat(-5); + d = MeterSquare.splat(-5); scaled = d.sqrt(); try std.testing.expectEqual(0, scaled.data[0]); - const d2 = MeterSquare_f.splat(20); + const d2 = MeterSquare_f.splat(20); const scaled2 = d2.sqrt(); try std.testing.expectApproxEqAbs(4.472135955, scaled2.data[0], 1e-4); } test "Scalar Chained: velocity and acceleration" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); - const dist = Meter.splat(100); - const t1 = Second.splat(5); + const dist = Meter.splat(100); + const t1 = Second.splat(5); const velocity = dist.div(t1); try std.testing.expectEqual(20, velocity.data[0]); - const t2 = Second.splat(4); + const t2 = Second.splat(4); const accel = velocity.div(t2); try std.testing.expectEqual(5, accel.data[0]); } test "Scalar DivBy integer exact" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); const dist = Meter.splat(120); const time = Second.splat(4); - const vel = dist.div(time); + const vel = dist.div(time); try std.testing.expectEqual(30, vel.data[0]); } test "Scalar Finer scales skip dim 0" { - const Dimless = Tensor(i128, .{}, .{}, &.{1}); - const KiloMetre = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Dimless = Tensor(i128, .{}, .{}, &.{1}); + const KiloMetre = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const r = Dimless.splat(30); - const km = KiloMetre.splat(4); - const vel = r.mul(km); + const r = Dimless.splat(30); + const km = KiloMetre.splat(4); + const vel = r.mul(km); try std.testing.expectEqual(120, vel.data[0]); try std.testing.expectEqual(Scales.UnitScale.k, @TypeOf(vel).scales.get(.L)); } test "Scalar Conversion chain: km -> m -> cm" { - const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const CentiMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .c }, &.{1}); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const CentiMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .c }, &.{1}); const km = KiloMeter.splat(15); - const m = km.to(Meter); + const m = km.to(Meter); const cm = m.to(CentiMeter); - try std.testing.expectEqual(15_000, m.data[0]); + try std.testing.expectEqual(15_000, m.data[0]); try std.testing.expectEqual(1_500_000, cm.data[0]); } test "Scalar Conversion: hours -> minutes -> seconds" { - const Hour = Tensor(i128, .{ .T = 1 }, .{ .T = .hour }, &.{1}); - const Minute = Tensor(i128, .{ .T = 1 }, .{ .T = .min }, &.{1}); - const Second = Tensor(i128, .{ .T = 1 }, .{}, &.{1}); + const Hour = Tensor(i128, .{ .T = 1 }, .{ .T = .hour }, &.{1}); + const Minute = Tensor(i128, .{ .T = 1 }, .{ .T = .min }, &.{1}); + const Second = Tensor(i128, .{ .T = 1 }, .{}, &.{1}); - const h = Hour.splat(1); + const h = Hour.splat(1); const min = h.to(Minute); const sec = min.to(Second); - try std.testing.expectEqual(60, min.data[0]); + try std.testing.expectEqual(60, min.data[0]); try std.testing.expectEqual(3600, sec.data[0]); } test "Scalar Format" { const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{1}); - const Meter = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + const Meter = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); - const m = Meter.splat(1.23456); + const m = Meter.splat(1.23456); const accel = MeterPerSecondSq.splat(9.81); var buf: [64]u8 = undefined; @@ -996,23 +1045,23 @@ test "Scalar Format" { } test "Scalar Abs" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const MeterF = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const MeterF = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); - try std.testing.expectEqual(50, Meter.splat(-50).abs().data[0]); + try std.testing.expectEqual(50, Meter.splat(-50).abs().data[0]); try std.testing.expectEqual(42.5, MeterF.splat(-42.5).abs().data[0]); } test "Scalar Pow" { const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const d = Meter.splat(4); + const d = Meter.splat(4); try std.testing.expectEqual(16, d.pow(2).data[0]); try std.testing.expectEqual(64, d.pow(3).data[0]); } test "Scalar mul comptime_int" { const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const d = Meter.splat(7); + const d = Meter.splat(7); try std.testing.expectEqual(21, d.mul(3).data[0]); } @@ -1020,16 +1069,16 @@ test "Scalar add/sub bare number on dimensionless scalar" { const DimLess = Tensor(i128, .{}, .{}, &.{1}); const a = DimLess.splat(10); try std.testing.expectEqual(15, a.add(5).data[0]); - try std.testing.expectEqual(7, a.sub(3).data[0]); + try std.testing.expectEqual(7, a.sub(3).data[0]); } test "Scalar Imperial length scales" { - const Foot = Tensor(f64, .{ .L = 1 }, .{ .L = .ft }, &.{1}); - const Meter = Tensor(f64, .{ .L = 1 }, .{}, &.{1}); - const Inch = Tensor(f64, .{ .L = 1 }, .{ .L = .inch }, &.{1}); + const Foot = Tensor(f64, .{ .L = 1 }, .{ .L = .ft }, &.{1}); + const Meter = Tensor(f64, .{ .L = 1 }, .{}, &.{1}); + const Inch = Tensor(f64, .{ .L = 1 }, .{ .L = .inch }, &.{1}); try std.testing.expectApproxEqAbs(0.3048, Foot.splat(1.0).to(Meter).data[0], 1e-9); - try std.testing.expectApproxEqAbs(1.0, Inch.splat(12.0).to(Foot).data[0], 1e-9); + try std.testing.expectApproxEqAbs(1.0, Inch.splat(12.0).to(Foot).data[0], 1e-9); } test "Scalar Imperial mass scales" { @@ -1057,10 +1106,10 @@ test "Vector initiate" { } test "Vector format" { - const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{3}); - const KgMeterPerSecond = Tensor(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3}); + const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{3}); + const KgMeterPerSecond = Tensor(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3}); - const accel = MeterPerSecondSq.splat(9.81); + const accel = MeterPerSecondSq.splat(9.81); const momentum = KgMeterPerSecond{ .data = .{ 43, 0, 11 } }; var buf: [64]u8 = undefined; @@ -1085,7 +1134,7 @@ test "Vector Vec3 Init and Basic Arithmetic" { try std.testing.expectEqual(5, v_def.data[2]); const v1 = Meter3{ .data = .{ 10, 20, 30 } }; - const v2 = Meter3{ .data = .{ 2, 4, 6 } }; + const v2 = Meter3{ .data = .{ 2, 4, 6 } }; const added = v1.add(v2); try std.testing.expectEqual(12, added.data[0]); @@ -1093,7 +1142,7 @@ test "Vector Vec3 Init and Basic Arithmetic" { try std.testing.expectEqual(36, added.data[2]); const subbed = v1.sub(v2); - try std.testing.expectEqual(8, subbed.data[0]); + try std.testing.expectEqual(8, subbed.data[0]); try std.testing.expectEqual(16, subbed.data[1]); try std.testing.expectEqual(24, subbed.data[2]); @@ -1104,30 +1153,30 @@ test "Vector Vec3 Init and Basic Arithmetic" { } test "Vector Kinematics (scalar mul/div broadcast)" { - const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); const Second1 = Tensor(i32, .{ .T = 1 }, .{}, &.{1}); - const pos = Meter3{ .data = .{ 100, 200, 300 } }; + const pos = Meter3{ .data = .{ 100, 200, 300 } }; const time = Second1.splat(10); const vel = pos.div(time); - try std.testing.expectEqual(10, vel.data[0]); - try std.testing.expectEqual(20, vel.data[1]); - try std.testing.expectEqual(30, vel.data[2]); - try std.testing.expectEqual(1, @TypeOf(vel).dims.get(.L)); - try std.testing.expectEqual(-1, @TypeOf(vel).dims.get(.T)); + try std.testing.expectEqual(10, vel.data[0]); + try std.testing.expectEqual(20, vel.data[1]); + try std.testing.expectEqual(30, vel.data[2]); + try std.testing.expectEqual(1, @TypeOf(vel).dims.get(.L)); + try std.testing.expectEqual(-1, @TypeOf(vel).dims.get(.T)); const new_pos = vel.mul(time); try std.testing.expectEqual(100, new_pos.data[0]); - try std.testing.expectEqual(0, @TypeOf(new_pos).dims.get(.T)); + try std.testing.expectEqual(0, @TypeOf(new_pos).dims.get(.T)); } test "Vector Element-wise Math and Scaling" { const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); - const v1 = Meter3{ .data = .{ 10, 20, 30 } }; - const v2 = Meter3{ .data = .{ 2, 5, 10 } }; - const dv = v1.div(v2); + const v1 = Meter3{ .data = .{ 10, 20, 30 } }; + const v2 = Meter3{ .data = .{ 2, 5, 10 } }; + const dv = v1.div(v2); try std.testing.expectEqual(5, dv.data[0]); try std.testing.expectEqual(4, dv.data[1]); try std.testing.expectEqual(3, dv.data[2]); @@ -1136,10 +1185,10 @@ test "Vector Element-wise Math and Scaling" { test "Vector Conversions" { const KiloMeter3 = Tensor(i32, .{ .L = 1 }, .{ .L = .k }, &.{3}); - const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); const v_km = KiloMeter3{ .data = .{ 1, 2, 3 } }; - const v_m = v_km.to(Meter3); + const v_m = v_km.to(Meter3); try std.testing.expectEqual(1000, v_m.data[0]); try std.testing.expectEqual(2000, v_m.data[1]); try std.testing.expectEqual(3000, v_m.data[2]); @@ -1147,38 +1196,38 @@ test "Vector Conversions" { } test "Vector Length" { - const MeterInt3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const MeterInt3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); const MeterFloat3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); const v_int = MeterInt3{ .data = .{ 3, 4, 0 } }; try std.testing.expectEqual(25, v_int.lengthSqr()); - try std.testing.expectEqual(5, v_int.length()); + try std.testing.expectEqual(5, v_int.length()); const v_float = MeterFloat3{ .data = .{ 3.0, 4.0, 0.0 } }; try std.testing.expectApproxEqAbs(@as(f32, 25.0), v_float.lengthSqr(), 1e-4); - try std.testing.expectApproxEqAbs(@as(f32, 5.0), v_float.length(), 1e-4); + try std.testing.expectApproxEqAbs(@as(f32, 5.0), v_float.length(), 1e-4); } test "Vector Comparisons" { - const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - const KiloMeter3 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{3}); + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter3 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{3}); - const v1 = Meter3{ .data = .{ 1000.0, 500.0, 0.0 } }; - const v2 = KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } }; - const v3 = KiloMeter3{ .data = .{ 1.0, 0.6, 0.0 } }; + const v1 = Meter3{ .data = .{ 1000.0, 500.0, 0.0 } }; + const v2 = KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } }; + const v3 = KiloMeter3{ .data = .{ 1.0, 0.6, 0.0 } }; try std.testing.expect(v1.eqAll(v2)); try std.testing.expect(v1.neAll(v3)); const higher = v3.gt(v1); try std.testing.expectEqual(false, higher[0]); - try std.testing.expectEqual(true, higher[1]); + try std.testing.expectEqual(true, higher[1]); try std.testing.expectEqual(false, higher[2]); const equal = v3.eq(v1); - try std.testing.expectEqual(true, equal[0]); + try std.testing.expectEqual(true, equal[0]); try std.testing.expectEqual(false, equal[1]); - try std.testing.expectEqual(true, equal[2]); + try std.testing.expectEqual(true, equal[2]); const low_eq = v1.lte(v3); try std.testing.expect(low_eq[0] and low_eq[1] and low_eq[2]); @@ -1186,7 +1235,7 @@ test "Vector Comparisons" { test "Vector vs Scalar broadcast comparison" { // Replaces the old eqScalar / gtScalar — now just eq / gt with a shape-{1} rhs. - const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); const positions = Meter3{ .data = .{ 500.0, 1200.0, 3000.0 } }; @@ -1194,21 +1243,21 @@ test "Vector vs Scalar broadcast comparison" { const exceeded = positions.gt(threshold); try std.testing.expectEqual(false, exceeded[0]); - try std.testing.expectEqual(true, exceeded[1]); - try std.testing.expectEqual(true, exceeded[2]); + try std.testing.expectEqual(true, exceeded[1]); + try std.testing.expectEqual(true, exceeded[2]); - const Meter1 = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); - const exact = positions.eq(Meter1.splat(500)); + const Meter1 = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + const exact = positions.eq(Meter1.splat(500)); try std.testing.expect(exact[0] == true); try std.testing.expect(exact[1] == false); } test "Vector contract — dot product (rank-1 × rank-1)" { - const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - const Newton3 = Tensor(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3}); + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const Newton3 = Tensor(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3}); - const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } }; - const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } }; + const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } }; + const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } }; // work = force · pos const work = force.contract(pos, 0, 0); @@ -1220,8 +1269,8 @@ test "Vector contract — dot product (rank-1 × rank-1)" { test "Vector contract — matrix multiply (rank-2 × rank-2)" { // 2×3 matrix multiplied by 3×2 matrix → 2×2 result - const A = Tensor(f32, .{}, .{}, &.{2, 3}); - const B = Tensor(f32, .{}, .{}, &.{3, 2}); + const A = Tensor(f32, .{}, .{}, &.{ 2, 3 }); + const B = Tensor(f32, .{}, .{}, &.{ 3, 2 }); // A = [[1, 2, 3], // [4, 5, 6]] @@ -1237,16 +1286,16 @@ test "Vector contract — matrix multiply (rank-2 × rank-2)" { // C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139 // C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154 const c = a.contract(b, 1, 0); - try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 0})]); - try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 1})]); - try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 0})]); - try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 1})]); + try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]); + try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]); + try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 0 })]); + try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 1 })]); } test "Vector Abs, Pow, Sqrt and Product" { const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - const v1 = Meter3{ .data = .{ -2.0, 3.0, -4.0 } }; + const v1 = Meter3{ .data = .{ -2.0, 3.0, -4.0 } }; const v_abs = v1.abs(); try std.testing.expectEqual(2.0, v_abs.data[0]); try std.testing.expectEqual(4.0, v_abs.data[2]); @@ -1256,7 +1305,7 @@ test "Vector Abs, Pow, Sqrt and Product" { try std.testing.expectEqual(3, @TypeOf(vol).dims.get(.L)); const area_vec = v_abs.pow(2); - try std.testing.expectEqual(4.0, area_vec.data[0]); + try std.testing.expectEqual(4.0, area_vec.data[0]); try std.testing.expectEqual(16.0, area_vec.data[2]); try std.testing.expectEqual(2, @TypeOf(area_vec).dims.get(.L)); @@ -1268,7 +1317,7 @@ test "Vector Abs, Pow, Sqrt and Product" { test "Vector mul comptime_int broadcast" { const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); - const v = Meter3{ .data = .{ 1, 2, 3 } }; + const v = Meter3{ .data = .{ 1, 2, 3 } }; const scaled = v.mul(10); try std.testing.expectEqual(10, scaled.data[0]); try std.testing.expectEqual(20, scaled.data[1]); @@ -1278,8 +1327,8 @@ test "Vector mul comptime_int broadcast" { test "Vector mul comptime_float broadcast" { const MeterF3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } }; - const scaled = v.mul(0.5); + const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } }; + const scaled = v.mul(0.5); try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6); try std.testing.expectApproxEqAbs(1.0, scaled.data[1], 1e-6); try std.testing.expectApproxEqAbs(2.0, scaled.data[2], 1e-6); @@ -1288,9 +1337,9 @@ test "Vector mul comptime_float broadcast" { test "Vector div comptime_int broadcast" { const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); - const v = Meter3{ .data = .{ 10, 20, 30 } }; + const v = Meter3{ .data = .{ 10, 20, 30 } }; const halved = v.div(2); - try std.testing.expectEqual(5, halved.data[0]); + try std.testing.expectEqual(5, halved.data[0]); try std.testing.expectEqual(10, halved.data[1]); try std.testing.expectEqual(15, halved.data[2]); try std.testing.expectEqual(1, @TypeOf(halved).dims.get(.L)); @@ -1298,8 +1347,8 @@ test "Vector div comptime_int broadcast" { test "Vector div comptime_float broadcast" { const MeterF3 = Tensor(f64, .{ .L = 1 }, .{}, &.{3}); - const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } }; - const r = v.div(3.0); + const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } }; + const r = v.div(3.0); try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9); try std.testing.expectApproxEqAbs(2.0, r.data[1], 1e-9); try std.testing.expectApproxEqAbs(1.0, r.data[2], 1e-9); @@ -1311,22 +1360,22 @@ test "Vector eq broadcast on dimensionless" { const eq_res = v.eq(2); try std.testing.expectEqual(false, eq_res[0]); - try std.testing.expectEqual(true, eq_res[1]); + try std.testing.expectEqual(true, eq_res[1]); try std.testing.expectEqual(false, eq_res[2]); const gt_res = v.gt(1); try std.testing.expectEqual(false, gt_res[0]); - try std.testing.expectEqual(true, gt_res[1]); - try std.testing.expectEqual(true, gt_res[2]); + try std.testing.expectEqual(true, gt_res[1]); + try std.testing.expectEqual(true, gt_res[2]); } test "Tensor idx helper and matrix access" { - const Mat3x3 = Tensor(f32, .{}, .{}, &.{3, 3}); + const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 }); // Identity-like: set [0][0]=1, [1][1]=2, [2][2]=3 var m: Mat3x3 = Mat3x3.zero; - m.data[Mat3x3.idx(.{0, 0})] = 1.0; - m.data[Mat3x3.idx(.{1, 1})] = 2.0; - m.data[Mat3x3.idx(.{2, 2})] = 3.0; + m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0; + m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0; + m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0; try std.testing.expectEqual(1.0, m.data[0]); // [0][0] try std.testing.expectEqual(2.0, m.data[4]); // [1][1] (stride 3 → 1*3+1=4) @@ -1336,13 +1385,13 @@ test "Tensor idx helper and matrix access" { test "Tensor strides_arr correctness" { const T1 = Tensor(f32, .{}, .{}, &.{3}); - const T2 = Tensor(f32, .{}, .{}, &.{3, 4}); - const T3 = Tensor(f32, .{}, .{}, &.{2, 3, 4}); + const T2 = Tensor(f32, .{}, .{}, &.{ 3, 4 }); + const T3 = Tensor(f32, .{}, .{}, &.{ 2, 3, 4 }); - try std.testing.expectEqual(1, T1.strides_arr[0]); - try std.testing.expectEqual(4, T2.strides_arr[0]); - try std.testing.expectEqual(1, T2.strides_arr[1]); + try std.testing.expectEqual(1, T1.strides_arr[0]); + try std.testing.expectEqual(4, T2.strides_arr[0]); + try std.testing.expectEqual(1, T2.strides_arr[1]); try std.testing.expectEqual(12, T3.strides_arr[0]); - try std.testing.expectEqual(4, T3.strides_arr[1]); - try std.testing.expectEqual(1, T3.strides_arr[2]); + try std.testing.expectEqual(4, T3.strides_arr[1]); + try std.testing.expectEqual(1, T3.strides_arr[2]); } diff --git a/src/helper.zig b/src/helper.zig deleted file mode 100644 index fb104e1..0000000 --- a/src/helper.zig +++ /dev/null @@ -1,97 +0,0 @@ -const std = @import("std"); - -pub fn isInt(comptime T: type) bool { - return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int; -} - -pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void { - if (n == 0) return; - var val = n; - if (val < 0) { - try writer.writeAll("\u{207B}"); - val = -val; - } - var buf: [12]u8 = undefined; - const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return; - for (str) |c| { - const s = switch (c) { - '0' => "\u{2070}", - '1' => "\u{00B9}", - '2' => "\u{00B2}", - '3' => "\u{00B3}", - '4' => "\u{2074}", - '5' => "\u{2075}", - '6' => "\u{2076}", - '7' => "\u{2077}", - '8' => "\u{2078}", - '9' => "\u{2079}", - else => unreachable, - }; - try writer.writeAll(s); - } -} - -const Scales = @import("Scales.zig"); -const Dimensions = @import("Dimensions.zig"); -const Dimension = @import("Dimensions.zig").Dimension; - -pub fn finerScales(comptime T1: type, comptime T2: type) Scales { - const d1: Dimensions = T1.dims; - const d2: Dimensions = T2.dims; - const s1: Scales = T1.scales; - const s2: Scales = T2.scales; - comptime var out = Scales.initFill(.none); - inline for (std.enums.values(Dimension)) |dim| { - const scale1 = comptime s1.get(dim); - const scale2 = comptime s2.get(dim); - out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0) - .none - else if (comptime d1.get(dim) == 0) - scale2 - else if (comptime d2.get(dim) == 0) - scale1 - else if (comptime scale1.getFactor() > scale2.getFactor()) - scale2 - else - scale1); - } - comptime return out; -} - -// --------------------------------------------------------------------------- -// RHS normalisation helpers -// --------------------------------------------------------------------------- - -const Quantity = @import("Quantity.zig").Quantity; - -/// Returns true if `T` is a `Scalar_` type (has `dims`, `scales`, and `value`). -pub fn isScalarType(comptime T: type) bool { - return @typeInfo(T) == .@"struct" and - @hasDecl(T, "ISQUANTITY") and - @field(T, "ISQUANTITY"); -} - -/// Resolve the Scalar type that `rhs` will be treated as. -/// -/// Accepted rhs types: -/// - Any `Scalar_` type → returned as-is -/// - `comptime_int` / `comptime_float` → dimensionless `Scalar_(BaseT, {}, {})` -/// - `BaseT` (the scalar's value type) → dimensionless `Scalar_(BaseT, {}, {})` -/// -/// Everything else is a compile error, including other int/float types. -pub fn rhsQuantityType(comptime ValueType: type, N: usize, comptime RhsT: type) type { - if (comptime isScalarType(RhsT)) return RhsT; - if (comptime RhsT == comptime_int or RhsT == comptime_float or RhsT == ValueType) - return Quantity(ValueType, N, .{}, .{}); - @compileError( - "rhs must be a Scalar, " ++ @typeName(ValueType) ++ - ", comptime_int, or comptime_float; got " ++ @typeName(RhsT), - ); -} - -/// Convert `rhs` to its normalised Scalar form (see `rhsScalarType`). -pub inline fn toRhsQuantity(comptime BaseT: type, N: usize, rhs: anytype) rhsQuantityType(BaseT, N, @TypeOf(rhs)) { - if (comptime isScalarType(@TypeOf(rhs))) return rhs; - const DimLess = Quantity(BaseT, N, .{}, .{}); - return DimLess{ .data = @splat(@as(BaseT, rhs)) }; -} diff --git a/src/main.zig b/src/main.zig index 4d61ee6..33ba626 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,15 +1,13 @@ const std = @import("std"); -pub const Vector = @import("Quantity.zig").Vector; -pub const Scalar = @import("Quantity.zig").Scalar; +pub const Tensor = @import("Tensor.zig").Tensor; pub const Dimensions = @import("Dimensions.zig"); pub const Scales = @import("Scales.zig"); pub const Base = @import("Base.zig"); test { - _ = @import("Quantity.zig"); + _ = @import("Tensor.zig"); _ = @import("Dimensions.zig"); _ = @import("Scales.zig"); _ = @import("Base.zig"); - _ = @import("helper.zig"); } From 16d25e7e7eba2a009bb5806d7a3ccffae4f3b9ad Mon Sep 17 00:00:00 2001 From: AdrienBouvais Date: Mon, 27 Apr 2026 14:45:41 +0200 Subject: [PATCH 03/10] Added shape comptime check for Tensor add/sub/div/mul --- src/Tensor.zig | 197 ++++++++++--------------------------------------- 1 file changed, 40 insertions(+), 157 deletions(-) diff --git a/src/Tensor.zig b/src/Tensor.zig index 7cb8a3c..6cc93ab 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -14,6 +14,15 @@ pub fn shapeTotal(comptime shape: []const usize) usize { return t; } +/// Check if two shapes are strictly identical. +pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { + if (a.len != b.len) return false; + for (a, 0..) |v, i| { + if (v != b[i]) return false; + } + return true; +} + /// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). /// e.g. shape {3, 4} → strides {4, 1} /// shape {2, 3, 4} → strides {12, 4, 1} @@ -126,10 +135,6 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales { } // ───────────────────────────────────────────────────────────────────────────── // File-scope RHS normalisation helpers -// -// Any bare comptime_int / comptime_float / runtime T used as an arithmetic -// or comparison RHS is wrapped into a dimensionless Tensor of shape {1}. -// Actual Tensor types are passed through unchanged. // ───────────────────────────────────────────────────────────────────────────── fn isTensor(comptime Rhs: type) bool { @@ -193,32 +198,6 @@ pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void { } } -/// ───────────────────────────────────────────────────────────────────────────── -/// Tensor — unified dimensioned ND type. -/// -/// T : element numeric type (f32, f64, i32, i128, …) -/// d_opt : SI dimension exponents -/// s_opt : unit scales -/// shape_ : compile-time shape -/// &.{1} → scalar -/// &.{3} → 3-vector -/// &.{4, 4} → 4×4 matrix -/// &.{3, 3, 3} → 3D field -/// -/// Storage: flat @Vector(total, T) where total = product(shape_). -/// All arithmetic operates on the flat vector directly → SIMD wherever possible. -/// -/// Shape-related comptime constants exposed on every Tensor type: -/// dims : Dimensions — SI exponent struct -/// scales : Scales — unit scale struct -/// shape : []const usize -/// rank : usize = shape.len -/// total : usize = product(shape) -/// strides_arr : [rank]usize — row-major strides -/// -/// Index helper: -/// Tensor.idx(.{row, col}) → flat index (comptime, no runtime cost) -/// ───────────────────────────────────────────────────────────────────────────── pub fn Tensor( comptime T: type, comptime d_opt: Dimensions.ArgOpts, @@ -226,8 +205,10 @@ pub fn Tensor( comptime shape_: []const usize, ) type { comptime { - std.debug.assert(shape_.len >= 1); - for (shape_) |s| std.debug.assert(s >= 1); + if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1)."); + for (shape_) |s| { + if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1."); + } } @setEvalBranchQuota(10_000_000); @@ -236,7 +217,6 @@ pub fn Tensor( const Vec = @Vector(_total, T); return struct { - /// Flat SIMD storage. All arithmetic operates here directly. data: Vec, const Self = @This(); @@ -250,27 +230,19 @@ pub fn Tensor( pub const strides_arr: [shape_.len]usize = _strides; pub const ISTENSOR = true; - // ─────────────────────────────────────────────────────────────── - // Index helper - // ─────────────────────────────────────────────────────────────── - /// Convert N-D coords (row-major) to flat index — fully comptime. /// Usage: Tensor.idx(.{row, col}) pub inline fn idx(comptime coords: [rank]usize) usize { comptime { var flat: usize = 0; for (0..rank) |i| { - std.debug.assert(coords[i] < shape[i]); + if (coords[i] >= shape[i]) @compileError("idx: Coordinate out of bounds"); flat += coords[i] * strides_arr[i]; } return flat; } } - // ─────────────────────────────────────────────────────────────── - // Constructors - // ─────────────────────────────────────────────────────────────── - /// Broadcast a single value across all elements. pub inline fn splat(v: T) Self { return .{ .data = @splat(v) }; @@ -279,19 +251,11 @@ pub fn Tensor( pub const zero: Self = splat(0); pub const one: Self = splat(1); - // ─────────────────────────────────────────────────────────────── - // GPU readiness - // ─────────────────────────────────────────────────────────────── - /// Return a mutable slice to the flat storage — zero-copy WebGPU buffer mapping. pub inline fn asSlice(self: *Self) []T { return @as([*]T, @ptrCast(&self.data))[0..total]; } - // ─────────────────────────────────────────────────────────────── - // Internal: RHS normalisation - // ─────────────────────────────────────────────────────────────── - inline fn RhsT(comptime Rhs: type) type { return RhsTensorType(T, Rhs); } @@ -299,10 +263,6 @@ pub fn Tensor( return toRhsTensor(T, r); } - // ─────────────────────────────────────────────────────────────── - // Internal: scalar broadcast (shape {1} → full Vec) - // ─────────────────────────────────────────────────────────────── - inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec { return if (comptime RhsType.total == 1 and total > 1) @splat(r.data[0]) @@ -310,12 +270,8 @@ pub fn Tensor( r.data; } - // ─────────────────────────────────────────────────────────────── - // Arithmetic - // ─────────────────────────────────────────────────────────────── - /// Element-wise add. Dimensions must match; scales resolve to finer. - /// RHS must have the same element count as self, or total == 1 (broadcast). + /// RHS must have the same shape as self, or total == 1 (broadcast). pub inline fn add(self: Self, r: anytype) Tensor( T, dims.argsOpt(), @@ -326,8 +282,8 @@ pub fn Tensor( const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime RhsType.total != total and RhsType.total != 1) - @compileError("Shape mismatch in add: element counts must match or RHS must be scalar (total=1)."); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -339,7 +295,8 @@ pub fn Tensor( return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; } - /// Element-wise subtract. Dimensions must match; scales resolve to finer. + /// Element-wise sub. Dimensions must match; scales resolve to finer. + /// RHS must have the same shape as self, or total == 1 (broadcast). pub inline fn sub(self: Self, r: anytype) Tensor( T, dims.argsOpt(), @@ -350,8 +307,8 @@ pub fn Tensor( const RhsType = @TypeOf(rhs_q); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime RhsType.total != total and RhsType.total != 1) - @compileError("Shape mismatch in sub."); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS."); const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -373,8 +330,8 @@ pub fn Tensor( ) { const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - if (comptime RhsType.total != total and RhsType.total != 1) - @compileError("Shape mismatch in mul."); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS."); const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); @@ -394,8 +351,8 @@ pub fn Tensor( ) { const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - if (comptime RhsType.total != total and RhsType.total != 1) - @compileError("Shape mismatch in div."); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS."); const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); @@ -411,10 +368,6 @@ pub fn Tensor( } } - // ─────────────────────────────────────────────────────────────── - // Unary - // ─────────────────────────────────────────────────────────────── - /// Absolute value of every element. pub inline fn abs(self: Self) Self { return .{ .data = @bitCast(@abs(self.data)) }; @@ -469,13 +422,9 @@ pub fn Tensor( return .{ .data = -self.data }; } - // ─────────────────────────────────────────────────────────────── - // Conversion - // ─────────────────────────────────────────────────────────────── - /// Convert to a compatible Tensor type. /// • Dimension mismatch → compile error. - /// • Dest.total must equal self.total, or Dest.total == 1 (scalar pattern). + /// • Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern). /// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. pub inline fn to( self: Self, @@ -485,20 +434,19 @@ pub fn Tensor( if (comptime !dims.eql(ActualDest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); - if (comptime Self == ActualDest) return self; + if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) + @compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar."); - comptime std.debug.assert(Dest.total == total or Dest.total == 1); + if (comptime Self == ActualDest) return self; const DestT = ActualDest.ValueType; const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); const DestVec = @Vector(total, DestT); - // ── Same numeric type ────────────────────────────────────── if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; - // Integer — branch prevents division-by-zero if (comptime ratio >= 1.0) { const mult: T = comptime @intFromFloat(@round(ratio)); return .{ .data = self.data *| @as(Vec, @splat(mult)) }; @@ -517,7 +465,6 @@ pub fn Tensor( } } - // ── Cross numeric type ───────────────────────────────────── var result: DestVec = undefined; inline for (0..total) |i| { const float_val: f64 = switch (comptime @typeInfo(T)) { @@ -535,17 +482,6 @@ pub fn Tensor( return .{ .data = result }; } - // ─────────────────────────────────────────────────────────────── - // Comparisons - // - // Return type: bool when total == 1 (scalar semantics) - // [total]bool when total > 1 (element-wise, flat-indexed) - // - // Whole-tensor equality check → eqAll / neAll (always returns bool). - // A shape {1} RHS is broadcast automatically, unifying the old - // eqScalar / gtScalar / … family into the plain eq / gt / … methods. - // ─────────────────────────────────────────────────────────────── - const CmpResult = if (total == 1) bool else [total]bool; inline fn cmpResult(v: @Vector(total, bool)) CmpResult { @@ -555,6 +491,9 @@ pub fn Tensor( /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { const RhsType = @TypeOf(rhs_q); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); + const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { @@ -626,26 +565,6 @@ pub fn Tensor( return !self.eqAll(other); } - // ─────────────────────────────────────────────────────────────── - // Contraction — generalised dot product / matrix multiply / einsum - // - // a.contract(b, axis_a, axis_b) - // - // Sums over dimension `axis_a` of `a` and `axis_b` of `b`. - // Requires a.shape[axis_a] == b.shape[axis_b] (checked at comptime). - // - // Result shape = a.shape \ axis_a ++ b.shape \ axis_b - // Result dims = a.dims + b.dims (exponents summed, as in mul) - // Result scales = finer of a, b - // - // Special cases: - // rank-1 × rank-1, axis 0 × 0 → dot product (result shape {1}) - // rank-2 × rank-2, axis 1 × 0 → matrix multiply - // rank-1 × rank-2, axis 0 × 0 → vector–matrix product - // - // All index arithmetic is comptime; runtime cost is the multiply-add loop only. - // ─────────────────────────────────────────────────────────────── - pub inline fn contract( self: Self, other: anytype, @@ -653,10 +572,10 @@ pub fn Tensor( comptime axis_b: usize, ) blk: { const OT = @TypeOf(other); - std.debug.assert(axis_a < rank); - std.debug.assert(axis_b < OT.rank); - std.debug.assert(shape_[axis_a] == OT.shape[axis_b]); - // Contracted-away free axes; empty joint → scalar shape {1} + if (axis_a >= rank) @compileError("contract: axis_a out of bounds"); + if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds"); + if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes"); + const sa = shapeRemoveAxis(shape_, axis_a); const sb = shapeRemoveAxis(OT.shape, axis_b); const rs_raw = shapeCat(&sa, &sb); @@ -683,20 +602,16 @@ pub fn Tensor( rs, ); - // Normalise scales before accumulation const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_); const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape); const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; - // Precompute result strides from rs_raw (for coord decoding) const rs_raw_strides = comptime shapeStrides(&rs_raw); var result: ResultType = .{ .data = @splat(0) }; inline for (0..ResultType.total) |res_flat| { - // Decode result flat index into free coords using rs_raw layout. - // When rs_raw.len == 0, decodeFlatCoords returns [0]usize{} — correct. const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*; @@ -704,7 +619,6 @@ pub fn Tensor( var acc: T = 0; inline for (0..k) |ki| { - // Reinsert the contracted index into free coords → full coord arrays const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free); const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free); const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); @@ -720,10 +634,6 @@ pub fn Tensor( return result; } - // ─────────────────────────────────────────────────────────────── - // Reduction helpers - // ─────────────────────────────────────────────────────────────── - /// Sum of squared elements. Cheaper than length(); use for ordering. pub inline fn lengthSqr(self: Self) T { return @reduce(.Add, self.data * self.data); @@ -749,10 +659,6 @@ pub fn Tensor( return .{ .data = .{@reduce(.Mul, self.data)} }; } - // ─────────────────────────────────────────────────────────────── - // Formatting - // ─────────────────────────────────────────────────────────────── - pub fn formatNumber( self: Self, writer: *std.Io.Writer, @@ -809,15 +715,6 @@ pub fn Tensor( // ═════════════════════════════════════════════════════════════════════════════ // Tests // ───────────────────────────────────────────────────────────────────────────── -// Naming convention used throughout: -// Tensor(T, d, s, &.{1}) → former Scalar -// Tensor(T, d, s, &.{N}) → former Vector of length N -// .data[0] → former .value() -// .mul(x) → former .mulScalar(x) (x may be scalar Tensor or bare number) -// .div(x) → former .divScalar(x) -// .eq(x) → former .eqScalar(x) (broadcasts when x.total==1) -// .contract(other, 0, 0) → former .dot(other) (for rank-1 tensors) -// ═════════════════════════════════════════════════════════════════════════════ // ─── Scalar tests ───────────────────────────────────────────────────────── @@ -1234,7 +1131,6 @@ test "Vector Comparisons" { } test "Vector vs Scalar broadcast comparison" { - // Replaces the old eqScalar / gtScalar — now just eq / gt with a shape-{1} rhs. const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); @@ -1259,7 +1155,6 @@ test "Vector contract — dot product (rank-1 × rank-1)" { const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } }; const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } }; - // work = force · pos const work = force.contract(pos, 0, 0); try std.testing.expectEqual(50.0, work.data[0]); try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M)); @@ -1268,23 +1163,12 @@ test "Vector contract — dot product (rank-1 × rank-1)" { } test "Vector contract — matrix multiply (rank-2 × rank-2)" { - // 2×3 matrix multiplied by 3×2 matrix → 2×2 result const A = Tensor(f32, .{}, .{}, &.{ 2, 3 }); const B = Tensor(f32, .{}, .{}, &.{ 3, 2 }); - // A = [[1, 2, 3], - // [4, 5, 6]] const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } }; - // B = [[7, 8], - // [9, 10], - // [11, 12]] const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } }; - // C = A @ B (contract over axis 1 of A × axis 0 of B) - // C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58 - // C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64 - // C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139 - // C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154 const c = a.contract(b, 1, 0); try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]); try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]); @@ -1371,16 +1255,15 @@ test "Vector eq broadcast on dimensionless" { test "Tensor idx helper and matrix access" { const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 }); - // Identity-like: set [0][0]=1, [1][1]=2, [2][2]=3 var m: Mat3x3 = Mat3x3.zero; m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0; m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0; m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0; - try std.testing.expectEqual(1.0, m.data[0]); // [0][0] - try std.testing.expectEqual(2.0, m.data[4]); // [1][1] (stride 3 → 1*3+1=4) - try std.testing.expectEqual(3.0, m.data[8]); // [2][2] (2*3+2=8) - try std.testing.expectEqual(0.0, m.data[1]); // [0][1] + try std.testing.expectEqual(1.0, m.data[0]); + try std.testing.expectEqual(2.0, m.data[4]); + try std.testing.expectEqual(3.0, m.data[8]); + try std.testing.expectEqual(0.0, m.data[1]); } test "Tensor strides_arr correctness" { From cd954b379b7eedd623d4928b929626427ed121f6 Mon Sep 17 00:00:00 2001 From: AdrienBouvais Date: Mon, 27 Apr 2026 15:13:15 +0200 Subject: [PATCH 04/10] Added cross to Tensor + fix benchmark --- src/Tensor.zig | 61 +++++++++++++++++++++++++++++++++++++++++------ src/benchmark.zig | 56 +++++++++++++++++++++---------------------- 2 files changed, 81 insertions(+), 36 deletions(-) diff --git a/src/Tensor.zig b/src/Tensor.zig index 6cc93ab..ff97dc0 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -17,9 +17,8 @@ pub fn shapeTotal(comptime shape: []const usize) usize { /// Check if two shapes are strictly identical. pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { if (a.len != b.len) return false; - for (a, 0..) |v, i| { + for (a, 0..) |v, i| if (v != b[i]) return false; - } return true; } @@ -432,6 +431,24 @@ pub fn Tensor( ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); + if (comptime Self == ActualDest) return self; + + const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); + const DestT = ActualDest.ValueType; + const DestVec = @Vector(total, DestT); + + // If ratio is 1, just handle type conversion + if (comptime ratio == 1.0) { + if (comptime T == DestT) return .{ .data = self.data }; + return .{ + .data = switch (comptime @typeInfo(DestT)) { + .float => @floatFromInt(self.data), // or @floatCast + .int => @intFromFloat(self.data), // or @intCast + else => unreachable, + }, + }; + } + if (comptime !dims.eql(ActualDest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) @@ -439,10 +456,6 @@ pub fn Tensor( if (comptime Self == ActualDest) return self; - const DestT = ActualDest.ValueType; - const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); - const DestVec = @Vector(total, DestT); - if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; @@ -485,7 +498,7 @@ pub fn Tensor( const CmpResult = if (total == 1) bool else [total]bool; inline fn cmpResult(v: @Vector(total, bool)) CmpResult { - return if (comptime total == 1) v[0] else @as([total]bool, v); + return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v); } /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. @@ -634,6 +647,40 @@ pub fn Tensor( return result; } + /// 3D Cross Product. Only defined for Rank-1 tensors of length 3. + /// Result dimensions are the sum of input dimensions. + pub inline fn cross(self: Self, other: anytype) Tensor( + T, + dims.add(RhsT(@TypeOf(other)).dims).argsOpt(), + finerScales(Self, RhsT(@TypeOf(other))).argsOpt(), + &.{3}, + ) { + const rhs_q = rhs(other); + const RhsType = @TypeOf(rhs_q); + + if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) { + @compileError("cross product is only defined for 3D vectors (rank-1, length 3)"); + } + + // Bring both to the same scale (e.g., mm vs m) + const p = self.resolveScalePair(rhs_q); + const l = p.l; + const r = p.r; + + var res: [3]T = undefined; + if (comptime isInt(T)) { + res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]); + res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]); + res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]); + } else { + res[0] = (l[1] * r[2]) - (l[2] * r[1]); + res[1] = (l[2] * r[0]) - (l[0] * r[2]); + res[2] = (l[0] * r[1]) - (l[1] * r[0]); + } + + return .{ .data = res }; + } + /// Sum of squared elements. Cheaper than length(); use for ordering. pub inline fn lengthSqr(self: Self) T { return @reduce(.Add, self.data * self.data); diff --git a/src/benchmark.zig b/src/benchmark.zig index 7da3e6f..c50b3a9 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -1,7 +1,6 @@ const std = @import("std"); const Io = std.Io; -const Scalar = @import("Quantity.zig").Scalar; -const Vector = @import("Quantity.zig").Vector; +const Tensor = @import("Tensor.zig").Tensor; var io: Io = undefined; pub fn main(init: std.process.Init) !void { @@ -11,23 +10,23 @@ pub fn main(init: std.process.Init) !void { io = init.io; - // try vectorSIMDvsNative(f64, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(f32, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(i32, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(i64, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(i128, &stdout_writer.interface); - // try stdout_writer.flush(); + try vectorSIMDvsNative(f64, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(f32, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(i32, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(i64, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(i128, &stdout_writer.interface); + try stdout_writer.flush(); try bench_Scalar(&stdout_writer.interface); try stdout_writer.flush(); - try bench_vsNative(&stdout_writer.interface); - try stdout_writer.flush(); - try bench_crossTypeVsNative(&stdout_writer.interface); - try stdout_writer.flush(); + // try bench_vsNative(&stdout_writer.interface); + // try stdout_writer.flush(); + // try bench_crossTypeVsNative(&stdout_writer.interface); + // try stdout_writer.flush(); try bench_Vector(&stdout_writer.interface); try stdout_writer.flush(); } @@ -97,9 +96,9 @@ fn bench_Scalar(writer: *std.Io.Writer) !void { comptime var tidx: usize = 0; inline for (Types, TNames) |T, tname| { - const M = Scalar(T, .{ .L = 1 }, .{}); - const KM = Scalar(T, .{ .L = 1 }, .{ .L = .k }); - const S = Scalar(T, .{ .T = 1 }, .{}); + const M = Tensor(T, .{ .L = 1 }, .{}, &.{1}); + const KM = Tensor(T, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const S = Tensor(T, .{ .T = 1 }, .{}, &.{1}); inline for (Ops, 0..) |op_name, oidx| { var samples: [SAMPLES]f64 = undefined; @@ -199,8 +198,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { var native_total_ns: f64 = 0; var quantity_total_ns: f64 = 0; - const M = Scalar(T, .{ .L = 1 }, .{}); - const S = Scalar(T, .{ .T = 1 }, .{}); + const M = Tensor(T, .{ .L = 1 }, .{}, &.{1}); + const S = Tensor(T, .{ .T = 1 }, .{}, &.{1}); std.mem.doNotOptimizeAway({ for (0..SAMPLES) |_| { @@ -321,9 +320,9 @@ fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void { var native_total_ns: f64 = 0; var quantity_total_ns: f64 = 0; - const M1 = Scalar(T1, .{ .L = 1 }, .{}); - const M2 = Scalar(T2, .{ .L = 1 }, .{}); - const S2 = Scalar(T2, .{ .T = 1 }, .{}); + const M1 = Tensor(T1, .{ .L = 1 }, .{}, &.{1}); + const M2 = Tensor(T2, .{ .L = 1 }, .{}, &.{1}); + const S2 = Tensor(T2, .{ .T = 1 }, .{}, &.{1}); std.mem.doNotOptimizeAway({ for (0..SAMPLES) |_| { @@ -429,9 +428,8 @@ fn bench_Vector(writer: *std.Io.Writer) !void { try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname }); inline for (Lengths) |len| { - const Q_base = Scalar(T, .{ .L = 1 }, .{}); - const Q_time = Scalar(T, .{ .T = 1 }, .{}); - const V = Vector(len, Q_base); + const Q_time = Tensor(T, .{ .T = 1 }, .{}, &.{1}); + const V = Tensor(T, .{ .L = 1 }, .{}, &.{len}); // cross product is only defined for len == 3 const is_cross = comptime std.mem.eql(u8, op_name, "cross"); @@ -455,10 +453,10 @@ fn bench_Vector(writer: *std.Io.Writer) !void { _ = v1.div(V.splat(getVal(T, i +% 2, 63))); } else if (comptime std.mem.eql(u8, op_name, "mulScalar")) { const s_val = Q_time.splat(getVal(T, i +% 2, 63)); - _ = v1.mulScalar(s_val); + _ = v1.mul(s_val); } else if (comptime std.mem.eql(u8, op_name, "dot")) { const v2 = V.splat(getVal(T, i +% 5, 63)); - _ = v1.dot(v2); + _ = v1.contract(v2, 0, 0); } else if (comptime std.mem.eql(u8, op_name, "cross")) { // len == 3 guaranteed by the guard above const v2 = V.splat(getVal(T, i +% 5, 63)); From 44aaa8a8b2648b37f2baf5e9e9f683cecb3c8988 Mon Sep 17 00:00:00 2001 From: AdrienBouvais Date: Mon, 27 Apr 2026 16:11:46 +0200 Subject: [PATCH 05/10] Removed all inline for (0..total) for either builtin or for loop without inline This is to prevent giant binary for Tensor with a lot of Scalar --- src/Dimensions.zig | 24 ++-- src/Scales.zig | 12 +- src/Tensor.zig | 269 ++++++++++++++++++++++++++++----------------- src/benchmark.zig | 8 +- 4 files changed, 190 insertions(+), 123 deletions(-) diff --git a/src/Dimensions.zig b/src/Dimensions.zig index e91e78d..ef37fe1 100644 --- a/src/Dimensions.zig +++ b/src/Dimensions.zig @@ -51,17 +51,17 @@ data: std.EnumArray(Dimension, comptime_int), /// Unspecified dimensions default to 0. pub fn init(comptime init_val: ArgOpts) Self { var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) }; - inline for (std.meta.fields(@TypeOf(init_val))) |f| + for (std.meta.fields(@TypeOf(init_val))) |f| s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); return s; } pub fn initFill(comptime val: comptime_int) Self { - return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) }; + comptime return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) }; } pub fn get(comptime self: Self, comptime key: Dimension) comptime_int { - return self.data.get(key); + comptime return self.data.get(key); } pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void { @@ -70,40 +70,40 @@ pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void pub fn argsOpt(self: Self) ArgOpts { var args: ArgOpts = undefined; - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| @field(args, @tagName(d)) = self.get(d); - return args; + comptime return args; } /// Add exponents component-wise. Used internally by `mul`. pub fn add(comptime a: Self, comptime b: Self) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) + b.get(d)); - return result; + comptime return result; } /// Subtract exponents component-wise. Used internally by `div`. pub fn sub(comptime a: Self, comptime b: Self) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) - b.get(d)); - return result; + comptime return result; } /// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar. pub fn scale(comptime a: Self, comptime exp: comptime_int) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) * exp); - return result; + comptime return result; } pub fn div(comptime a: Self, comptime exp: comptime_int) Self { var result = Self.initFill(0); inline for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) / exp); - return result; + comptime return result; } /// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`. diff --git a/src/Scales.zig b/src/Scales.zig index 06031fb..ea6ff08 100644 --- a/src/Scales.zig +++ b/src/Scales.zig @@ -65,7 +65,7 @@ pub const UnitScale = enum(isize) { } pub inline fn getFactor(self: @This()) comptime_float { - return comptime switch (self) { + comptime return switch (self) { // Standard SI Exponents inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))), @@ -83,7 +83,7 @@ pub const UnitScale = enum(isize) { inline .lb => 453.59237, inline .st => 6350.29318, - inline else => @floatFromInt(@intFromEnum(self)), + else => @floatFromInt(@intFromEnum(self)), }; } }; @@ -103,11 +103,11 @@ pub fn init(comptime init_val: ArgOpts) Self { else s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); } - return s; + return comptime s; } pub fn initFill(comptime val: UnitScale) Self { - return comptime .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; + comptime return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; } pub fn get(comptime self: Self, comptime key: Dimension) UnitScale { @@ -115,7 +115,7 @@ pub fn get(comptime self: Self, comptime key: Dimension) UnitScale { } pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void { - comptime self.data.set(key, val); + self.data.set(key, val); } pub fn argsOpt(self: Self) ArgOpts { @@ -144,5 +144,5 @@ pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float factor /= base; } } - return comptime factor; + comptime return factor; } diff --git a/src/Tensor.zig b/src/Tensor.zig index ff97dc0..1cc3a91 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension; // Comptime utilities // ───────────────────────────────────────────────────────────────────────────── -pub fn shapeTotal(comptime shape: []const usize) usize { - var t: usize = 1; +pub fn shapeTotal(comptime shape: []const comptime_int) usize { + var t: comptime_int = 1; for (shape) |s| t *= s; return t; } /// Check if two shapes are strictly identical. -pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { +pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool { if (a.len != b.len) return false; for (a, 0..) |v, i| if (v != b[i]) return false; @@ -25,21 +25,21 @@ pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { /// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). /// e.g. shape {3, 4} → strides {4, 1} /// shape {2, 3, 4} → strides {12, 4, 1} -pub fn shapeStrides(comptime shape: []const usize) [shape.len]usize { - var st: [shape.len]usize = undefined; +pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int { + var st: [shape.len]comptime_int = undefined; if (shape.len == 0) return st; st[shape.len - 1] = 1; if (shape.len > 1) { - var i: usize = shape.len - 1; + var i: comptime_int = shape.len - 1; while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i]; } return st; } /// Return a copy of `shape` with the element at `axis` removed. -pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [shape.len - 1]usize { - var out: [shape.len - 1]usize = undefined; - var j: usize = 0; +pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int { + var out: [shape.len - 1]comptime_int = undefined; + var j: comptime_int = 0; for (shape, 0..) |v, i| { if (i != axis) { out[j] = v; @@ -50,8 +50,8 @@ pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [sha } /// Concatenate two compile-time slices. -pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b.len]usize { - var out: [a.len + b.len]usize = undefined; +pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int { + var out: [a.len + b.len]comptime_int = undefined; for (a, 0..) |v, i| out[i] = v; for (b, 0..) |v, i| out[a.len + i] = v; return out; @@ -60,11 +60,11 @@ pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b /// Decode a flat row-major index into N-D coordinates. /// Called only in comptime contexts (all arguments are comptime). pub fn decodeFlatCoords( - comptime flat: usize, - comptime n: usize, - comptime strd: [n]usize, + comptime flat: comptime_int, + comptime n: comptime_int, + comptime strd: [n]comptime_int, ) [n]usize { - var coords: [n]usize = undefined; + var coords: [n]comptime_int = undefined; var tmp = flat; for (0..n) |i| { coords[i] = if (strd[i] == 0) 0 else tmp / strd[i]; @@ -116,7 +116,7 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales { const s1: Scales = T1.scales; const s2: Scales = T2.scales; comptime var out = Scales.initFill(.none); - inline for (std.enums.values(Dimension)) |dim| { + for (std.enums.values(Dimension)) |dim| { const scale1 = comptime s1.get(dim); const scale2 = comptime s2.get(dim); out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0) @@ -201,7 +201,7 @@ pub fn Tensor( comptime T: type, comptime d_opt: Dimensions.ArgOpts, comptime s_opt: Scales.ArgOpts, - comptime shape_: []const usize, + comptime shape_: []const comptime_int, ) type { comptime { if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1)."); @@ -223,10 +223,10 @@ pub fn Tensor( pub const ValueType: type = T; pub const dims: Dimensions = Dimensions.init(d_opt); pub const scales: Scales = Scales.init(s_opt); - pub const shape: []const usize = shape_; - pub const rank: usize = shape_.len; - pub const total: usize = _total; - pub const strides_arr: [shape_.len]usize = _strides; + pub const shape: []const comptime_int = shape_; + pub const rank: comptime_int = shape_.len; + pub const total: comptime_int = _total; + pub const strides_arr: [shape_.len]comptime_int = _strides; pub const ISTENSOR = true; /// Convert N-D coords (row-major) to flat index — fully comptime. @@ -359,9 +359,7 @@ pub fn Tensor( const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr: Vec = broadcastToVec(RhsNorm, rr_base); if (comptime isInt(T)) { - var result: Vec = undefined; - inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]); - return .{ .data = result }; + return .{ .data = @divTrunc(l, rr) }; } else { return .{ .data = l / rr }; } @@ -379,19 +377,27 @@ pub fn Tensor( scales.argsOpt(), shape_, ) { - if (comptime isInt(T)) { - var result: Vec = undefined; - inline for (0..total) |i| - result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T); - return .{ .data = result }; - } else { - const abs_exp = comptime @abs(exp); - var result: Vec = @splat(1); - comptime var i = 0; - inline while (i < abs_exp) : (i += 1) result *= self.data; - if (comptime exp < 0) result = @as(Vec, @splat(1)) / result; - return .{ .data = result }; + if (comptime exp == 0) return .{ .data = @splat(1) }; + if (comptime exp == 1) return self; + + var base = self.data; + var result: Vec = @splat(1); + comptime var e = @abs(exp); + + // $O(\log n)$ Exponentiation by squaring applied to the entire vector + inline while (e > 0) { + if (e % 2 == 1) { + result = if (comptime isInt(T)) result *| base else result * base; + } + e /= 2; + if (e > 0) { + base = if (comptime isInt(T)) base *| base else base * base; + } } + if (comptime !isInt(T) and exp < 0) { + result = @as(Vec, @splat(1)) / result; + } + return .{ .data = result }; } /// Square root of every element. All dimension exponents must be even. @@ -404,15 +410,16 @@ pub fn Tensor( if (comptime !dims.isSquare()) @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); if (comptime @typeInfo(T) == .float) { - return .{ .data = @sqrt(self.data) }; + return .{ .data = @sqrt(self.data) }; // Float is natively vectorized! } else { - var result: Vec = undefined; + const arr: [total]T = self.data; // Add this! + var res_arr: [total]T = undefined; const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); - inline for (0..total) |i| { - const v = self.data[i]; - result[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); + for (0..total) |i| { + const v = arr[i]; + res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); } - return .{ .data = result }; + return .{ .data = res_arr }; } } @@ -433,28 +440,34 @@ pub fn Tensor( if (comptime Self == ActualDest) return self; - const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); - const DestT = ActualDest.ValueType; - const DestVec = @Vector(total, DestT); - - // If ratio is 1, just handle type conversion - if (comptime ratio == 1.0) { - if (comptime T == DestT) return .{ .data = self.data }; - return .{ - .data = switch (comptime @typeInfo(DestT)) { - .float => @floatFromInt(self.data), // or @floatCast - .int => @intFromFloat(self.data), // or @intCast - else => unreachable, - }, - }; - } - + // Run validation checks FIRST before dealing with types if (comptime !dims.eql(ActualDest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) @compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar."); - if (comptime Self == ActualDest) return self; + const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); + const DestT = ActualDest.ValueType; + const DestVec = @Vector(total, DestT); + + // If ratio is 1, handle type conversion correctly based on BOTH source and dest types + if (comptime ratio == 1.0) { + const T_info = @typeInfo(T); + const Dest_info = @typeInfo(DestT); + + return .{ + .data = if (comptime T_info == .int and Dest_info == .int) + @as(DestVec, @intCast(self.data)) + else if (comptime T_info == .float and Dest_info == .float) + @as(DestVec, @floatCast(self.data)) + else if (comptime T_info == .int and Dest_info == .float) + @as(DestVec, @floatFromInt(self.data)) + else if (comptime T_info == .float and Dest_info == .int) + @as(DestVec, @intFromFloat(self.data)) // Or @intFromFloat(@round(self.data)) if you want rounding + else + unreachable, + }; + } if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) @@ -466,33 +479,33 @@ pub fn Tensor( } else { const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); const half: T = comptime @divTrunc(div_val, 2); - var result: DestVec = undefined; - inline for (0..total) |i| { - const val = self.data[i]; - result[i] = if (val >= 0) - @divTrunc(val + half, div_val) - else - @divTrunc(val - half, div_val); + + if (comptime @typeInfo(T).int.signedness == .unsigned) { + return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) }; + } else { + // Vectorized branchless negative handling + const is_pos = self.data >= @as(Vec, @splat(0)); + const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half))); + return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) }; } - return .{ .data = result }; } } - var result: DestVec = undefined; - inline for (0..total) |i| { - const float_val: f64 = switch (comptime @typeInfo(T)) { - .float => @floatCast(self.data[i]), - .int => @floatFromInt(self.data[i]), - else => unreachable, - }; - const scaled = float_val * ratio; - result[i] = switch (comptime @typeInfo(DestT)) { - .float => @floatCast(scaled), - .int => @intFromFloat(@round(scaled)), - else => unreachable, - }; - } - return .{ .data = result }; + // Cross-type fully vectorized casting with scales + const FVec = @Vector(total, f64); + const float_vec: FVec = switch (comptime @typeInfo(T)) { + .float => @floatCast(self.data), + .int => @floatFromInt(self.data), + else => unreachable, + }; + + const scaled = float_vec * @as(FVec, @splat(ratio)); + + return switch (comptime @typeInfo(DestT)) { + .float => .{ .data = @floatCast(scaled) }, + .int => .{ .data = @intFromFloat(@round(scaled)) }, + else => unreachable, + }; } const CmpResult = if (total == 1) bool else [total]bool; @@ -592,7 +605,7 @@ pub fn Tensor( const sa = shapeRemoveAxis(shape_, axis_a); const sb = shapeRemoveAxis(OT.shape, axis_b); const rs_raw = shapeCat(&sa, &sb); - const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw; + const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw; break :blk Tensor( T, dims.add(OT.dims).argsOpt(), @@ -606,7 +619,7 @@ pub fn Tensor( const sa = comptime shapeRemoveAxis(shape_, axis_a); const sb = comptime shapeRemoveAxis(OT.shape, axis_b); const rs_raw = comptime shapeCat(&sa, &sb); - const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; + const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; const ResultType = Tensor( T, @@ -617,34 +630,85 @@ pub fn Tensor( const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_); const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape); + const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; + // FAST PATH: Dot Product + if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) { + if (comptime !isInt(T)) { + return .{ .data = @splat(@reduce(.Add, a_data * b_data)) }; + } else { + // For integers, we do a vectorized saturating multiply, + // then convert to an array to do a saturating sum + const mul_arr: [total]T = a_data *| b_data; + var acc: T = 0; + for (mul_arr) |val| acc +|= val; + return .{ .data = @splat(acc) }; + } + } + + // --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING --- + const a_arr: [total]T = a_data; + const b_arr: [OT.total]T = b_data; + + // FAST PATH: 2D Matrix Multiplication + if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) { + const rows = shape_[0]; + const cols = OT.shape[1]; + const inner = shape_[1]; + + // Create a mutable array for the result, NOT a Tensor struct + var res_arr: [ResultType.total]T = undefined; + + for (0..rows) |i| { + for (0..cols) |j| { + var acc: T = 0; + for (0..inner) |id| { + const a_flat = i * _strides[0] + id * _strides[1]; + const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1]; + + // Use a_arr and b_arr here + if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; + } + // Write to the array + res_arr[i * cols + j] = acc; + } + } + // Return the initialized Tensor struct + return .{ .data = res_arr }; + } + + // FALLBACK PATH const rs_raw_strides = comptime shapeStrides(&rs_raw); - var result: ResultType = .{ .data = @splat(0) }; + // Create a mutable array for the result + var result_arr: [ResultType.total]T = undefined; - inline for (0..ResultType.total) |res_flat| { - const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); + for (0..ResultType.total) |res_flat| { + const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); - const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*; - const b_free: [sb.len]usize = comptime res_coords[sa.len..].*; + var a_free: [sa.len]usize = undefined; + for (0..sa.len) |i| a_free[i] = res_coords[i]; + var b_free: [sb.len]usize = undefined; + for (0..sb.len) |i| b_free[i] = res_coords[sa.len + i]; var acc: T = 0; - inline for (0..k) |ki| { - const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free); - const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free); - const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); - const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); + for (0..k) |ki| { + const a_coords = insertAxis(rank, axis_a, ki, &a_free); + const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free); + const a_flat = encodeFlatCoords(&a_coords, rank, _strides); + const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); - if (comptime isInt(T)) - acc +|= a_data[a_flat] *| b_data[b_flat] - else - acc += a_data[a_flat] * b_data[b_flat]; + // Use a_arr and b_arr here + if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; } - result.data[res_flat] = acc; + // Write to the array + result_arr[res_flat] = acc; } - return result; + + // Return the initialized Tensor struct + return .{ .data = result_arr }; } /// 3D Cross Product. Only defined for Rank-1 tensors of length 3. @@ -724,7 +788,8 @@ pub fn Tensor( } } else { try writer.writeAll("("); - inline for (0..total) |i| { + const max_to_print = 6; + inline for (0..@min(total, max_to_print)) |i| { if (i > 0) try writer.writeAll(", "); switch (@typeInfo(T)) { .float, .comptime_float => try writer.printFloat(self.data[i], options), @@ -736,6 +801,8 @@ pub fn Tensor( }), else => unreachable, } + if (comptime i == max_to_print - 1 and total != max_to_print - 1) + try writer.writeAll(", ..."); } try writer.writeAll(")"); } diff --git a/src/benchmark.zig b/src/benchmark.zig index c50b3a9..488e494 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -23,10 +23,10 @@ pub fn main(init: std.process.Init) !void { try bench_Scalar(&stdout_writer.interface); try stdout_writer.flush(); - // try bench_vsNative(&stdout_writer.interface); - // try stdout_writer.flush(); - // try bench_crossTypeVsNative(&stdout_writer.interface); - // try stdout_writer.flush(); + try bench_vsNative(&stdout_writer.interface); + try stdout_writer.flush(); + try bench_crossTypeVsNative(&stdout_writer.interface); + try stdout_writer.flush(); try bench_Vector(&stdout_writer.interface); try stdout_writer.flush(); } From 698e968ef8bd5a0ba191cd07291e383dc3f2dcef Mon Sep 17 00:00:00 2001 From: AdrienBouvais Date: Mon, 27 Apr 2026 16:18:31 +0200 Subject: [PATCH 06/10] Added a high dimension benchmark --- src/benchmark.zig | 98 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/src/benchmark.zig b/src/benchmark.zig index 488e494..6371c25 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -29,6 +29,8 @@ pub fn main(init: std.process.Init) !void { try stdout_writer.flush(); try bench_Vector(&stdout_writer.interface); try stdout_writer.flush(); + try bench_HighDimTensor(&stdout_writer.interface); + try stdout_writer.flush(); } fn getTime() Io.Timestamp { @@ -488,6 +490,102 @@ fn bench_Vector(writer: *std.Io.Writer) !void { try writer.print("└──────────────────┴──────┴─────────┴─────────┴─────────┴─────────┴─────────┘\n", .{}); } +fn bench_HighDimTensor(writer: *std.Io.Writer) !void { + const ITERS: usize = 5_000; + const SAMPLES: usize = 5; + + const getVal = struct { + fn f(comptime TT: type, i: usize, comptime mask: u7) TT { + const v: u8 = @as(u8, @truncate(i & @as(usize, mask))) + 1; + return if (comptime @typeInfo(TT) == .float) @floatFromInt(v) else @intCast(v); + } + }.f; + + const computeStats = struct { + fn f(samples: []f64, iters: usize) f64 { + std.mem.sort(f64, samples, {}, std.sort.asc(f64)); + const mid = samples.len / 2; + const median_ns = if (samples.len % 2 == 0) + (samples[mid - 1] + samples[mid]) / 2.0 + else + samples[mid]; + return median_ns / @as(f64, @floatFromInt(iters)); + } + }.f; + + try writer.print( + \\ + \\ High Dimension Tensor benchmark — {d} iterations, {d} samples/cell + \\ (Results in ns/op) + \\ + \\┌─────────────────┬──────┬──────────────┬──────────────┬──────────────┬──────────────┐ + \\│ Operation │ Type │ 2x2x2 │ 3x3x3 │ 4x4x4 │ 10x10x10x10 │ + \\├─────────────────┼──────┼──────────────┼──────────────┼──────────────┼──────────────┤ + \\ + , .{ ITERS, SAMPLES }); + + const Types = .{ i32, i64, f32, f64 }; + const TNames = .{ "i32", "i64", "f32", "f64" }; + + // Testing multiple structural bounds + const Shapes = .{ + &.{ 2, 2, 2 }, + &.{ 3, 3, 3 }, + &.{ 4, 4, 4 }, + &.{ 10, 10, 10, 10 }, + }; + + const Ops = .{ "add", "sub", "mulElem", "mulScalar", "abs" }; + + inline for (Ops, 0..) |op_name, o_idx| { + inline for (Types, TNames) |T, tname| { + try writer.print("│ {s:<15} │ {s:<4} │", .{ op_name, tname }); + + inline for (Shapes) |shape| { + const V = Tensor(T, .{ .L = 1 }, .{}, shape); + const Q = Tensor(T, .{ .T = 1 }, .{}, &.{1}); // For scalar broadcasting operations + + var samples: [SAMPLES]f64 = undefined; + + for (0..SAMPLES) |s_idx| { + const t_start = getTime(); + + for (0..ITERS) |i| { + std.mem.doNotOptimizeAway({ + const t1 = V.splat(getVal(T, i, 63)); + + _ = if (comptime std.mem.eql(u8, op_name, "add")) + t1.add(V.splat(getVal(T, i +% 7, 63))) + else if (comptime std.mem.eql(u8, op_name, "sub")) + t1.sub(V.splat(getVal(T, i +% 3, 63))) + else if (comptime std.mem.eql(u8, op_name, "mulElem")) + t1.mul(V.splat(getVal(T, i +% 5, 63))) + else if (comptime std.mem.eql(u8, op_name, "mulScalar")) + t1.mul(Q.splat(getVal(T, i +% 2, 63))) + else if (comptime std.mem.eql(u8, op_name, "abs")) + t1.abs() + else + unreachable; + }); + } + + const t_end = getTime(); + samples[s_idx] = @as(f64, @floatFromInt(t_start.durationTo(t_end).toNanoseconds())); + } + + const median_ns_per_op = computeStats(&samples, ITERS); + try writer.print(" {d:>12.1} │", .{median_ns_per_op}); + } + try writer.print("\n", .{}); + } + + if (o_idx < Ops.len - 1) { + try writer.print("├─────────────────┼──────┼──────────────┼──────────────┼──────────────┼──────────────┤\n", .{}); + } + } + try writer.print("└─────────────────┴──────┴──────────────┴──────────────┴──────────────┴──────────────┘\n", .{}); +} + fn vectorSIMDvsNative(comptime T: type, writer: *std.Io.Writer) !void { const iterations: u64 = 10_000; const lens = [_]u32{ 1, 2, 3, 4, 5, 10, 100, 1_000, 10_000 }; From 168312b78e6b7676578e6b03a7416e0709d29aec Mon Sep 17 00:00:00 2001 From: adrien Date: Mon, 27 Apr 2026 19:09:55 +0200 Subject: [PATCH 07/10] Removed lots of usless inline and comptime in Scales and Dimensions --- src/Dimensions.zig | 44 +++++++++++++++++++++--------------------- src/Scales.zig | 48 +++++++++++++++++++++++----------------------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/src/Dimensions.zig b/src/Dimensions.zig index ef37fe1..d54ced2 100644 --- a/src/Dimensions.zig +++ b/src/Dimensions.zig @@ -49,22 +49,22 @@ data: std.EnumArray(Dimension, comptime_int), /// Create a `Dimensions` from a struct literal, e.g. `.{ .L = 1, .T = -1 }`. /// Unspecified dimensions default to 0. -pub fn init(comptime init_val: ArgOpts) Self { +pub fn init(init_val: ArgOpts) Self { var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) }; for (std.meta.fields(@TypeOf(init_val))) |f| s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); return s; } -pub fn initFill(comptime val: comptime_int) Self { - comptime return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) }; +pub fn initFill(val: comptime_int) Self { + return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) }; } -pub fn get(comptime self: Self, comptime key: Dimension) comptime_int { - comptime return self.data.get(key); +pub fn get(self: Self, key: Dimension) comptime_int { + return self.data.get(key); } -pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void { +pub fn set(self: *Self, key: Dimension, val: i8) void { self.data.set(key, val); } @@ -72,58 +72,58 @@ pub fn argsOpt(self: Self) ArgOpts { var args: ArgOpts = undefined; for (std.enums.values(Dimension)) |d| @field(args, @tagName(d)) = self.get(d); - comptime return args; + return args; } /// Add exponents component-wise. Used internally by `mul`. -pub fn add(comptime a: Self, comptime b: Self) Self { +pub fn add(a: Self, b: Self) Self { var result = Self.initFill(0); for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) + b.get(d)); - comptime return result; + return result; } /// Subtract exponents component-wise. Used internally by `div`. -pub fn sub(comptime a: Self, comptime b: Self) Self { +pub fn sub(a: Self, b: Self) Self { var result = Self.initFill(0); for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) - b.get(d)); - comptime return result; + return result; } /// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar. -pub fn scale(comptime a: Self, comptime exp: comptime_int) Self { +pub fn scale(a: Self, exp: comptime_int) Self { var result = Self.initFill(0); for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) * exp); - comptime return result; + return result; } -pub fn div(comptime a: Self, comptime exp: comptime_int) Self { +pub fn div(a: Self, exp: comptime_int) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) / exp); - comptime return result; + return result; } /// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`. -pub fn eql(comptime a: Self, comptime b: Self) bool { - inline for (std.enums.values(Dimension)) |d| +pub fn eql(a: Self, b: Self) bool { + for (std.enums.values(Dimension)) |d| if (a.get(d) != b.get(d)) return false; return true; } -pub fn isSquare(comptime a: Self) bool { - inline for (std.enums.values(Dimension)) |d| +pub fn isSquare(a: Self) bool { + for (std.enums.values(Dimension)) |d| if (a.get(d) % 2 != 0) return false; return true; } -pub fn str(comptime a: Self) []const u8 { +pub fn str(a: Self) []const u8 { var out: []const u8 = ""; const dims = std.enums.values(Dimension); - inline for (dims) |d| { + for (dims) |d| { const val = a.get(d); if (val != 0) { out = out ++ @tagName(d) ++ std.fmt.comptimePrint("{d}", .{val}); diff --git a/src/Scales.zig b/src/Scales.zig index ea6ff08..54484b4 100644 --- a/src/Scales.zig +++ b/src/Scales.zig @@ -55,33 +55,33 @@ pub const UnitScale = enum(isize) { // Undefined _, - pub inline fn str(self: @This()) []const u8 { + pub fn str(self: @This()) []const u8 { var buf: [16]u8 = undefined; return switch (self) { - inline .none => "", - inline .P, .T, .G, .M, .k, .h, .da, .d, .c, .m, .u, .n, .p, .f, .min, .hour, .year, .inch, .ft, .yd, .mi, .oz, .lb, .st => @tagName(self), + .none => "", + .P, .T, .G, .M, .k, .h, .da, .d, .c, .m, .u, .n, .p, .f, .min, .hour, .year, .inch, .ft, .yd, .mi, .oz, .lb, .st => @tagName(self), else => std.fmt.bufPrint(&buf, "[{d}]", .{@intFromEnum(self)}) catch "[]", // This cannot be inline because of non exhaustive enum, but that's ok, it is just str, not calculation }; } - pub inline fn getFactor(self: @This()) comptime_float { - comptime return switch (self) { + pub fn getFactor(self: @This()) comptime_float { + return switch (self) { // Standard SI Exponents - inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))), + .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))), // Time Factors - inline .min, .hour, .year => @floatFromInt(@intFromEnum(self)), + .min, .hour, .year => @floatFromInt(@intFromEnum(self)), // Imperial Length (metres) - inline .inch => 0.0254, - inline .ft => 0.3048, - inline .yd => 0.9144, - inline .mi => 1609.344, + .inch => 0.0254, + .ft => 0.3048, + .yd => 0.9144, + .mi => 1609.344, // Imperial Mass (grams — base unit for M is gram, i.e. .none = 1 g) - inline .oz => 28.3495231, - inline .lb => 453.59237, - inline .st => 6350.29318, + .oz => 28.3495231, + .lb => 453.59237, + .st => 6350.29318, else => @floatFromInt(@intFromEnum(self)), }; @@ -97,7 +97,7 @@ data: std.EnumArray(Dimension, UnitScale), /// Unspecified dimensions default to `.none` (factor 1). pub fn init(comptime init_val: ArgOpts) Self { comptime var s = Self{ .data = std.EnumArray(Dimension, UnitScale).initFill(.none) }; - inline for (std.meta.fields(@TypeOf(init_val))) |f| { + for (std.meta.fields(@TypeOf(init_val))) |f| { if (comptime @typeInfo(@TypeOf(@field(init_val, f.name))) == .comptime_int) s.data.set(@field(Dimension, f.name), @enumFromInt(@field(init_val, f.name))) else @@ -106,31 +106,31 @@ pub fn init(comptime init_val: ArgOpts) Self { return comptime s; } -pub fn initFill(comptime val: UnitScale) Self { - comptime return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; +pub fn initFill(val: UnitScale) Self { + return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; } -pub fn get(comptime self: Self, comptime key: Dimension) UnitScale { - return comptime self.data.get(key); +pub fn get(self: Self, key: Dimension) UnitScale { + return self.data.get(key); } -pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void { +pub fn set(self: *Self, key: Dimension, val: UnitScale) void { self.data.set(key, val); } pub fn argsOpt(self: Self) ArgOpts { var args: ArgOpts = undefined; - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| @field(args, @tagName(d)) = self.get(d); return args; } /// Compute the combined scale factor for a given dimension signature. /// Each dimension's prefix is raised to its exponent and multiplied together. -pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float { +pub fn getFactor(s: Self, d: Dimensions) comptime_float { var factor: f64 = 1.0; for (std.enums.values(Dimension)) |dim| { - const power = comptime d.get(dim); + const power = d.get(dim); if (power == 0) continue; const base = s.get(dim).getFactor(); @@ -144,5 +144,5 @@ pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float factor /= base; } } - comptime return factor; + return factor; } From a0961e757174f7307ce7a498ca5aea70ebdd5d9f Mon Sep 17 00:00:00 2001 From: adrien Date: Mon, 27 Apr 2026 21:24:03 +0200 Subject: [PATCH 08/10] Start to optimize the shit out of it, still a long way to go After that GPPPPPUUUUUU baby! --- build.zig | 2 +- src/Base.zig | 10 ++++---- src/Tensor.zig | 45 +++++++++++++++++----------------- src/benchmark.zig | 61 ++++++++++++++++++++++------------------------- 4 files changed, 57 insertions(+), 61 deletions(-) diff --git a/build.zig b/build.zig index e507138..2883cb9 100644 --- a/build.zig +++ b/build.zig @@ -2,7 +2,7 @@ const std = @import("std"); pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); - const optimize = b.standardOptimizeOption(.{}); + const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseFast }); // 1. Define the module so other projects can import it _ = b.addModule("dimal", .{ diff --git a/src/Base.zig b/src/Base.zig index 4a475a7..ebc4780 100644 --- a/src/Base.zig +++ b/src/Base.zig @@ -186,12 +186,12 @@ test "BaseQuantities - Kinematics equations" { // Velocity = Distance / Time const v = d.div(t); try std.testing.expectEqual(25.0, v.data[0]); - try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims)); + try comptime std.testing.expect(Speed.dims.eql(@TypeOf(v).dims)); // Acceleration = Velocity / Time const a = v.div(t); try std.testing.expectEqual(12.5, a.data[0]); - try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims)); + try comptime std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims)); } test "BaseQuantities - Dynamics (Force and Work)" { @@ -203,13 +203,13 @@ test "BaseQuantities - Dynamics (Force and Work)" { // Force = mass * acceleration const f = m.mul(a); try std.testing.expectEqual(98, f.data[0]); - try std.testing.expect(Force.dims.eql(@TypeOf(f).dims)); + try comptime std.testing.expect(Force.dims.eql(@TypeOf(f).dims)); // Energy (Work) = Force * distance const distance = Meter.Of(f32).splat(5.0); const energy = f.mul(distance); try std.testing.expectEqual(490, energy.data[0]); - try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims)); + try comptime std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims)); } test "BaseQuantities - Electric combinations" { @@ -219,7 +219,7 @@ test "BaseQuantities - Electric combinations" { // Charge = Current * time const charge = current.mul(time); try std.testing.expectEqual(6.0, charge.data[0]); - try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims)); + try comptime std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims)); } test "Constants - Initialization and dimension checks" { diff --git a/src/Tensor.zig b/src/Tensor.zig index 1cc3a91..563a530 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension; // Comptime utilities // ───────────────────────────────────────────────────────────────────────────── -pub fn shapeTotal(comptime shape: []const comptime_int) usize { +pub fn shapeTotal(shape: []const comptime_int) usize { var t: comptime_int = 1; for (shape) |s| t *= s; return t; } /// Check if two shapes are strictly identical. -pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool { +pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool { if (a.len != b.len) return false; for (a, 0..) |v, i| if (v != b[i]) return false; @@ -25,7 +25,7 @@ pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_i /// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). /// e.g. shape {3, 4} → strides {4, 1} /// shape {2, 3, 4} → strides {12, 4, 1} -pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int { +pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int { var st: [shape.len]comptime_int = undefined; if (shape.len == 0) return st; st[shape.len - 1] = 1; @@ -37,7 +37,7 @@ pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_in } /// Return a copy of `shape` with the element at `axis` removed. -pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int { +pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int { var out: [shape.len - 1]comptime_int = undefined; var j: comptime_int = 0; for (shape, 0..) |v, i| { @@ -50,7 +50,7 @@ pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comp } /// Concatenate two compile-time slices. -pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int { +pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int { var out: [a.len + b.len]comptime_int = undefined; for (a, 0..) |v, i| out[i] = v; for (b, 0..) |v, i| out[a.len + i] = v; @@ -59,11 +59,7 @@ pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_i /// Decode a flat row-major index into N-D coordinates. /// Called only in comptime contexts (all arguments are comptime). -pub fn decodeFlatCoords( - comptime flat: comptime_int, - comptime n: comptime_int, - comptime strd: [n]comptime_int, -) [n]usize { +pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize { var coords: [n]comptime_int = undefined; var tmp = flat; for (0..n) |i| { @@ -75,11 +71,7 @@ pub fn decodeFlatCoords( /// Encode N-D coordinates into a flat row-major index. /// Called only in comptime contexts. -pub fn encodeFlatCoords( - comptime coords: []const usize, - comptime n: usize, - comptime strd: [n]usize, -) usize { +pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize { var flat: usize = 0; for (0..n) |i| flat += coords[i] * strd[i]; return flat; @@ -106,7 +98,7 @@ pub fn insertAxis( return out; } -fn isInt(comptime T: type) bool { +inline fn isInt(comptime T: type) bool { return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int; } @@ -130,25 +122,27 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales { else scale1); } - comptime return out; + return out; } // ───────────────────────────────────────────────────────────────────────────── // File-scope RHS normalisation helpers // ───────────────────────────────────────────────────────────────────────────── -fn isTensor(comptime Rhs: type) bool { +inline fn isTensor(comptime Rhs: type) bool { return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR"); } -fn RhsTensorType(comptime T: type, comptime Rhs: type) type { +inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type { if (comptime isTensor(Rhs)) return Rhs; return Tensor(T, .{}, .{}, &.{1}); } -fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { +/// Take the anyvalue coming from operation and if it is a Tensor, return it. +/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r). +inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { const Rhs = @TypeOf(r); if (comptime isTensor(Rhs)) return r; - const scalar: T = switch (comptime @typeInfo(Rhs)) { + const scalar: T = switch (@typeInfo(Rhs)) { .comptime_int => switch (comptime @typeInfo(T)) { .float => @as(T, @floatFromInt(r)), else => @as(T, r), @@ -278,11 +272,13 @@ pub fn Tensor( shape_, ) { const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); + const RhsType = @TypeOf(r); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); + // if (comptime total == 1 and RhsType.scales.eql(scales)) + // return .{ .data = if (comptime isInt(T)) self.data +| r.data else self.data + r.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -450,6 +446,9 @@ pub fn Tensor( const DestT = ActualDest.ValueType; const DestVec = @Vector(total, DestT); + if (comptime ratio == 1.0 and T == DestT) + return .{ .data = self.data }; + // If ratio is 1, handle type conversion correctly based on BOTH source and dest types if (comptime ratio == 1.0) { const T_info = @typeInfo(T); @@ -533,7 +532,7 @@ pub fn Tensor( pub inline fn eq(self: Self, r: anytype) CmpResult { const rhs_q = rhs(r); if (comptime !dims.eql(@TypeOf(rhs_q).dims)) - @compileError("Dimension mismatch in eq: " ++ dims.str() ++ " vs " ++ @TypeOf(rhs_q).dims.str()); + @compileError("Dimension mismatch in ne."); const p = resolveScalePair(self, rhs_q); return cmpResult(p.l == p.r); } diff --git a/src/benchmark.zig b/src/benchmark.zig index 6371c25..dc8bcc0 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -10,27 +10,27 @@ pub fn main(init: std.process.Init) !void { io = init.io; - try vectorSIMDvsNative(f64, &stdout_writer.interface); - try stdout_writer.flush(); - try vectorSIMDvsNative(f32, &stdout_writer.interface); - try stdout_writer.flush(); - try vectorSIMDvsNative(i32, &stdout_writer.interface); - try stdout_writer.flush(); - try vectorSIMDvsNative(i64, &stdout_writer.interface); - try stdout_writer.flush(); - try vectorSIMDvsNative(i128, &stdout_writer.interface); - try stdout_writer.flush(); - - try bench_Scalar(&stdout_writer.interface); - try stdout_writer.flush(); + // try vectorSIMDvsNative(f64, &stdout_writer.interface); + // try stdout_writer.flush(); + // try vectorSIMDvsNative(f32, &stdout_writer.interface); + // try stdout_writer.flush(); + // try vectorSIMDvsNative(i32, &stdout_writer.interface); + // try stdout_writer.flush(); + // try vectorSIMDvsNative(i64, &stdout_writer.interface); + // try stdout_writer.flush(); + // try vectorSIMDvsNative(i128, &stdout_writer.interface); + // try stdout_writer.flush(); + // + // try bench_Scalar(&stdout_writer.interface); + // try stdout_writer.flush(); try bench_vsNative(&stdout_writer.interface); try stdout_writer.flush(); - try bench_crossTypeVsNative(&stdout_writer.interface); - try stdout_writer.flush(); - try bench_Vector(&stdout_writer.interface); - try stdout_writer.flush(); - try bench_HighDimTensor(&stdout_writer.interface); - try stdout_writer.flush(); + // try bench_crossTypeVsNative(&stdout_writer.interface); + // try stdout_writer.flush(); + // try bench_Vector(&stdout_writer.interface); + // try stdout_writer.flush(); + // try bench_HighDimTensor(&stdout_writer.interface); + // try stdout_writer.flush(); } fn getTime() Io.Timestamp { @@ -200,24 +200,22 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { var native_total_ns: f64 = 0; var quantity_total_ns: f64 = 0; - const M = Tensor(T, .{ .L = 1 }, .{}, &.{1}); - const S = Tensor(T, .{ .T = 1 }, .{}, &.{1}); + const M = Tensor(T, .{}, .{}, &.{1}); std.mem.doNotOptimizeAway({ for (0..SAMPLES) |_| { // --- 1. Benchmark Native --- const n_start = getTime(); - for (0..ITERS) |i| { - const a = getValT(T, i); - const b = getValT(T, 2); - + const a = getValT(T, 10); + const b = getValT(T, 2); + for (0..ITERS) |_| { // Native logic branch _ = if (comptime std.mem.eql(u8, op_name, "add")) - a + b + if (comptime @typeInfo(T) == .int) a +| b else a + b else if (comptime std.mem.eql(u8, op_name, "sub")) - a - b + if (comptime @typeInfo(T) == .int) a -| b else a - b else if (comptime std.mem.eql(u8, op_name, "mul")) - a * b + if (comptime @typeInfo(T) == .int) a *| b else a * b else if (comptime std.mem.eql(u8, op_name, "div")) if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b else if (comptime std.mem.eql(u8, op_name, "abs")) @@ -234,10 +232,9 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { // --- 2. Benchmark Scalar --- const q_start = getTime(); - for (0..ITERS) |i| { - const qa = M.splat(getValT(T, i)); - const qb = if (comptime std.mem.eql(u8, op_name, "div")) S.splat(getValT(T, 2)) else M.splat(getValT(T, 2)); - + const qa = M.splat(getValT(T, 10)); + const qb = M.splat(getValT(T, 2)); + for (0..ITERS) |_| { // Scalar logic branch _ = if (comptime std.mem.eql(u8, op_name, "add")) qa.add(qb) From c6f613a78780621fe8b0c988919c3b464032d746 Mon Sep 17 00:00:00 2001 From: adrien Date: Mon, 27 Apr 2026 22:07:39 +0200 Subject: [PATCH 09/10] I guess I can't do better. Scalar still suck, but at least it is builtin SIDM so ok I guess --- src/Scales.zig | 6 +++++ src/Tensor.zig | 18 +++++++------ src/benchmark.zig | 66 +++++++++++++++++++++++++++++++++-------------- 3 files changed, 63 insertions(+), 27 deletions(-) diff --git a/src/Scales.zig b/src/Scales.zig index 54484b4..e7877e1 100644 --- a/src/Scales.zig +++ b/src/Scales.zig @@ -118,6 +118,12 @@ pub fn set(self: *Self, key: Dimension, val: UnitScale) void { self.data.set(key, val); } +pub fn eql(self: Self, other: Self) bool { + for (self.data.values, other.data.values) |l, r| + if (l != r) return false; + return true; +} + pub fn argsOpt(self: Self) ArgOpts { var args: ArgOpts = undefined; for (std.enums.values(Dimension)) |d| diff --git a/src/Tensor.zig b/src/Tensor.zig index 563a530..9b2a593 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -203,7 +203,7 @@ pub fn Tensor( if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1."); } } - @setEvalBranchQuota(10_000_000); + @setEvalBranchQuota(100_000); const _total: usize = comptime shapeTotal(shape_); const _strides = comptime shapeStrides(shape_); @@ -271,20 +271,20 @@ pub fn Tensor( finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { - const rhs_q = rhs(r); + const rhs_t = rhs(r); const RhsType = @TypeOf(r); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); - // if (comptime total == 1 and RhsType.scales.eql(scales)) - // return .{ .data = if (comptime isInt(T)) self.data +| r.data else self.data + r.data }; + if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too + return .{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); - const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; @@ -298,18 +298,20 @@ pub fn Tensor( finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { - const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); + const rhs_t = rhs(r); + const RhsType = @TypeOf(rhs_t); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS."); + if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too + return .{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); - const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; return .{ .data = if (comptime isInt(T)) l -| rr else l - rr }; diff --git a/src/benchmark.zig b/src/benchmark.zig index dc8bcc0..c6fb592 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -26,11 +26,11 @@ pub fn main(init: std.process.Init) !void { try bench_vsNative(&stdout_writer.interface); try stdout_writer.flush(); // try bench_crossTypeVsNative(&stdout_writer.interface); - // try stdout_writer.flush(); - // try bench_Vector(&stdout_writer.interface); - // try stdout_writer.flush(); - // try bench_HighDimTensor(&stdout_writer.interface); - // try stdout_writer.flush(); + try stdout_writer.flush(); + try bench_Vector(&stdout_writer.interface); + try stdout_writer.flush(); + try bench_HighDimTensor(&stdout_writer.interface); + try stdout_writer.flush(); } fn getTime() Io.Timestamp { @@ -171,7 +171,7 @@ fn bench_Scalar(writer: *std.Io.Writer) !void { fn bench_vsNative(writer: *std.Io.Writer) !void { const ITERS: usize = 100_000; - const SAMPLES: usize = 5; + const SAMPLES: usize = 100; const getValT = struct { fn f(comptime TT: type, i: usize) TT { @@ -180,8 +180,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { } }.f; - const Types = .{ i32, i64, i128, f32, f64 }; - const TNames = .{ "i32", "i64", "i128", "f32", "f64" }; + const Types = .{ f64, i64, i128, f32, f64 }; + const TNames = .{ "f64", "i64", "i128", "f32", "f64" }; // Expanded Ops to match bench_Scalar const Ops = .{ "add", "sub", "mul", "div", "abs", "eq", "gt" }; @@ -189,16 +189,17 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { \\ \\ Scalar vs Native Overhead Analysis \\ - \\┌───────────┬──────┬───────────┬───────────┬───────────┐ - \\│ Operation │ Type │ Native │ Scalar │ Slowdown │ - \\├───────────┼──────┼───────────┼───────────┼───────────┤ + \\┌───────────┬──────┬───────────┬───────────┬───────────┬───────────────────────┐ + \\│ Operation │ Type │ Native │ @Vector │ Tensor{{1}} │ Slowdown Nat | Vec │ + \\├───────────┼──────┼───────────┼───────────┼───────────┼───────────────────────┤ \\ , .{}); inline for (Ops, 0..) |op_name, j| { inline for (Types, 0..) |T, tidx| { var native_total_ns: f64 = 0; - var quantity_total_ns: f64 = 0; + var vector_total_ns: f64 = 0; + var tensor_total_ns: f64 = 0; const M = Tensor(T, .{}, .{}, &.{1}); @@ -230,6 +231,31 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { const n_end = getTime(); native_total_ns += @as(f64, @floatFromInt(n_start.durationTo(n_end).toNanoseconds())); + const v_start = getTime(); + const va = getValT(T, 10); + const vb = getValT(T, 2); + for (0..ITERS) |_| { + // Native logic branch + _ = if (comptime std.mem.eql(u8, op_name, "add")) + if (comptime @typeInfo(T) == .int) va +| vb else va + vb + else if (comptime std.mem.eql(u8, op_name, "sub")) + if (comptime @typeInfo(T) == .int) va -| vb else va - vb + else if (comptime std.mem.eql(u8, op_name, "mul")) + if (comptime @typeInfo(T) == .int) va *| vb else va * vb + else if (comptime std.mem.eql(u8, op_name, "div")) + if (comptime @typeInfo(T) == .int) @divTrunc(va, vb) else va / vb + else if (comptime std.mem.eql(u8, op_name, "abs")) + if (comptime @typeInfo(T) == .int) @abs(va) else @as(T, @abs(va)) + else if (comptime std.mem.eql(u8, op_name, "eq")) + va == vb + else if (comptime std.mem.eql(u8, op_name, "gt")) + va > vb + else + unreachable; + } + const v_end = getTime(); + vector_total_ns += @as(f64, @floatFromInt(v_start.durationTo(v_end).toNanoseconds())); + // --- 2. Benchmark Scalar --- const q_start = getTime(); const qa = M.splat(getValT(T, 10)); @@ -254,22 +280,24 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { unreachable; } const q_end = getTime(); - quantity_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds())); + tensor_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds())); } }); const avg_n = (native_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); - const avg_q = (quantity_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); - const slowdown = avg_q / avg_n; + const avg_v = (vector_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); + const avg_t = (tensor_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); + const slowdown_nt = avg_t / avg_n; + const slowdown_vt = avg_t / avg_v; - try writer.print("│ {s:<9} │ {s:<4} │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>8.2}x │\n", .{ - op_name, TNames[tidx], avg_n, avg_q, slowdown, + try writer.print("│ {s:<9} │ {s:<4} │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>8.2}x {d:>8.2}x │\n", .{ + op_name, TNames[tidx], avg_n, avg_v, avg_t, slowdown_nt, slowdown_vt, }); } - if (j != Ops.len - 1) try writer.print("├───────────┼──────┼───────────┼───────────┼───────────┤\n", .{}); + if (j != Ops.len - 1) try writer.print("├───────────┼──────┼───────────┼───────────┼───────────┼───────────────────────┤\n", .{}); } - try writer.print("└───────────┴──────┴───────────┴───────────┴───────────┘\n", .{}); + try writer.print("└───────────┴──────┴───────────┴───────────┴───────────┴───────────────────────┘\n", .{}); } fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void { From 0b2469d3212fed77f7fc379c5003c4af2fdf9351 Mon Sep 17 00:00:00 2001 From: adrien Date: Mon, 27 Apr 2026 22:07:45 +0200 Subject: [PATCH 10/10] . --- current release.md | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 current release.md diff --git a/current release.md b/current release.md new file mode 100644 index 0000000..30759d6 --- /dev/null +++ b/current release.md @@ -0,0 +1,11 @@ +- Changed Quantity to Tensor that can use any shape and is a single @Vector. + Point being to add WebGPU easily from this. + Scalr suffer in performance tho, I will work on that + +Maybe I can do a jupiter like web interface with cells to make Dim analysis +I could: + - Use cells with a toy language + - A nice debugger to display current variables with dimensions, type and value + - Realtime error (I try to compile at change, display error on the cell) + - Integrate a small graphic API that use Raylib canvas + - COuld generate template at comptime =o