diff --git a/build.zig b/build.zig index e507138..2883cb9 100644 --- a/build.zig +++ b/build.zig @@ -2,7 +2,7 @@ const std = @import("std"); pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); - const optimize = b.standardOptimizeOption(.{}); + const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseFast }); // 1. Define the module so other projects can import it _ = b.addModule("dimal", .{ diff --git a/current release.md b/current release.md new file mode 100644 index 0000000..30759d6 --- /dev/null +++ b/current release.md @@ -0,0 +1,11 @@ +- Changed Quantity to Tensor that can use any shape and is a single @Vector. + Point being to add WebGPU easily from this. + Scalr suffer in performance tho, I will work on that + +Maybe I can do a jupiter like web interface with cells to make Dim analysis +I could: + - Use cells with a toy language + - A nice debugger to display current variables with dimensions, type and value + - Realtime error (I try to compile at change, display error on the cell) + - Integrate a small graphic API that use Raylib canvas + - COuld generate template at comptime =o diff --git a/src/Base.zig b/src/Base.zig index 4bde68b..ebc4780 100644 --- a/src/Base.zig +++ b/src/Base.zig @@ -3,34 +3,39 @@ const std = @import("std"); // Adjust these imports to match your actual file names const Dimensions = @import("Dimensions.zig"); const Scales = @import("Scales.zig"); -const Scalar = @import("Quantity.zig").Scalar; +const Tensor = @import("Tensor.zig").Tensor; fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime s: Scales.ArgOpts) type { return struct { - const dims = Dimensions.init(d); - const scales = Scales.init(s); + pub const dims = Dimensions.init(d); + pub const scales = Scales.init(s); /// Instantiates the constant into a specific numeric type. - pub fn Of(comptime T: type) Scalar(T, d, s) { - return .{ .data = @splat(@as(T, @floatCast(val))) }; + pub fn Of(comptime T: type) Tensor(T, d, s, &.{1}) { + const casted_val: T = switch (@typeInfo(T)) { + .float => @floatCast(val), + .int => @intFromFloat(val), + else => @compileError("Unsupported type for PhysicalConstant"), + }; + return Tensor(T, d, s, &.{1}).splat(casted_val); } }; } fn BaseScalar(comptime d: Dimensions.ArgOpts) type { return struct { - const dims = Dimensions.init(d); + pub const dims = Dimensions.init(d); /// Creates a Scalar of this dimension using default scales. - /// Example: const V = Quantities.Velocity.Base(f32); + /// Example: const V = Quantities.Velocity.Of(f32); pub fn Of(comptime T: type) type { - return Scalar(T, d, .{}); + return Tensor(T, d, .{}, &.{1}); } /// Creates a Scalar of this dimension using custom scales. - /// Example: const Kmh = Quantities.Velocity.Scaled(f32, Scales.init(.{ .L = .k, .T = .hour })); + /// Example: const Kmh = Quantities.Velocity.Scaled(f32, .{ .L = .k, .T = .hour }); pub fn Scaled(comptime T: type, comptime s: Scales.ArgOpts) type { - return Scalar(T, d, s); + return Tensor(T, d, s, &.{1}); } }; } @@ -107,7 +112,7 @@ pub const ElectricCapacitance = BaseScalar(.{ .T = 4, .L = -2, .M = -1, .I = 2 } pub const ElectricImpedance = ElectricResistance; pub const MagneticFlux = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .I = -1 }); pub const MagneticDensity = BaseScalar(.{ .M = 1, .T = -2, .I = -1 }); -pub const MagneticStrength = BaseScalar(.{ .L = -1, .I = 1 }); // Fixed typo from MagneticStrengh +pub const MagneticStrength = BaseScalar(.{ .L = -1, .I = 1 }); pub const MagneticMoment = BaseScalar(.{ .L = 2, .I = 1 }); // ========================================== @@ -140,7 +145,7 @@ pub const ThermalHeat = Energy; pub const ThermalWork = Energy; pub const ThermalCapacity = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .Tr = -1 }); pub const ThermalCapacityPerMass = BaseScalar(.{ .L = 2, .T = -2, .Tr = -1 }); -pub const ThermalFluxDensity = BaseScalar(.{ .M = 1, .T = -3 }); // Fixed typo from ThermalluxDensity +pub const ThermalFluxDensity = BaseScalar(.{ .M = 1, .T = -3 }); pub const ThermalConductance = BaseScalar(.{ .M = 1, .L = 2, .T = -3, .Tr = -1 }); pub const ThermalConductivity = BaseScalar(.{ .M = 1, .L = 1, .T = -3, .Tr = -1 }); pub const ThermalResistance = BaseScalar(.{ .M = -1, .L = -2, .T = 3, .Tr = 1 }); @@ -152,20 +157,24 @@ pub const ThermalEntropy = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .Tr = -1 }); // ========================================== pub const Frequency = BaseScalar(.{ .T = -1 }); pub const Viscosity = BaseScalar(.{ .M = 1, .L = -1, .T = -1 }); -pub const SurfaceTension = BaseScalar(.{ .M = 1, .T = -2 }); // Corrected from MT-2a +pub const SurfaceTension = BaseScalar(.{ .M = 1, .T = -2 }); + +// ========================================== +// Tests +// ========================================== test "BaseQuantities - Core dimensions instantiation" { // Basic types via generic wrappers const M = Meter.Of(f32); const distance = M.splat(100); - try std.testing.expectEqual(100.0, distance.value()); + try std.testing.expectEqual(100.0, distance.data[0]); try std.testing.expectEqual(1, M.dims.get(.L)); try std.testing.expectEqual(0, M.dims.get(.T)); // Test specific scale variants const Kmh = Speed.Scaled(f32, .{ .L = .k, .T = .hour }); const speed = Kmh.splat(120); - try std.testing.expectEqual(120.0, speed.value()); + try std.testing.expectEqual(120.0, speed.data[0]); try std.testing.expectEqual(.k, @TypeOf(speed).scales.get(.L)); try std.testing.expectEqual(.hour, @TypeOf(speed).scales.get(.T)); } @@ -176,13 +185,13 @@ test "BaseQuantities - Kinematics equations" { // Velocity = Distance / Time const v = d.div(t); - try std.testing.expectEqual(25.0, v.value()); - try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims)); + try std.testing.expectEqual(25.0, v.data[0]); + try comptime std.testing.expect(Speed.dims.eql(@TypeOf(v).dims)); // Acceleration = Velocity / Time const a = v.div(t); - try std.testing.expectEqual(12.5, a.value()); - try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims)); + try std.testing.expectEqual(12.5, a.data[0]); + try comptime std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims)); } test "BaseQuantities - Dynamics (Force and Work)" { @@ -193,14 +202,14 @@ test "BaseQuantities - Dynamics (Force and Work)" { // Force = mass * acceleration const f = m.mul(a); - try std.testing.expectEqual(98, f.value()); - try std.testing.expect(Force.dims.eql(@TypeOf(f).dims)); + try std.testing.expectEqual(98, f.data[0]); + try comptime std.testing.expect(Force.dims.eql(@TypeOf(f).dims)); // Energy (Work) = Force * distance const distance = Meter.Of(f32).splat(5.0); const energy = f.mul(distance); - try std.testing.expectEqual(490, energy.value()); - try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims)); + try std.testing.expectEqual(490, energy.data[0]); + try comptime std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims)); } test "BaseQuantities - Electric combinations" { @@ -209,26 +218,26 @@ test "BaseQuantities - Electric combinations" { // Charge = Current * time const charge = current.mul(time); - try std.testing.expectEqual(6.0, charge.value()); - try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims)); + try std.testing.expectEqual(6.0, charge.data[0]); + try comptime std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims)); } test "Constants - Initialization and dimension checks" { // Speed of Light const c = Constants.SpeedOfLight.Of(f64); - try std.testing.expectEqual(299792458.0, c.value()); + try std.testing.expectEqual(299792458.0, c.data[0]); try std.testing.expectEqual(1, @TypeOf(c).dims.get(.L)); try std.testing.expectEqual(-1, @TypeOf(c).dims.get(.T)); // Electron Mass (verifying scale as well) const me = Constants.ElectronMass.Of(f64); - try std.testing.expectEqual(9.1093837139e-31, me.value()); + try std.testing.expectEqual(9.1093837139e-31, me.data[0]); try std.testing.expectEqual(1, @TypeOf(me).dims.get(.M)); try std.testing.expectEqual(.k, @TypeOf(me).scales.get(.M)); // Should be scaled to kg // Boltzmann Constant (Complex derived dimensions) const kb = Constants.Boltzmann.Of(f64); - try std.testing.expectEqual(1.380649e-23, kb.value()); + try std.testing.expectEqual(1.380649e-23, kb.data[0]); try std.testing.expectEqual(1, @TypeOf(kb).dims.get(.M)); try std.testing.expectEqual(2, @TypeOf(kb).dims.get(.L)); try std.testing.expectEqual(-2, @TypeOf(kb).dims.get(.T)); @@ -237,7 +246,7 @@ test "Constants - Initialization and dimension checks" { // Vacuum Permittivity const eps0 = Constants.VacuumPermittivity.Of(f64); - try std.testing.expectEqual(8.8541878188e-12, eps0.value()); + try std.testing.expectEqual(8.8541878188e-12, eps0.data[0]); try std.testing.expectEqual(-1, @TypeOf(eps0).dims.get(.M)); try std.testing.expectEqual(-3, @TypeOf(eps0).dims.get(.L)); try std.testing.expectEqual(4, @TypeOf(eps0).dims.get(.T)); @@ -245,7 +254,7 @@ test "Constants - Initialization and dimension checks" { // Fine Structure Constant (Dimensionless) const alpha = Constants.FineStructure.Of(f64); - try std.testing.expectEqual(0.0072973525643, alpha.value()); + try std.testing.expectEqual(0.0072973525643, alpha.data[0]); try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.M)); try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.L)); } diff --git a/src/Dimensions.zig b/src/Dimensions.zig index e91e78d..d54ced2 100644 --- a/src/Dimensions.zig +++ b/src/Dimensions.zig @@ -49,81 +49,81 @@ data: std.EnumArray(Dimension, comptime_int), /// Create a `Dimensions` from a struct literal, e.g. `.{ .L = 1, .T = -1 }`. /// Unspecified dimensions default to 0. -pub fn init(comptime init_val: ArgOpts) Self { +pub fn init(init_val: ArgOpts) Self { var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) }; - inline for (std.meta.fields(@TypeOf(init_val))) |f| + for (std.meta.fields(@TypeOf(init_val))) |f| s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); return s; } -pub fn initFill(comptime val: comptime_int) Self { +pub fn initFill(val: comptime_int) Self { return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) }; } -pub fn get(comptime self: Self, comptime key: Dimension) comptime_int { +pub fn get(self: Self, key: Dimension) comptime_int { return self.data.get(key); } -pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void { +pub fn set(self: *Self, key: Dimension, val: i8) void { self.data.set(key, val); } pub fn argsOpt(self: Self) ArgOpts { var args: ArgOpts = undefined; - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| @field(args, @tagName(d)) = self.get(d); return args; } /// Add exponents component-wise. Used internally by `mul`. -pub fn add(comptime a: Self, comptime b: Self) Self { +pub fn add(a: Self, b: Self) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) + b.get(d)); return result; } /// Subtract exponents component-wise. Used internally by `div`. -pub fn sub(comptime a: Self, comptime b: Self) Self { +pub fn sub(a: Self, b: Self) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) - b.get(d)); return result; } /// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar. -pub fn scale(comptime a: Self, comptime exp: comptime_int) Self { +pub fn scale(a: Self, exp: comptime_int) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) * exp); return result; } -pub fn div(comptime a: Self, comptime exp: comptime_int) Self { +pub fn div(a: Self, exp: comptime_int) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) / exp); return result; } /// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`. -pub fn eql(comptime a: Self, comptime b: Self) bool { - inline for (std.enums.values(Dimension)) |d| +pub fn eql(a: Self, b: Self) bool { + for (std.enums.values(Dimension)) |d| if (a.get(d) != b.get(d)) return false; return true; } -pub fn isSquare(comptime a: Self) bool { - inline for (std.enums.values(Dimension)) |d| +pub fn isSquare(a: Self) bool { + for (std.enums.values(Dimension)) |d| if (a.get(d) % 2 != 0) return false; return true; } -pub fn str(comptime a: Self) []const u8 { +pub fn str(a: Self) []const u8 { var out: []const u8 = ""; const dims = std.enums.values(Dimension); - inline for (dims) |d| { + for (dims) |d| { const val = a.get(d); if (val != 0) { out = out ++ @tagName(d) ++ std.fmt.comptimePrint("{d}", .{val}); diff --git a/src/Quantity.zig b/src/Quantity.zig deleted file mode 100644 index 4e30ec9..0000000 --- a/src/Quantity.zig +++ /dev/null @@ -1,1259 +0,0 @@ -const std = @import("std"); -const hlp = @import("helper.zig"); -const Scales = @import("Scales.zig"); -const UnitScale = Scales.UnitScale; -const Dimensions = @import("Dimensions.zig"); -const Dimension = Dimensions.Dimension; - -// --------------------------------------------------------------------------- -// Quantity — single unified dimensioned type. -// -// T : element numeric type (f32, f64, i32, i128, …) -// N : lane count (1 = Scalar, >1 = Vector) -// d : SI dimension exponents -// s : unit scales -// -// All arithmetic is performed directly on the underlying @Vector(N, T), so -// the compiler can emit SIMD instructions wherever the target supports them. -// -// Thin aliases (same type identity, no wrapper overhead): -// Scalar(T, d, s) ≡ Quantity(T, 1, d, s) -// Vector(N, Q) ≡ Quantity(Q.ValueType, N, Q.dims.argsOpt(), Q.scales.argsOpt()) -// --------------------------------------------------------------------------- -// -// To investigate: -// @reduce(comptime op: std.builtin.ReduceOp, value: anytype) E -// @select(comptime T: type, pred: @Vector(len, bool), a: @Vector(len, T), b: @Vector(len, T)) @Vector(len, T) -// @shuffle(comptime E: type, a: @Vector(a_len, E), b: @Vector(b_len, E), comptime mask: @Vector(mask_len, i32)) @Vector(mask_len, E) - -pub fn Quantity( - comptime T: type, - comptime N: usize, - comptime d_opt: Dimensions.ArgOpts, - comptime s_opt: Scales.ArgOpts, -) type { - comptime std.debug.assert(N >= 1); - @setEvalBranchQuota(10_000_000); - - // Local shorthand for the SIMD vector type used in storage. - const Vec = @Vector(N, T); - - return struct { - /// SIMD-friendly storage. Arithmetic operates here directly. - data: Vec, - - const Self = @This(); - - pub const ValueType: type = T; - pub const Len: usize = N; - pub const dims: Dimensions = Dimensions.init(d_opt); - pub const scales: Scales = Scales.init(s_opt); - pub const ISQUANTITY = true; - - /// Scalar variant of this quantity (lane=1). Returned by dot(), product(), etc. - pub const ScalarType: type = Quantity(T, 1, d_opt, s_opt); - /// Convenience: a 3-lane vector of the same dimension/scale. - pub const Vec3: type = Quantity(T, 3, d_opt, s_opt); - - // --------------------------------------------------------------- - // Constructors - // --------------------------------------------------------------- - - /// Broadcast a single value across all N lanes. - pub inline fn splat(v: T) Self { - return .{ .data = @splat(v) }; - } - - pub const zero: Self = splat(0); - pub const one: Self = splat(1); - - // --------------------------------------------------------------- - // Scalar-only helpers (N = 1) - // --------------------------------------------------------------- - - /// Return the single scalar value. Compile error when N ≠ 1. - pub inline fn value(self: Self) T { - comptime if (N != 1) - @compileError(".value() is only available on Scalar (N=1)."); - return self.data[0]; - } - - /// Expand this scalar into a len-lane vector by splatting. - pub inline fn vec(self: Self, comptime len: usize) Quantity(T, len, d_opt, s_opt) { - comptime if (N != 1) - @compileError(".vec() is only available on Scalar (N=1)."); - return .{ .data = @splat(self.data[0]) }; - } - - pub inline fn vec3(self: Self) Vec3 { - return self.vec(3); - } - - // --------------------------------------------------------------- - // Internal: RHS normalisation - // - // • For N=1 (Scalar context): bare numbers → Quantity(T, 1, dimless, none) - // • For N>1 (Vector context): bare numbers → Quantity(T, N, dimless, none) - // Quantity(T,1) → broadcast (handled in each op) - // - // A bare number used as rhs is ALWAYS treated as dimensionless. - // --------------------------------------------------------------- - - inline fn RhsT(comptime Rhs: type) type { - return hlp.rhsQuantityType(T, N, Rhs); - } - inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { - return hlp.toRhsQuantity(T, N, r); - } - - /// Scalar rhs (N=1) — used by mulScalar / divScalar / eqScalar etc. - inline fn ScalarRhsT(comptime Rhs: type) type { - return hlp.rhsQuantityType(T, 1, Rhs); - } - inline fn scalarRhs(r: anytype) ScalarRhsT(@TypeOf(r)) { - return hlp.toRhsQuantity(T, 1, r); - } - - // --------------------------------------------------------------- - // Internal: broadcast helper - // - // When an N=1 rhs is used in an N>1 operation, splat it. - // --------------------------------------------------------------- - inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec { - if (comptime RhsType.Len == 1 and N > 1) - return @splat(r.data[0]) - else - return r.data; - } - - // --------------------------------------------------------------- - // Arithmetic - // --------------------------------------------------------------- - - /// Element-wise add. Dimensions must match; scales resolve to finer. - /// For N=1: rhs may be a bare number (treated as dimensionless). - /// For N>1: rhs must be a same-length Quantity. - pub inline fn add(self: Self, r: anytype) Quantity( - T, - N, - dims.argsOpt(), - hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), - ) { - const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); - if (comptime !dims.eql(RhsType.dims)) - @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime N > 1 and RhsType.Len != N) - @compileError("Vector add requires same-length Quantity."); - - const TargetType = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; - const rr: Vec = if (comptime RhsType == TargetType) rhs_q.data else rhs_q.to(TargetType).data; - return .{ .data = if (comptime hlp.isInt(T)) l +| rr else l + rr }; - } - - /// Element-wise subtract. Dimensions must match; scales resolve to finer. - pub inline fn sub(self: Self, r: anytype) Quantity( - T, - N, - dims.argsOpt(), - hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), - ) { - const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); - if (comptime !dims.eql(RhsType.dims)) - @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime N > 1 and RhsType.Len != N) - @compileError("Vector sub requires same-length Quantity."); - - const TargetType = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; - const rr: Vec = if (comptime RhsType == TargetType) rhs_q.data else rhs_q.to(TargetType).data; - return .{ .data = if (comptime hlp.isInt(T)) l -| rr else l - rr }; - } - - /// Element-wise multiply. Dimension exponents are summed. - /// An N=1 rhs on an N>1 self is automatically broadcast (scalar × vector). - pub inline fn mul(self: Self, r: anytype) Quantity( - T, - N, - dims.add(RhsT(@TypeOf(r)).dims).argsOpt(), - hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), - ) { - const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); - const SelfNorm = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const RhsNorm = Quantity(T, RhsType.Len, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); - return .{ .data = if (comptime hlp.isInt(T)) l *| rr else l * rr }; - } - - /// Element-wise divide. Dimension exponents are subtracted. - /// An N=1 rhs on an N>1 self is automatically broadcast. - pub inline fn div(self: Self, r: anytype) Quantity( - T, - N, - dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(), - hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), - ) { - const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); - const SelfNorm = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const RhsNorm = Quantity(T, RhsType.Len, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); - if (comptime hlp.isInt(T)) { - var result: Vec = undefined; - inline for (0..N) |i| result[i] = @divTrunc(l[i], rr[i]); - return .{ .data = result }; - } else { - return .{ .data = l / rr }; - } - } - - // --------------------------------------------------------------- - // Unary - // --------------------------------------------------------------- - - /// Absolute value of every lane. Uses native `@abs` (SIMD for floats & ints). - pub inline fn abs(self: Self) Self { - return .{ .data = @bitCast(@abs(self.data)) }; - } - - /// Raise every lane to a comptime integer exponent. - /// Repeated SIMD multiply — good for small exponents. - pub inline fn pow(self: Self, comptime exp: comptime_int) Quantity( - T, - N, - dims.scale(exp).argsOpt(), - scales.argsOpt(), - ) { - if (comptime hlp.isInt(T)) { - // No SIMD pow for integers — element-wise std.math.powi. - var result: Vec = undefined; - inline for (0..N) |i| - result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T); - return .{ .data = result }; - } else { - // Float: unrolled SIMD multiplications. - const abs_exp = comptime @abs(exp); - var result: Vec = @splat(1); - comptime var i = 0; - inline while (i < abs_exp) : (i += 1) result *= self.data; - if (comptime exp < 0) result = @as(Vec, @splat(1)) / result; - return .{ .data = result }; - } - } - - /// Square root of every lane. All dimension exponents must be even. - pub inline fn sqrt(self: Self) Quantity( - T, - N, - dims.div(2).argsOpt(), - scales.argsOpt(), - ) { - if (comptime !dims.isSquare()) - @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); - if (comptime @typeInfo(T) == .float) { - return .{ .data = @sqrt(self.data) }; - } else { - // Integer sqrt is not SIMD-able — element-wise. - var result: Vec = undefined; - const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); - inline for (0..N) |i| { - const v = self.data[i]; - if (v < 0) - result[i] = 0 - else - result[i] = @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); - } - return .{ .data = result }; - } - } - - /// Negate every lane. - pub inline fn negate(self: Self) Self { - return .{ .data = -self.data }; - } - - // --------------------------------------------------------------- - // Conversion - // --------------------------------------------------------------- - - /// Convert to a compatible quantity type. Dimension mismatch is a compile error. - /// Dest can have the same Len as this quantity, or Len == 1 (in which case it - /// will be automatically recast to this quantity's Len). - /// The scale ratio is computed entirely at comptime; the only runtime cost is - /// a SIMD multiply-by-splat (or element-wise cast for cross-numeric-type conversions). - pub inline fn to( - self: Self, - comptime Dest: type, - ) Quantity(Dest.ValueType, N, Dest.dims.argsOpt(), Dest.scales.argsOpt()) { - const ActualDest = Quantity(Dest.ValueType, N, Dest.dims.argsOpt(), Dest.scales.argsOpt()); - - if (comptime !dims.eql(ActualDest.dims)) - @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); - - if (comptime Self == ActualDest) return self; - - // Allow Dest to be exactly matching Len or a Scalar (Len == 1) - comptime std.debug.assert(Dest.Len == N or Dest.Len == 1); - - const DestT = ActualDest.ValueType; - const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); - const DestVec = @Vector(N, DestT); - - // ── Same numeric type path ── - if (comptime T == DestT) { - if (comptime @typeInfo(T) == .float) - return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; - - // Integer logic: Branching prevents division by zero errors - if (comptime ratio >= 1.0) { - // Upscaling (e.g., km -> m, ratio = 1000) - const mult: T = comptime @intFromFloat(@round(ratio)); - return .{ .data = self.data *| @as(Vec, @splat(mult)) }; - } else { - // Downscaling (e.g., m -> km, ratio = 0.001) - const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); - var result: DestVec = undefined; - const half: T = comptime @divTrunc(div_val, 2); - - inline for (0..N) |i| { - const val = self.data[i]; - // Rounding division for integers - result[i] = if (val >= 0) @divTrunc(val + half, div_val) else @divTrunc(val - half, div_val); - } - return .{ .data = result }; - } - } - - // ── Cross-numeric-type (unchanged) ── - var result: DestVec = undefined; - inline for (0..N) |i| { - const float_val: f64 = switch (comptime @typeInfo(T)) { - .float => @floatCast(self.data[i]), - .int => @floatFromInt(self.data[i]), - else => unreachable, - }; - const scaled = float_val * ratio; - result[i] = switch (comptime @typeInfo(DestT)) { - .float => @floatCast(scaled), - .int => @intFromFloat(@round(scaled)), - else => unreachable, - }; - } - return .{ .data = result }; - } - - // --------------------------------------------------------------- - // Comparisons - // - // Return type: bool when N = 1 (Scalar semantics) - // [N]bool when N > 1 (Vector semantics, element-wise) - // - // Whole-vector "all equal/any differ" → use eqAll / neAll. - // Broadcast scalar comparison → use eqScalar / gtScalar / … - // --------------------------------------------------------------- - - const CmpResult = if (N == 1) bool else [N]bool; - - inline fn cmpResult(v: @Vector(N, bool)) CmpResult { - return if (comptime N == 1) v[0] else @as([N]bool, v); - } - - inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { - const RhsType = @TypeOf(rhs_q); - const TargetType = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt()); - return .{ - .l = if (comptime Self == TargetType) self.data else self.to(TargetType).data, - .r = if (comptime RhsType == TargetType) rhs_q.data else rhs_q.to(TargetType).data, - }; - } - - pub inline fn eq(self: Self, r: anytype) CmpResult { - const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) - @compileError("Dimension mismatch in eq: " ++ dims.str() ++ " vs " ++ @TypeOf(rhs_q).dims.str()); - const p = resolveScalePair(self, rhs_q); - return cmpResult(p.l == p.r); - } - - pub inline fn ne(self: Self, r: anytype) CmpResult { - const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) - @compileError("Dimension mismatch in ne."); - const p = resolveScalePair(self, rhs_q); - return cmpResult(p.l != p.r); - } - - pub inline fn gt(self: Self, r: anytype) CmpResult { - const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) - @compileError("Dimension mismatch in gt."); - const p = resolveScalePair(self, rhs_q); - return cmpResult(p.l > p.r); - } - - pub inline fn gte(self: Self, r: anytype) CmpResult { - const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) - @compileError("Dimension mismatch in gte."); - const p = resolveScalePair(self, rhs_q); - return cmpResult(p.l >= p.r); - } - - pub inline fn lt(self: Self, r: anytype) CmpResult { - const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) - @compileError("Dimension mismatch in lt."); - const p = resolveScalePair(self, rhs_q); - return cmpResult(p.l < p.r); - } - - pub inline fn lte(self: Self, r: anytype) CmpResult { - const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) - @compileError("Dimension mismatch in lte."); - const p = resolveScalePair(self, rhs_q); - return cmpResult(p.l <= p.r); - } - - // --------------------------------------------------------------- - // Vector whole-quantity comparisons (N > 1 intended, but work for N=1 too) - // --------------------------------------------------------------- - - /// True iff every lane is equal after scale resolution. - pub inline fn eqAll(self: Self, other: anytype) bool { - if (comptime !dims.eql(@TypeOf(other).dims)) - @compileError("Dimension mismatch in eqAll."); - const p = resolveScalePair(self, other); - return @reduce(.And, p.l == p.r); - } - - pub inline fn neAll(self: Self, other: anytype) bool { - return !self.eqAll(other); - } - - // --------------------------------------------------------------- - // Vector broadcast-scalar comparisons (always returns [N]bool) - // --------------------------------------------------------------- - - inline fn broadcastScalarForCmp(self: Self, scalar: anytype) struct { l: Vec, r: Vec } { - const s = scalarRhs(scalar); - const SN = @TypeOf(s); - const TargetScalar = Quantity(T, 1, dims.argsOpt(), hlp.finerScales(Self, SN).argsOpt()); - const TargetSelf = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, SN).argsOpt()); - const s_val: T = if (comptime SN == TargetScalar) s.data[0] else s.to(TargetScalar).data[0]; - const l: Vec = if (comptime Self == TargetSelf) self.data else self.to(TargetSelf).data; - return .{ .l = l, .r = @splat(s_val) }; - } - - pub inline fn eqScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l == p.r); - } - - pub inline fn neScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l != p.r); - } - - pub inline fn gtScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l > p.r); - } - - pub inline fn gteScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l >= p.r); - } - - pub inline fn ltScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l < p.r); - } - - pub inline fn lteScalar(self: Self, scalar: anytype) [N]bool { - const p = broadcastScalarForCmp(self, scalar); - return @as([N]bool, p.l <= p.r); - } - - // --------------------------------------------------------------- - // Vector broadcast multiply / divide - // (These are explicit aliases for mul/div with an N=1 rhs; kept for - // clarity and backward-compat with the old Vector API.) - // --------------------------------------------------------------- - - pub inline fn mulScalar(self: Self, scalar: anytype) Quantity( - T, - N, - dims.add(ScalarRhsT(@TypeOf(scalar)).dims).argsOpt(), - hlp.finerScales(Self, ScalarRhsT(@TypeOf(scalar))).argsOpt(), - ) { - return self.mul(scalar); - } - - pub inline fn divScalar(self: Self, scalar: anytype) Quantity( - T, - N, - dims.sub(ScalarRhsT(@TypeOf(scalar)).dims).argsOpt(), - hlp.finerScales(Self, ScalarRhsT(@TypeOf(scalar))).argsOpt(), - ) { - return self.div(scalar); - } - - // --------------------------------------------------------------- - // Vector geometric operations - // --------------------------------------------------------------- - - /// Dot product — sum of element-wise products; returns a Scalar. - pub inline fn dot(self: Self, other: anytype) Quantity( - T, - 1, - dims.add(@TypeOf(other).dims).argsOpt(), - hlp.finerScales(Self, @TypeOf(other)).argsOpt(), - ) { - const Tr = @TypeOf(other); - const SelfNorm = Quantity(T, N, dims.argsOpt(), hlp.finerScales(Self, Tr).argsOpt()); - const OtherNorm = Quantity(T, N, Tr.dims.argsOpt(), hlp.finerScales(Self, Tr).argsOpt()); - const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const r2: Vec = if (comptime Tr == OtherNorm) other.data else other.to(OtherNorm).data; - return .{ .data = .{@reduce(.Add, l * r2)} }; - } - - /// 3D cross product. Requires N = 3. - pub inline fn cross(self: Self, other: anytype) Quantity( - T, - 3, - dims.add(@TypeOf(other).dims).argsOpt(), - hlp.finerScales(Self, @TypeOf(other)).argsOpt(), - ) { - comptime if (N != 3) @compileError("cross() requires len=3."); - const a = self.data; - const b = other.data; - return .{ .data = .{ - a[1] * b[2] - a[2] * b[1], - a[2] * b[0] - a[0] * b[2], - a[0] * b[1] - a[1] * b[0], - } }; - } - - /// Sum of squared components. Cheaper than length(); use for comparisons. - pub inline fn lengthSqr(self: Self) T { - return @reduce(.Add, self.data * self.data); - } - - /// Euclidean length. Float types use SIMD @reduce → @sqrt. - /// Integer types use integer sqrt (truncated). - pub inline fn length(self: Self) T { - const sq = self.lengthSqr(); - if (comptime @typeInfo(T) == .int) { - const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); - return @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(sq))))); - } - return @sqrt(sq); - } - - /// Product of all components. Result dimension is (original dim × N). - pub inline fn product(self: Self) Quantity(T, 1, dims.scale(N).argsOpt(), scales.argsOpt()) { - return .{ .data = .{@reduce(.Mul, self.data)} }; - } - - // --------------------------------------------------------------- - // Formatting (unchanged from old Scalar / Vector) - // --------------------------------------------------------------- - - pub fn formatNumber( - self: Self, - writer: *std.Io.Writer, - options: std.fmt.Number, - ) !void { - if (comptime N == 1) { - // Scalar-style: just print the value + units - switch (@typeInfo(T)) { - .float, .comptime_float => try writer.printFloat(self.data[0], options), - .int, .comptime_int => try writer.printInt(self.data[0], 10, .lower, .{ - .width = options.width, - .alignment = options.alignment, - .fill = options.fill, - .precision = options.precision, - }), - else => unreachable, - } - } else { - // Vector-style: (v0, v1, …) + units - try writer.writeAll("("); - inline for (0..N) |i| { - if (i > 0) try writer.writeAll(", "); - switch (@typeInfo(T)) { - .float, .comptime_float => try writer.printFloat(self.data[i], options), - .int, .comptime_int => try writer.printInt(self.data[i], 10, .lower, .{ - .width = options.width, - .alignment = options.alignment, - .fill = options.fill, - .precision = options.precision, - }), - else => unreachable, - } - } - try writer.writeAll(")"); - } - - var first = true; - inline for (std.enums.values(Dimension)) |bu| { - const v = dims.get(bu); - if (comptime v == 0) continue; - if (!first) try writer.writeAll("."); - first = false; - - const uscale = scales.get(bu); - if (bu == .T and (uscale == .min or uscale == .hour or uscale == .year)) - try writer.print("{s}", .{uscale.str()}) - else - try writer.print("{s}{s}", .{ uscale.str(), bu.unit() }); - - if (v != 1) try hlp.printSuperscript(writer, v); - } - } - }; -} - -// Scalar tests - -pub fn Scalar(comptime T: type, comptime d: Dimensions.ArgOpts, comptime s: Scales.ArgOpts) type { - return Quantity(T, 1, d, s); -} - -test "Scalar initiat" { - const Meter = Scalar(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }); - const Second = Scalar(f32, .{ .T = 1 }, .{ .T = .n }); - - const distance = Meter.splat(10); - const time = Second.splat(2); - - try std.testing.expectEqual(10, distance.value()); - try std.testing.expectEqual(2, time.value()); -} - -test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const KiloMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); - - const m1000 = Meter.splat(1000); - const km1 = KiloMeter.splat(1); - const km2 = KiloMeter.splat(2); - - // Equal / Not Equal - try std.testing.expect(m1000.eq(km1)); - try std.testing.expect(km1.eq(m1000)); - try std.testing.expect(km2.ne(m1000)); - - // Greater Than / Greater Than or Equal - try std.testing.expect(km2.gt(m1000)); - try std.testing.expect(km2.gt(km1)); - try std.testing.expect(km1.gte(m1000)); - try std.testing.expect(km2.gte(m1000)); - - // Less Than / Less Than or Equal - try std.testing.expect(m1000.lt(km2)); - try std.testing.expect(km1.lt(km2)); - try std.testing.expect(km1.lte(m1000)); - try std.testing.expect(m1000.lte(km2)); -} - -test "Scalar Add" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - - const distance = Meter.splat(10); - const distance2 = Meter.splat(20); - - const added = distance.add(distance2); - try std.testing.expectEqual(30, added.value()); - try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); - - const KiloMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); - const distance3 = KiloMeter.splat(2); - const added2 = distance.add(distance3); - try std.testing.expectEqual(2010, added2.value()); - try std.testing.expectEqual(1, @TypeOf(added2).dims.get(.L)); - - const added3 = distance3.add(distance).to(KiloMeter); - try std.testing.expectEqual(2, added3.value()); - try std.testing.expectEqual(1, @TypeOf(added3).dims.get(.L)); - - const KiloMeter_f = Scalar(f64, .{ .L = 1 }, .{ .L = .k }); - const distance4 = KiloMeter_f.splat(2); - const added4 = distance4.add(distance).to(KiloMeter_f); - try std.testing.expectApproxEqAbs(2.01, added4.value(), 0.000001); - try std.testing.expectEqual(1, @TypeOf(added4).dims.get(.L)); -} - -test "Scalar Sub" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const KiloMeter_f = Scalar(f64, .{ .L = 1 }, .{ .L = .k }); - - const a = Meter.splat(500); - const b = Meter.splat(200); - - const diff = a.sub(b); - try std.testing.expectEqual(300, diff.value()); - const diff2 = b.sub(a); - try std.testing.expectEqual(-300, diff2.value()); - - const km_f = KiloMeter_f.splat(2.5); - const m_f = Meter.splat(500); - const diff3 = km_f.sub(m_f); - try std.testing.expectApproxEqAbs(2000, diff3.value(), 1e-4); -} - -test "Scalar MulBy" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const Second = Scalar(f32, .{ .T = 1 }, .{}); - - const d = Meter.splat(3.0); - const t = Second.splat(4.0); - - const area_time = d.mul(t); - try std.testing.expectEqual(12, area_time.value()); - try std.testing.expectEqual(1, @TypeOf(area_time).dims.get(.L)); - try std.testing.expectEqual(1, @TypeOf(area_time).dims.get(.T)); - - const d2 = Meter.splat(5.0); - const area = d.mul(d2); - try std.testing.expectEqual(15, area.value()); - try std.testing.expectEqual(2, @TypeOf(area).dims.get(.L)); -} - -test "Scalar MulBy with scale" { - const KiloMeter = Scalar(f32, .{ .L = 1 }, .{ .L = .k }); - const KiloGram = Scalar(f32, .{ .M = 1 }, .{ .M = .k }); - - const dist = KiloMeter.splat(2.0); - const mass = KiloGram.splat(3.0); - const prod = dist.mul(mass); - try std.testing.expectEqual(1, @TypeOf(prod).dims.get(.L)); - try std.testing.expectEqual(1, @TypeOf(prod).dims.get(.M)); -} - -test "Scalar MulBy with type change" { - const Meter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); - const Second = Scalar(f64, .{ .T = 1 }, .{}); - const KmSec = Scalar(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }); - const KmSec_f = Scalar(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }); - - const d = Meter.splat(3.0); - const t = Second.splat(4.0); - - const area_time = d.mul(t).to(KmSec); - const area_time_f = d.mul(t).to(KmSec_f); - try std.testing.expectEqual(12, area_time.value()); - try std.testing.expectApproxEqAbs(12.0, area_time_f.value(), 0.0001); -} - -test "Scalar MulBy small" { - const Meter = Scalar(i128, .{ .L = 1 }, .{ .L = .n }); - const Second = Scalar(f32, .{ .T = 1 }, .{}); - - const d = Meter.splat(3.0); - const t = Second.splat(4.0); - - const area_time = d.mul(t); - try std.testing.expectEqual(12, area_time.value()); -} - -test "Scalar MulBy dimensionless" { - const DimLess = Scalar(i128, .{}, .{}); - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - - const d = Meter.splat(7); - const scaled = d.mul(DimLess.splat(3)); - try std.testing.expectEqual(21, scaled.value()); -} - -test "Scalar Sqrt" { - const MeterSquare = Scalar(i128, .{ .L = 2 }, .{}); - - var d = MeterSquare.splat(9); - var scaled = d.sqrt(); - try std.testing.expectEqual(3, scaled.value()); - try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); - - d = MeterSquare.splat(-5); - scaled = d.sqrt(); - try std.testing.expectEqual(0, scaled.value()); - - const MeterSquare_f = Scalar(f64, .{ .L = 2 }, .{}); - const d2 = MeterSquare_f.splat(20); - const scaled2 = d2.sqrt(); - try std.testing.expectApproxEqAbs(4.472135955, scaled2.value(), 1e-4); -} - -test "Scalar Chained: velocity and acceleration" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const Second = Scalar(f32, .{ .T = 1 }, .{}); - - const dist = Meter.splat(100.0); - const t1 = Second.splat(5.0); - const velocity = dist.div(t1); - try std.testing.expectEqual(20, velocity.value()); - - const t2 = Second.splat(4.0); - const accel = velocity.div(t2); - try std.testing.expectEqual(5, accel.value()); -} - -test "Scalar DivBy integer exact" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const Second = Scalar(f32, .{ .T = 1 }, .{}); - - const dist = Meter.splat(120); - const time = Second.splat(4); - const vel = dist.div(time); - - try std.testing.expectEqual(30, vel.value()); -} - -test "Scalar Finer scales skip dim 0" { - const Dimless = Scalar(i128, .{}, .{}); - const KiloMetre = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); - - const r = Dimless.splat(30); - const time = KiloMetre.splat(4); - const vel = r.mul(time); - - try std.testing.expectEqual(120, vel.value()); - try std.testing.expectEqual(Scales.UnitScale.k, @TypeOf(vel).scales.get(.L)); -} - -test "Scalar Conversion chain: km -> m -> cm" { - const KiloMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .k }); - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const CentiMeter = Scalar(i128, .{ .L = 1 }, .{ .L = .c }); - - const km = KiloMeter.splat(15); - const m = km.to(Meter); - const cm = m.to(CentiMeter); - - try std.testing.expectEqual(15_000, m.value()); - try std.testing.expectEqual(1_500_000, cm.value()); -} - -test "Scalar Conversion: hours -> minutes -> seconds" { - const Hour = Scalar(i128, .{ .T = 1 }, .{ .T = .hour }); - const Minute = Scalar(i128, .{ .T = 1 }, .{ .T = .min }); - const Second = Scalar(i128, .{ .T = 1 }, .{}); - - const h = Hour.splat(1.0); - const min = h.to(Minute); - const sec = min.to(Second); - - try std.testing.expectEqual(60, min.value()); - try std.testing.expectEqual(3600, sec.value()); -} - -test "Scalar Format Scalar" { - const MeterPerSecondSq = Scalar(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }); - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - - const m = Meter.splat(1.23456); - const accel = MeterPerSecondSq.splat(9.81); - - var buf: [64]u8 = undefined; - var res = try std.fmt.bufPrint(&buf, "{d:.2}", .{m}); - try std.testing.expectEqualStrings("1.23m", res); - - res = try std.fmt.bufPrint(&buf, "{d}", .{accel}); - try std.testing.expectEqualStrings("9.81m.ns⁻²", res); -} - -test "Scalar Abs" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const m1 = Meter.splat(-50); - const m2 = m1.abs(); - - try std.testing.expectEqual(50, m2.value()); - - const m_float = Scalar(f32, .{ .L = 1 }, .{}); - const m3 = m_float.splat(-42.5); - try std.testing.expectEqual(42.5, m3.abs().value()); -} - -test "Scalar Pow" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const d = Meter.splat(4); - - const area = d.pow(2); - try std.testing.expectEqual(16, area.value()); - - const volume = d.pow(3); - try std.testing.expectEqual(64, volume.value()); -} - -test "Scalar mul comptime_int" { - const Meter = Scalar(i128, .{ .L = 1 }, .{}); - const d = Meter.splat(7); - - const scaled = d.mul(3); - try std.testing.expectEqual(21, scaled.value()); -} - -test "Scalar add/sub bare number on dimensionless scalar" { - const DimLess = Scalar(i128, .{}, .{}); - const a = DimLess.splat(10); - - const b = a.add(5); - try std.testing.expectEqual(15, b.value()); - - const c = a.sub(3); - try std.testing.expectEqual(7, c.value()); -} - -test "Scalar Imperial length scales" { - const Foot = Scalar(f64, .{ .L = 1 }, .{ .L = .ft }); - const Meter = Scalar(f64, .{ .L = 1 }, .{}); - const Inch = Scalar(f64, .{ .L = 1 }, .{ .L = .inch }); - - const one_ft = Foot.splat(1.0); - try std.testing.expectApproxEqAbs(0.3048, one_ft.to(Meter).value(), 1e-9); - - const twelve_in = Inch.splat(12.0); - try std.testing.expectApproxEqAbs(1.0, twelve_in.to(Foot).value(), 1e-9); -} - -test "Scalar Imperial mass scales" { - const Pound = Scalar(f64, .{ .M = 1 }, .{ .M = .lb }); - const Ounce = Scalar(f64, .{ .M = 1 }, .{ .M = .oz }); - - const two_lb = Pound.splat(2.0); - const eight_oz = Ounce.splat(8.0); - const total = two_lb.add(eight_oz).to(Pound); - try std.testing.expectApproxEqAbs(2.5, total.value(), 1e-6); -} - -test "Scalar comparisons with comptime_int on dimensionless scalar" { - const DimLess = Scalar(i128, .{}, .{}); - const x = DimLess.splat(42); - - try std.testing.expect(x.eq(42)); - try std.testing.expect(x.gt(10)); -} - -// Vector tests - -pub fn Vector(N: comptime_int, Q: type) type { - return Quantity(Q.ValueType, N, Q.dims.argsOpt(), Q.scales.argsOpt()); -} - -test "Vector initiate" { - const Meter = Vector(4, Scalar(f32, .{ .L = 1 }, .{})); - const m = Meter.splat(1); - - try std.testing.expect(m.data[0] == 1); -} - -test "Vector format" { - const MeterPerSecondSq = Scalar(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }); - const KgMeterPerSecond = Scalar(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }); - - const accel = MeterPerSecondSq.Vec3.splat(9.81); - const momentum = KgMeterPerSecond.Vec3{ .data = .{ 43, 0, 11 } }; - - var buf: [64]u8 = undefined; - var res = try std.fmt.bufPrint(&buf, "{d}", .{accel}); - try std.testing.expectEqualStrings("(9.81, 9.81, 9.81)m.ns⁻²", res); - - res = try std.fmt.bufPrint(&buf, "{d:.2}", .{momentum}); - try std.testing.expectEqualStrings("(43.00, 0.00, 11.00)m.kg.s⁻¹", res); -} - -test "Vector Vec3 Init and Basic Arithmetic" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const Vec3M = Meter.Vec3; - - // Test zero, one, initDefault - const v_zero = Vec3M.zero; - try std.testing.expectEqual(0, v_zero.data[0]); - try std.testing.expectEqual(0, v_zero.data[1]); - try std.testing.expectEqual(0, v_zero.data[2]); - - const v_one = Vec3M.one; - try std.testing.expectEqual(1, v_one.data[0]); - try std.testing.expectEqual(1, v_one.data[1]); - try std.testing.expectEqual(1, v_one.data[2]); - - const v_def = Vec3M.splat(5); - try std.testing.expectEqual(5, v_def.data[0]); - try std.testing.expectEqual(5, v_def.data[1]); - try std.testing.expectEqual(5, v_def.data[2]); - - // Test add and sub - const v1 = Vec3M{ .data = .{ 10, 20, 30 } }; - const v2 = Vec3M{ .data = .{ 2, 4, 6 } }; - - const added = v1.add(v2); - try std.testing.expectEqual(12, added.data[0]); - try std.testing.expectEqual(24, added.data[1]); - try std.testing.expectEqual(36, added.data[2]); - - const subbed = v1.sub(v2); - try std.testing.expectEqual(8, subbed.data[0]); - try std.testing.expectEqual(16, subbed.data[1]); - try std.testing.expectEqual(24, subbed.data[2]); - - // Test negate - const neg = v1.negate(); - try std.testing.expectEqual(-10, neg.data[0]); - try std.testing.expectEqual(-20, neg.data[1]); - try std.testing.expectEqual(-30, neg.data[2]); -} - -test "Vector Kinematics (Scalar Mul/Div)" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const Second = Scalar(i32, .{ .T = 1 }, .{}); - const Vec3M = Meter.Vec3; - - const pos = Vec3M{ .data = .{ 100, 200, 300 } }; - const time = Second.splat(10); - - // Vector divided by scalar (Velocity = Position / Time) - const vel = pos.divScalar(time); - try std.testing.expectEqual(10, vel.data[0]); - try std.testing.expectEqual(20, vel.data[1]); - try std.testing.expectEqual(30, vel.data[2]); - try std.testing.expectEqual(1, @TypeOf(vel).dims.get(.L)); - try std.testing.expectEqual(-1, @TypeOf(vel).dims.get(.T)); - - // Vector multiplied by scalar (Position = Velocity * Time) - const new_pos = vel.mulScalar(time); - try std.testing.expectEqual(100, new_pos.data[0]); - try std.testing.expectEqual(200, new_pos.data[1]); - try std.testing.expectEqual(300, new_pos.data[2]); - try std.testing.expectEqual(1, @TypeOf(new_pos).dims.get(.L)); - try std.testing.expectEqual(0, @TypeOf(new_pos).dims.get(.T)); -} - -test "Vector Element-wise Math and Scaling" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const Vec3M = Meter.Vec3; - - const v1 = Vec3M{ .data = .{ 10, 20, 30 } }; - const v2 = Vec3M{ .data = .{ 2, 5, 10 } }; - - // Element-wise division - const div = v1.div(v2); - try std.testing.expectEqual(5, div.data[0]); - try std.testing.expectEqual(4, div.data[1]); - try std.testing.expectEqual(3, div.data[2]); - try std.testing.expectEqual(0, @TypeOf(div).dims.get(.L)); // M / M = Dimensionless -} - -test "Vector Conversions" { - const KiloMeter = Scalar(i32, .{ .L = 1 }, .{ .L = .k }); - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - - const v_km = KiloMeter.Vec3{ .data = .{ 1, 2, 3 } }; - const v_m = v_km.to(Meter); - - try std.testing.expectEqual(1000, v_m.data[0]); - try std.testing.expectEqual(2000, v_m.data[1]); - try std.testing.expectEqual(3000, v_m.data[2]); - - // Type checking the result - try std.testing.expectEqual(1, @TypeOf(v_m).dims.get(.L)); - try std.testing.expectEqual(UnitScale.none, @TypeOf(v_m).scales.get(.L)); -} - -test "Vector Length" { - const MeterInt = Scalar(i32, .{ .L = 1 }, .{}); - const MeterFloat = Scalar(f32, .{ .L = 1 }, .{}); - - // Integer length - // 3-4-5 triangle on XY plane - const v_int = MeterInt.Vec3{ .data = .{ 3, 4, 0 } }; - try std.testing.expectEqual(25, v_int.lengthSqr()); - try std.testing.expectEqual(5, v_int.length()); - - // Float length - const v_float = MeterFloat.Vec3{ .data = .{ 3.0, 4.0, 0.0 } }; - try std.testing.expectApproxEqAbs(@as(f32, 25.0), v_float.lengthSqr(), 1e-4); - try std.testing.expectApproxEqAbs(@as(f32, 5.0), v_float.length(), 1e-4); -} - -test "Vector Comparisons" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - const KiloMeter = Scalar(f32, .{ .L = 1 }, .{ .L = .k }); - - const v1 = Meter.Vec3{ .data = .{ 1000.0, 500.0, 0.0 } }; - const v2 = KiloMeter.Vec3{ .data = .{ 1.0, 0.5, 0.0 } }; - const v3 = KiloMeter.Vec3{ .data = .{ 1.0, 0.6, 0.0 } }; - - // 1. Equality (Whole vector) - try std.testing.expect(v1.eqAll(v2)); - try std.testing.expect(v1.neAll(v3)); - - // 2. Element-wise Ordered Comparison - const higher = v3.gt(v1); // compares 1km, 0.6km, 0km vs 1000m, 500m, 0m - try std.testing.expectEqual(false, higher[0]); // 1km == 1000m - try std.testing.expectEqual(true, higher[1]); // 0.6km > 500m - try std.testing.expectEqual(false, higher[2]); // 0 == 0 - - // 3. Element-wise Equal Comparison - const equal = v3.eq(v1); // compares 1km, 0.6km, 0km vs 1000m, 500m, 0m - try std.testing.expectEqual(true, equal[0]); // 1km == 1000m - try std.testing.expectEqual(false, equal[1]); // 0.6km > 500m - try std.testing.expectEqual(true, equal[2]); // 0 == 0 - - // 3. Less than or equal - const low_eq = v1.lte(v3); - try std.testing.expect(low_eq[0] and low_eq[1] and low_eq[2]); -} - -test "Vector vs Scalar Comparisons" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - const KiloMeter = Scalar(f32, .{ .L = 1 }, .{ .L = .k }); - - const positions = Meter.Vec3{ .data = .{ 500.0, 1200.0, 3000.0 } }; - const threshold = KiloMeter.splat(1); // 1km (1000m) - - // Check which axes exceed the 1km threshold - const exceeded = positions.gtScalar(threshold); - - try std.testing.expectEqual(false, exceeded[0]); // 500m > 1km is false - try std.testing.expectEqual(true, exceeded[1]); // 1200m > 1km is true - try std.testing.expectEqual(true, exceeded[2]); // 3000m > 1km is true - - // Check for equality (broadcasted) - const exact_match = positions.eqScalar(Meter.splat(500)); - try std.testing.expect(exact_match[0] == true); - try std.testing.expect(exact_match[1] == false); -} - -test "Vector Dot and Cross Products" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - const Newton = Scalar(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}); - - const pos = Meter.Vec3{ .data = .{ 10.0, 0.0, 0.0 } }; - const force = Newton.Vec3{ .data = .{ 5.0, 5.0, 0.0 } }; - - // 1. Dot Product (Work = F dot d) - const work = force.dot(pos); - try std.testing.expectEqual(50.0, work.value()); - // Dimensions should be M¹L²T⁻² (Energy/Joules) - try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M)); - try std.testing.expectEqual(2, @TypeOf(work).dims.get(.L)); - try std.testing.expectEqual(-2, @TypeOf(work).dims.get(.T)); - - // 2. Cross Product (Torque = r cross F) - const torque = pos.cross(force); - try std.testing.expectEqual(0.0, torque.data[0]); - try std.testing.expectEqual(0.0, torque.data[1]); - try std.testing.expectEqual(50.0, torque.data[2]); - // Torque dimensions are same as Energy but as a Vector - try std.testing.expectEqual(2, @TypeOf(torque).dims.get(.L)); -} - -test "Vector Abs, Pow, Sqrt and Product" { - const Meter = Scalar(f32, .{ .L = 1 }, .{}); - - const v1 = Meter.Vec3{ .data = .{ -2.0, 3.0, -4.0 } }; - - // 1. Abs - const v_abs = v1.abs(); - try std.testing.expectEqual(2.0, v_abs.data[0]); - try std.testing.expectEqual(4.0, v_abs.data[2]); - - // 2. Product (L1 * L1 * L1 = L3) - const vol = v_abs.product(); - try std.testing.expectEqual(24.0, vol.value()); - try std.testing.expectEqual(3, @TypeOf(vol).dims.get(.L)); - - // 3. Pow (Scalar exponent: (L1)^2 = L2) - const area_vec = v_abs.pow(2); - try std.testing.expectEqual(4.0, area_vec.data[0]); - try std.testing.expectEqual(16.0, area_vec.data[2]); - try std.testing.expectEqual(2, @TypeOf(area_vec).dims.get(.L)); - - // 4. Sqrt - const sqrted = area_vec.sqrt(); - try std.testing.expectEqual(2, sqrted.data[0]); - try std.testing.expectEqual(4, sqrted.data[2]); - try std.testing.expectEqual(1, @TypeOf(sqrted).dims.get(.L)); -} - -test "Vector mulScalar comptime_int" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const v = Meter.Vec3{ .data = .{ 1, 2, 3 } }; - - const scaled = v.mulScalar(10); // comptime_int → dimensionless - try std.testing.expectEqual(10, scaled.data[0]); - try std.testing.expectEqual(20, scaled.data[1]); - try std.testing.expectEqual(30, scaled.data[2]); - // Dimensions unchanged: L¹ × dimensionless = L¹ - try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); - try std.testing.expectEqual(0, @TypeOf(scaled).dims.get(.T)); -} - -test "Vector mulScalar comptime_float" { - const MeterF = Scalar(f32, .{ .L = 1 }, .{}); - const v = MeterF.Vec3{ .data = .{ 1.0, 2.0, 4.0 } }; - - const scaled = v.mulScalar(0.5); // comptime_float → dimensionless - try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6); - try std.testing.expectApproxEqAbs(1.0, scaled.data[1], 1e-6); - try std.testing.expectApproxEqAbs(2.0, scaled.data[2], 1e-6); - try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); -} - -test "Vector mulScalar T (value type)" { - const MeterF = Scalar(f32, .{ .L = 1 }, .{}); - const v = MeterF.Vec3{ .data = .{ 3.0, 6.0, 9.0 } }; - const factor: f32 = 2.0; - - const scaled = v.mulScalar(factor); - try std.testing.expectApproxEqAbs(6.0, scaled.data[0], 1e-6); - try std.testing.expectApproxEqAbs(12.0, scaled.data[1], 1e-6); - try std.testing.expectApproxEqAbs(18.0, scaled.data[2], 1e-6); - try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); -} - -test "Vector divScalar comptime_int" { - const Meter = Scalar(i32, .{ .L = 1 }, .{}); - const v = Meter.Vec3{ .data = .{ 10, 20, 30 } }; - - const halved = v.divScalar(2); // comptime_int → dimensionless divisor - try std.testing.expectEqual(5, halved.data[0]); - try std.testing.expectEqual(10, halved.data[1]); - try std.testing.expectEqual(15, halved.data[2]); - try std.testing.expectEqual(1, @TypeOf(halved).dims.get(.L)); -} - -test "Vector divScalar comptime_float" { - const MeterF = Scalar(f64, .{ .L = 1 }, .{}); - const v = MeterF.Vec3{ .data = .{ 9.0, 6.0, 3.0 } }; - - const r = v.divScalar(3.0); - try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9); - try std.testing.expectApproxEqAbs(2.0, r.data[1], 1e-9); - try std.testing.expectApproxEqAbs(1.0, r.data[2], 1e-9); - try std.testing.expectEqual(1, @TypeOf(r).dims.get(.L)); -} - -test "Vector eqScalar / gtScalar with comptime_int on dimensionless vector" { - // Bare numbers are dimensionless, so comparisons only work when vector is dimensionless too. - const DimLess = Scalar(i32, .{}, .{}); - const v = DimLess.Vec3{ .data = .{ 1, 2, 3 } }; - - const eq_res = v.eqScalar(2); - try std.testing.expectEqual(false, eq_res[0]); - try std.testing.expectEqual(true, eq_res[1]); - try std.testing.expectEqual(false, eq_res[2]); - - const gt_res = v.gtScalar(1); - try std.testing.expectEqual(false, gt_res[0]); - try std.testing.expectEqual(true, gt_res[1]); - try std.testing.expectEqual(true, gt_res[2]); -} diff --git a/src/Scales.zig b/src/Scales.zig index 246565f..e7877e1 100644 --- a/src/Scales.zig +++ b/src/Scales.zig @@ -1,5 +1,4 @@ const std = @import("std"); -const hlp = @import("helper.zig"); const Dimensions = @import("Dimensions.zig"); const Dimension = @import("Dimensions.zig").Dimension; @@ -56,35 +55,35 @@ pub const UnitScale = enum(isize) { // Undefined _, - pub inline fn str(self: @This()) []const u8 { + pub fn str(self: @This()) []const u8 { var buf: [16]u8 = undefined; return switch (self) { - inline .none => "", - inline .P, .T, .G, .M, .k, .h, .da, .d, .c, .m, .u, .n, .p, .f, .min, .hour, .year, .inch, .ft, .yd, .mi, .oz, .lb, .st => @tagName(self), + .none => "", + .P, .T, .G, .M, .k, .h, .da, .d, .c, .m, .u, .n, .p, .f, .min, .hour, .year, .inch, .ft, .yd, .mi, .oz, .lb, .st => @tagName(self), else => std.fmt.bufPrint(&buf, "[{d}]", .{@intFromEnum(self)}) catch "[]", // This cannot be inline because of non exhaustive enum, but that's ok, it is just str, not calculation }; } - pub inline fn getFactor(self: @This()) comptime_float { - return comptime switch (self) { + pub fn getFactor(self: @This()) comptime_float { + return switch (self) { // Standard SI Exponents - inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))), + .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))), // Time Factors - inline .min, .hour, .year => @floatFromInt(@intFromEnum(self)), + .min, .hour, .year => @floatFromInt(@intFromEnum(self)), // Imperial Length (metres) - inline .inch => 0.0254, - inline .ft => 0.3048, - inline .yd => 0.9144, - inline .mi => 1609.344, + .inch => 0.0254, + .ft => 0.3048, + .yd => 0.9144, + .mi => 1609.344, // Imperial Mass (grams — base unit for M is gram, i.e. .none = 1 g) - inline .oz => 28.3495231, - inline .lb => 453.59237, - inline .st => 6350.29318, + .oz => 28.3495231, + .lb => 453.59237, + .st => 6350.29318, - inline else => @floatFromInt(@intFromEnum(self)), + else => @floatFromInt(@intFromEnum(self)), }; } }; @@ -98,40 +97,46 @@ data: std.EnumArray(Dimension, UnitScale), /// Unspecified dimensions default to `.none` (factor 1). pub fn init(comptime init_val: ArgOpts) Self { comptime var s = Self{ .data = std.EnumArray(Dimension, UnitScale).initFill(.none) }; - inline for (std.meta.fields(@TypeOf(init_val))) |f| { - if (comptime hlp.isInt(@TypeOf(@field(init_val, f.name)))) + for (std.meta.fields(@TypeOf(init_val))) |f| { + if (comptime @typeInfo(@TypeOf(@field(init_val, f.name))) == .comptime_int) s.data.set(@field(Dimension, f.name), @enumFromInt(@field(init_val, f.name))) else s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); } - return s; + return comptime s; } -pub fn initFill(comptime val: UnitScale) Self { - return comptime .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; +pub fn initFill(val: UnitScale) Self { + return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; } -pub fn get(comptime self: Self, comptime key: Dimension) UnitScale { - return comptime self.data.get(key); +pub fn get(self: Self, key: Dimension) UnitScale { + return self.data.get(key); } -pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void { - comptime self.data.set(key, val); +pub fn set(self: *Self, key: Dimension, val: UnitScale) void { + self.data.set(key, val); +} + +pub fn eql(self: Self, other: Self) bool { + for (self.data.values, other.data.values) |l, r| + if (l != r) return false; + return true; } pub fn argsOpt(self: Self) ArgOpts { var args: ArgOpts = undefined; - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| @field(args, @tagName(d)) = self.get(d); return args; } /// Compute the combined scale factor for a given dimension signature. /// Each dimension's prefix is raised to its exponent and multiplied together. -pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float { +pub fn getFactor(s: Self, d: Dimensions) comptime_float { var factor: f64 = 1.0; for (std.enums.values(Dimension)) |dim| { - const power = comptime d.get(dim); + const power = d.get(dim); if (power == 0) continue; const base = s.get(dim).getFactor(); @@ -145,5 +150,5 @@ pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float factor /= base; } } - return comptime factor; + return factor; } diff --git a/src/Tensor.zig b/src/Tensor.zig new file mode 100644 index 0000000..9b2a593 --- /dev/null +++ b/src/Tensor.zig @@ -0,0 +1,1395 @@ +const std = @import("std"); +const Scales = @import("Scales.zig"); +const UnitScale = Scales.UnitScale; +const Dimensions = @import("Dimensions.zig"); +const Dimension = Dimensions.Dimension; + +// ───────────────────────────────────────────────────────────────────────────── +// Comptime utilities +// ───────────────────────────────────────────────────────────────────────────── + +pub fn shapeTotal(shape: []const comptime_int) usize { + var t: comptime_int = 1; + for (shape) |s| t *= s; + return t; +} + +/// Check if two shapes are strictly identical. +pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool { + if (a.len != b.len) return false; + for (a, 0..) |v, i| + if (v != b[i]) return false; + return true; +} + +/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). +/// e.g. shape {3, 4} → strides {4, 1} +/// shape {2, 3, 4} → strides {12, 4, 1} +pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int { + var st: [shape.len]comptime_int = undefined; + if (shape.len == 0) return st; + st[shape.len - 1] = 1; + if (shape.len > 1) { + var i: comptime_int = shape.len - 1; + while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i]; + } + return st; +} + +/// Return a copy of `shape` with the element at `axis` removed. +pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int { + var out: [shape.len - 1]comptime_int = undefined; + var j: comptime_int = 0; + for (shape, 0..) |v, i| { + if (i != axis) { + out[j] = v; + j += 1; + } + } + return out; +} + +/// Concatenate two compile-time slices. +pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int { + var out: [a.len + b.len]comptime_int = undefined; + for (a, 0..) |v, i| out[i] = v; + for (b, 0..) |v, i| out[a.len + i] = v; + return out; +} + +/// Decode a flat row-major index into N-D coordinates. +/// Called only in comptime contexts (all arguments are comptime). +pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize { + var coords: [n]comptime_int = undefined; + var tmp = flat; + for (0..n) |i| { + coords[i] = if (strd[i] == 0) 0 else tmp / strd[i]; + tmp = if (strd[i] == 0) 0 else tmp % strd[i]; + } + return coords; +} + +/// Encode N-D coordinates into a flat row-major index. +/// Called only in comptime contexts. +pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize { + var flat: usize = 0; + for (0..n) |i| flat += coords[i] * strd[i]; + return flat; +} + +/// Rebuild a full coordinate array by inserting `val` at `axis` into `free`. +/// `free` holds the remaining (non-contracted) coordinates in order. +pub fn insertAxis( + comptime n: usize, + comptime axis: usize, + comptime val: usize, + comptime free: []const usize, +) [n]usize { + var out: [n]usize = undefined; + var fi: usize = 0; + for (0..n) |i| { + if (i == axis) { + out[i] = val; + } else { + out[i] = free[fi]; + fi += 1; + } + } + return out; +} + +inline fn isInt(comptime T: type) bool { + return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int; +} + +fn finerScales(comptime T1: type, comptime T2: type) Scales { + const d1: Dimensions = T1.dims; + const d2: Dimensions = T2.dims; + const s1: Scales = T1.scales; + const s2: Scales = T2.scales; + comptime var out = Scales.initFill(.none); + for (std.enums.values(Dimension)) |dim| { + const scale1 = comptime s1.get(dim); + const scale2 = comptime s2.get(dim); + out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0) + .none + else if (comptime d1.get(dim) == 0) + scale2 + else if (comptime d2.get(dim) == 0) + scale1 + else if (comptime scale1.getFactor() > scale2.getFactor()) + scale2 + else + scale1); + } + return out; +} +// ───────────────────────────────────────────────────────────────────────────── +// File-scope RHS normalisation helpers +// ───────────────────────────────────────────────────────────────────────────── + +inline fn isTensor(comptime Rhs: type) bool { + return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR"); +} + +inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type { + if (comptime isTensor(Rhs)) return Rhs; + return Tensor(T, .{}, .{}, &.{1}); +} + +/// Take the anyvalue coming from operation and if it is a Tensor, return it. +/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r). +inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { + const Rhs = @TypeOf(r); + if (comptime isTensor(Rhs)) return r; + const scalar: T = switch (@typeInfo(Rhs)) { + .comptime_int => switch (comptime @typeInfo(T)) { + .float => @as(T, @floatFromInt(r)), + else => @as(T, r), + }, + .comptime_float => switch (comptime @typeInfo(T)) { + .int => @as(T, @intFromFloat(r)), + else => @as(T, r), + }, + .int => switch (comptime @typeInfo(T)) { + .float => @floatFromInt(r), + else => @intCast(r), + }, + .float => switch (comptime @typeInfo(T)) { + .int => @intFromFloat(r), + else => @floatCast(r), + }, + else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)), + }; + return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} }; +} + +pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void { + if (n == 0) return; + var val = n; + if (val < 0) { + try writer.writeAll("\u{207B}"); + val = -val; + } + var buf: [12]u8 = undefined; + const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return; + for (str) |c| { + const s = switch (c) { + '0' => "\u{2070}", + '1' => "\u{00B9}", + '2' => "\u{00B2}", + '3' => "\u{00B3}", + '4' => "\u{2074}", + '5' => "\u{2075}", + '6' => "\u{2076}", + '7' => "\u{2077}", + '8' => "\u{2078}", + '9' => "\u{2079}", + else => unreachable, + }; + try writer.writeAll(s); + } +} + +pub fn Tensor( + comptime T: type, + comptime d_opt: Dimensions.ArgOpts, + comptime s_opt: Scales.ArgOpts, + comptime shape_: []const comptime_int, +) type { + comptime { + if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1)."); + for (shape_) |s| { + if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1."); + } + } + @setEvalBranchQuota(100_000); + + const _total: usize = comptime shapeTotal(shape_); + const _strides = comptime shapeStrides(shape_); + const Vec = @Vector(_total, T); + + return struct { + data: Vec, + + const Self = @This(); + + pub const ValueType: type = T; + pub const dims: Dimensions = Dimensions.init(d_opt); + pub const scales: Scales = Scales.init(s_opt); + pub const shape: []const comptime_int = shape_; + pub const rank: comptime_int = shape_.len; + pub const total: comptime_int = _total; + pub const strides_arr: [shape_.len]comptime_int = _strides; + pub const ISTENSOR = true; + + /// Convert N-D coords (row-major) to flat index — fully comptime. + /// Usage: Tensor.idx(.{row, col}) + pub inline fn idx(comptime coords: [rank]usize) usize { + comptime { + var flat: usize = 0; + for (0..rank) |i| { + if (coords[i] >= shape[i]) @compileError("idx: Coordinate out of bounds"); + flat += coords[i] * strides_arr[i]; + } + return flat; + } + } + + /// Broadcast a single value across all elements. + pub inline fn splat(v: T) Self { + return .{ .data = @splat(v) }; + } + + pub const zero: Self = splat(0); + pub const one: Self = splat(1); + + /// Return a mutable slice to the flat storage — zero-copy WebGPU buffer mapping. + pub inline fn asSlice(self: *Self) []T { + return @as([*]T, @ptrCast(&self.data))[0..total]; + } + + inline fn RhsT(comptime Rhs: type) type { + return RhsTensorType(T, Rhs); + } + inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { + return toRhsTensor(T, r); + } + + inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec { + return if (comptime RhsType.total == 1 and total > 1) + @splat(r.data[0]) + else + r.data; + } + + /// Element-wise add. Dimensions must match; scales resolve to finer. + /// RHS must have the same shape as self, or total == 1 (broadcast). + pub inline fn add(self: Self, r: anytype) Tensor( + T, + dims.argsOpt(), + finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, + ) { + const rhs_t = rhs(r); + const RhsType = @TypeOf(r); + if (comptime !dims.eql(RhsType.dims)) + @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); + if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too + return .{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; + + const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; + const rr: Vec = blk: { + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); + break :blk broadcastToVec(RhsNorm, rn); + }; + return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; + } + + /// Element-wise sub. Dimensions must match; scales resolve to finer. + /// RHS must have the same shape as self, or total == 1 (broadcast). + pub inline fn sub(self: Self, r: anytype) Tensor( + T, + dims.argsOpt(), + finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, + ) { + const rhs_t = rhs(r); + const RhsType = @TypeOf(rhs_t); + if (comptime !dims.eql(RhsType.dims)) + @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS."); + if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too + return .{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; + + const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; + const rr: Vec = blk: { + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); + break :blk broadcastToVec(RhsNorm, rn); + }; + return .{ .data = if (comptime isInt(T)) l -| rr else l - rr }; + } + + /// Element-wise multiply. Dimension exponents summed. + /// Shape {1} RHS is automatically broadcast across all elements. + pub inline fn mul(self: Self, r: anytype) Tensor( + T, + dims.add(RhsT(@TypeOf(r)).dims).argsOpt(), + finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, + ) { + const rhs_q = rhs(r); + const RhsType = @TypeOf(rhs_q); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS."); + + const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rr: Vec = broadcastToVec(RhsNorm, rr_base); + return .{ .data = if (comptime isInt(T)) l *| rr else l * rr }; + } + + /// Element-wise divide. Dimension exponents subtracted. + /// Shape {1} RHS is automatically broadcast across all elements. + pub inline fn div(self: Self, r: anytype) Tensor( + T, + dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(), + finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + shape_, + ) { + const rhs_q = rhs(r); + const RhsType = @TypeOf(rhs_q); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS."); + + const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rr: Vec = broadcastToVec(RhsNorm, rr_base); + if (comptime isInt(T)) { + return .{ .data = @divTrunc(l, rr) }; + } else { + return .{ .data = l / rr }; + } + } + + /// Absolute value of every element. + pub inline fn abs(self: Self) Self { + return .{ .data = @bitCast(@abs(self.data)) }; + } + + /// Raise every element to a comptime integer exponent. + pub inline fn pow(self: Self, comptime exp: comptime_int) Tensor( + T, + dims.scale(exp).argsOpt(), + scales.argsOpt(), + shape_, + ) { + if (comptime exp == 0) return .{ .data = @splat(1) }; + if (comptime exp == 1) return self; + + var base = self.data; + var result: Vec = @splat(1); + comptime var e = @abs(exp); + + // $O(\log n)$ Exponentiation by squaring applied to the entire vector + inline while (e > 0) { + if (e % 2 == 1) { + result = if (comptime isInt(T)) result *| base else result * base; + } + e /= 2; + if (e > 0) { + base = if (comptime isInt(T)) base *| base else base * base; + } + } + if (comptime !isInt(T) and exp < 0) { + result = @as(Vec, @splat(1)) / result; + } + return .{ .data = result }; + } + + /// Square root of every element. All dimension exponents must be even. + pub inline fn sqrt(self: Self) Tensor( + T, + dims.div(2).argsOpt(), + scales.argsOpt(), + shape_, + ) { + if (comptime !dims.isSquare()) + @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); + if (comptime @typeInfo(T) == .float) { + return .{ .data = @sqrt(self.data) }; // Float is natively vectorized! + } else { + const arr: [total]T = self.data; // Add this! + var res_arr: [total]T = undefined; + const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); + for (0..total) |i| { + const v = arr[i]; + res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); + } + return .{ .data = res_arr }; + } + } + + /// Negate every element. + pub inline fn negate(self: Self) Self { + return .{ .data = -self.data }; + } + + /// Convert to a compatible Tensor type. + /// • Dimension mismatch → compile error. + /// • Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern). + /// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. + pub inline fn to( + self: Self, + comptime Dest: type, + ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { + const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); + + if (comptime Self == ActualDest) return self; + + // Run validation checks FIRST before dealing with types + if (comptime !dims.eql(ActualDest.dims)) + @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); + if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) + @compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar."); + + const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); + const DestT = ActualDest.ValueType; + const DestVec = @Vector(total, DestT); + + if (comptime ratio == 1.0 and T == DestT) + return .{ .data = self.data }; + + // If ratio is 1, handle type conversion correctly based on BOTH source and dest types + if (comptime ratio == 1.0) { + const T_info = @typeInfo(T); + const Dest_info = @typeInfo(DestT); + + return .{ + .data = if (comptime T_info == .int and Dest_info == .int) + @as(DestVec, @intCast(self.data)) + else if (comptime T_info == .float and Dest_info == .float) + @as(DestVec, @floatCast(self.data)) + else if (comptime T_info == .int and Dest_info == .float) + @as(DestVec, @floatFromInt(self.data)) + else if (comptime T_info == .float and Dest_info == .int) + @as(DestVec, @intFromFloat(self.data)) // Or @intFromFloat(@round(self.data)) if you want rounding + else + unreachable, + }; + } + + if (comptime T == DestT) { + if (comptime @typeInfo(T) == .float) + return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; + + if (comptime ratio >= 1.0) { + const mult: T = comptime @intFromFloat(@round(ratio)); + return .{ .data = self.data *| @as(Vec, @splat(mult)) }; + } else { + const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); + const half: T = comptime @divTrunc(div_val, 2); + + if (comptime @typeInfo(T).int.signedness == .unsigned) { + return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) }; + } else { + // Vectorized branchless negative handling + const is_pos = self.data >= @as(Vec, @splat(0)); + const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half))); + return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) }; + } + } + } + + // Cross-type fully vectorized casting with scales + const FVec = @Vector(total, f64); + const float_vec: FVec = switch (comptime @typeInfo(T)) { + .float => @floatCast(self.data), + .int => @floatFromInt(self.data), + else => unreachable, + }; + + const scaled = float_vec * @as(FVec, @splat(ratio)); + + return switch (comptime @typeInfo(DestT)) { + .float => .{ .data = @floatCast(scaled) }, + .int => .{ .data = @intFromFloat(@round(scaled)) }, + else => unreachable, + }; + } + + const CmpResult = if (total == 1) bool else [total]bool; + + inline fn cmpResult(v: @Vector(total, bool)) CmpResult { + return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v); + } + + /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. + inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { + const RhsType = @TypeOf(rhs_q); + if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); + + const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; + const rr: Vec = blk: { + const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + break :blk broadcastToVec(RhsNorm, rn); + }; + return .{ .l = l, .r = rr }; + } + + pub inline fn eq(self: Self, r: anytype) CmpResult { + const rhs_q = rhs(r); + if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + @compileError("Dimension mismatch in ne."); + const p = resolveScalePair(self, rhs_q); + return cmpResult(p.l == p.r); + } + + pub inline fn ne(self: Self, r: anytype) CmpResult { + const rhs_q = rhs(r); + if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + @compileError("Dimension mismatch in ne."); + const p = resolveScalePair(self, rhs_q); + return cmpResult(p.l != p.r); + } + + pub inline fn gt(self: Self, r: anytype) CmpResult { + const rhs_q = rhs(r); + if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + @compileError("Dimension mismatch in gt."); + const p = resolveScalePair(self, rhs_q); + return cmpResult(p.l > p.r); + } + + pub inline fn gte(self: Self, r: anytype) CmpResult { + const rhs_q = rhs(r); + if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + @compileError("Dimension mismatch in gte."); + const p = resolveScalePair(self, rhs_q); + return cmpResult(p.l >= p.r); + } + + pub inline fn lt(self: Self, r: anytype) CmpResult { + const rhs_q = rhs(r); + if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + @compileError("Dimension mismatch in lt."); + const p = resolveScalePair(self, rhs_q); + return cmpResult(p.l < p.r); + } + + pub inline fn lte(self: Self, r: anytype) CmpResult { + const rhs_q = rhs(r); + if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + @compileError("Dimension mismatch in lte."); + const p = resolveScalePair(self, rhs_q); + return cmpResult(p.l <= p.r); + } + + /// True iff every element is equal after scale resolution. + pub inline fn eqAll(self: Self, other: anytype) bool { + if (comptime !dims.eql(@TypeOf(other).dims)) + @compileError("Dimension mismatch in eqAll."); + const p = resolveScalePair(self, other); + return @reduce(.And, p.l == p.r); + } + + /// True iff any element differs after scale resolution. + pub inline fn neAll(self: Self, other: anytype) bool { + return !self.eqAll(other); + } + + pub inline fn contract( + self: Self, + other: anytype, + comptime axis_a: usize, + comptime axis_b: usize, + ) blk: { + const OT = @TypeOf(other); + if (axis_a >= rank) @compileError("contract: axis_a out of bounds"); + if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds"); + if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes"); + + const sa = shapeRemoveAxis(shape_, axis_a); + const sb = shapeRemoveAxis(OT.shape, axis_b); + const rs_raw = shapeCat(&sa, &sb); + const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw; + break :blk Tensor( + T, + dims.add(OT.dims).argsOpt(), + finerScales(Self, OT).argsOpt(), + rs, + ); + } { + const OT = @TypeOf(other); + const k: usize = comptime shape_[axis_a]; // contraction dimension + + const sa = comptime shapeRemoveAxis(shape_, axis_a); + const sb = comptime shapeRemoveAxis(OT.shape, axis_b); + const rs_raw = comptime shapeCat(&sa, &sb); + const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; + + const ResultType = Tensor( + T, + dims.add(OT.dims).argsOpt(), + finerScales(Self, OT).argsOpt(), + rs, + ); + + const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_); + const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape); + + const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; + const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; + + // FAST PATH: Dot Product + if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) { + if (comptime !isInt(T)) { + return .{ .data = @splat(@reduce(.Add, a_data * b_data)) }; + } else { + // For integers, we do a vectorized saturating multiply, + // then convert to an array to do a saturating sum + const mul_arr: [total]T = a_data *| b_data; + var acc: T = 0; + for (mul_arr) |val| acc +|= val; + return .{ .data = @splat(acc) }; + } + } + + // --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING --- + const a_arr: [total]T = a_data; + const b_arr: [OT.total]T = b_data; + + // FAST PATH: 2D Matrix Multiplication + if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) { + const rows = shape_[0]; + const cols = OT.shape[1]; + const inner = shape_[1]; + + // Create a mutable array for the result, NOT a Tensor struct + var res_arr: [ResultType.total]T = undefined; + + for (0..rows) |i| { + for (0..cols) |j| { + var acc: T = 0; + for (0..inner) |id| { + const a_flat = i * _strides[0] + id * _strides[1]; + const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1]; + + // Use a_arr and b_arr here + if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; + } + // Write to the array + res_arr[i * cols + j] = acc; + } + } + // Return the initialized Tensor struct + return .{ .data = res_arr }; + } + + // FALLBACK PATH + const rs_raw_strides = comptime shapeStrides(&rs_raw); + + // Create a mutable array for the result + var result_arr: [ResultType.total]T = undefined; + + for (0..ResultType.total) |res_flat| { + const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); + + var a_free: [sa.len]usize = undefined; + for (0..sa.len) |i| a_free[i] = res_coords[i]; + var b_free: [sb.len]usize = undefined; + for (0..sb.len) |i| b_free[i] = res_coords[sa.len + i]; + + var acc: T = 0; + for (0..k) |ki| { + const a_coords = insertAxis(rank, axis_a, ki, &a_free); + const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free); + const a_flat = encodeFlatCoords(&a_coords, rank, _strides); + const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); + + // Use a_arr and b_arr here + if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; + } + // Write to the array + result_arr[res_flat] = acc; + } + + // Return the initialized Tensor struct + return .{ .data = result_arr }; + } + + /// 3D Cross Product. Only defined for Rank-1 tensors of length 3. + /// Result dimensions are the sum of input dimensions. + pub inline fn cross(self: Self, other: anytype) Tensor( + T, + dims.add(RhsT(@TypeOf(other)).dims).argsOpt(), + finerScales(Self, RhsT(@TypeOf(other))).argsOpt(), + &.{3}, + ) { + const rhs_q = rhs(other); + const RhsType = @TypeOf(rhs_q); + + if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) { + @compileError("cross product is only defined for 3D vectors (rank-1, length 3)"); + } + + // Bring both to the same scale (e.g., mm vs m) + const p = self.resolveScalePair(rhs_q); + const l = p.l; + const r = p.r; + + var res: [3]T = undefined; + if (comptime isInt(T)) { + res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]); + res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]); + res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]); + } else { + res[0] = (l[1] * r[2]) - (l[2] * r[1]); + res[1] = (l[2] * r[0]) - (l[0] * r[2]); + res[2] = (l[0] * r[1]) - (l[1] * r[0]); + } + + return .{ .data = res }; + } + + /// Sum of squared elements. Cheaper than length(); use for ordering. + pub inline fn lengthSqr(self: Self) T { + return @reduce(.Add, self.data * self.data); + } + + /// Euclidean length (L2 norm). + pub inline fn length(self: Self) T { + const sq = self.lengthSqr(); + if (comptime @typeInfo(T) == .int) { + const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); + return @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(sq))))); + } + return @sqrt(sq); + } + + /// Product of all elements. Result has shape {1}; dimension exponent × total. + pub inline fn product(self: Self) Tensor( + T, + dims.scale(@as(comptime_int, total)).argsOpt(), + scales.argsOpt(), + &.{1}, + ) { + return .{ .data = .{@reduce(.Mul, self.data)} }; + } + + pub fn formatNumber( + self: Self, + writer: *std.Io.Writer, + options: std.fmt.Number, + ) !void { + if (comptime total == 1) { + switch (@typeInfo(T)) { + .float, .comptime_float => try writer.printFloat(self.data[0], options), + .int, .comptime_int => try writer.printInt(self.data[0], 10, .lower, .{ + .width = options.width, + .alignment = options.alignment, + .fill = options.fill, + .precision = options.precision, + }), + else => unreachable, + } + } else { + try writer.writeAll("("); + const max_to_print = 6; + inline for (0..@min(total, max_to_print)) |i| { + if (i > 0) try writer.writeAll(", "); + switch (@typeInfo(T)) { + .float, .comptime_float => try writer.printFloat(self.data[i], options), + .int, .comptime_int => try writer.printInt(self.data[i], 10, .lower, .{ + .width = options.width, + .alignment = options.alignment, + .fill = options.fill, + .precision = options.precision, + }), + else => unreachable, + } + if (comptime i == max_to_print - 1 and total != max_to_print - 1) + try writer.writeAll(", ..."); + } + try writer.writeAll(")"); + } + + var first = true; + inline for (std.enums.values(Dimension)) |bu| { + const v = dims.get(bu); + if (comptime v == 0) continue; + if (!first) try writer.writeAll("."); + first = false; + + const uscale = scales.get(bu); + if (bu == .T and (uscale == .min or uscale == .hour or uscale == .year)) + try writer.print("{s}", .{uscale.str()}) + else + try writer.print("{s}{s}", .{ uscale.str(), bu.unit() }); + + if (v != 1) try printSuperscript(writer, v); + } + } + }; +} + +// ═════════════════════════════════════════════════════════════════════════════ +// Tests +// ───────────────────────────────────────────────────────────────────────────── + +// ─── Scalar tests ───────────────────────────────────────────────────────── + +test "Scalar initiat" { + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{ .T = .n }, &.{1}); + + const distance = Meter.splat(10); + const time = Second.splat(2); + + try std.testing.expectEqual(10, distance.data[0]); + try std.testing.expectEqual(2, time.data[0]); +} + +test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + + const m1000 = Meter.splat(1000); + const km1 = KiloMeter.splat(1); + const km2 = KiloMeter.splat(2); + + try std.testing.expect(m1000.eq(km1)); + try std.testing.expect(km1.eq(m1000)); + try std.testing.expect(km2.ne(m1000)); + + try std.testing.expect(km2.gt(m1000)); + try std.testing.expect(km2.gt(km1)); + try std.testing.expect(km1.gte(m1000)); + try std.testing.expect(km2.gte(m1000)); + + try std.testing.expect(m1000.lt(km2)); + try std.testing.expect(km1.lt(km2)); + try std.testing.expect(km1.lte(m1000)); + try std.testing.expect(m1000.lte(km2)); +} + +test "Scalar Add" { + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); + + const distance = Meter.splat(10); + const distance2 = Meter.splat(20); + const added = distance.add(distance2); + try std.testing.expectEqual(30, added.data[0]); + try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); + + const distance3 = KiloMeter.splat(2); + const added2 = distance.add(distance3); + try std.testing.expectEqual(2010, added2.data[0]); + + const added3 = distance3.add(distance).to(KiloMeter); + try std.testing.expectEqual(2, added3.data[0]); + + const distance4 = KiloMeter_f.splat(2); + const added4 = distance4.add(distance).to(KiloMeter_f); + try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001); +} + +test "Scalar Sub" { + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); + + const a = Meter.splat(500); + const b = Meter.splat(200); + const diff = a.sub(b); + try std.testing.expectEqual(300, diff.data[0]); + const diff2 = b.sub(a); + try std.testing.expectEqual(-300, diff2.data[0]); + + const km_f = KiloMeter_f.splat(2.5); + const m_f = Meter.splat(500); + const diff3 = km_f.sub(m_f); + try std.testing.expectApproxEqAbs(2000, diff3.data[0], 1e-4); +} + +test "Scalar MulBy" { + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + + const d = Meter.splat(3); + const t = Second.splat(4); + const at = d.mul(t); + try std.testing.expectEqual(12, at.data[0]); + try std.testing.expectEqual(1, @TypeOf(at).dims.get(.L)); + try std.testing.expectEqual(1, @TypeOf(at).dims.get(.T)); + + const d2 = Meter.splat(5); + const area = d.mul(d2); + try std.testing.expectEqual(15, area.data[0]); + try std.testing.expectEqual(2, @TypeOf(area).dims.get(.L)); +} + +test "Scalar MulBy with scale" { + const KiloMeter = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const KiloGram = Tensor(f32, .{ .M = 1 }, .{ .M = .k }, &.{1}); + + const dist = KiloMeter.splat(2.0); + const mass = KiloGram.splat(3.0); + const prod = dist.mul(mass); + try std.testing.expectEqual(1, @TypeOf(prod).dims.get(.L)); + try std.testing.expectEqual(1, @TypeOf(prod).dims.get(.M)); +} + +test "Scalar MulBy with type change" { + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Second = Tensor(f64, .{ .T = 1 }, .{}, &.{1}); + const KmSec = Tensor(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); + const KmSec_f = Tensor(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); + + const d = Meter.splat(3); + const t = Second.splat(4); + + try std.testing.expectEqual(12, d.mul(t).to(KmSec).data[0]); + try std.testing.expectApproxEqAbs(12.0, d.mul(t).to(KmSec_f).data[0], 0.0001); +} + +test "Scalar MulBy small" { + const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .n }, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const d = Meter.splat(3); + const t = Second.splat(4); + try std.testing.expectEqual(12, d.mul(t).data[0]); +} + +test "Scalar MulBy dimensionless" { + const DimLess = Tensor(i128, .{}, .{}, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(7); + const scaled = d.mul(DimLess.splat(3)); + try std.testing.expectEqual(21, scaled.data[0]); +} + +test "Scalar Sqrt" { + const MeterSquare = Tensor(i128, .{ .L = 2 }, .{}, &.{1}); + const MeterSquare_f = Tensor(f64, .{ .L = 2 }, .{}, &.{1}); + + var d = MeterSquare.splat(9); + var scaled = d.sqrt(); + try std.testing.expectEqual(3, scaled.data[0]); + try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); + + d = MeterSquare.splat(-5); + scaled = d.sqrt(); + try std.testing.expectEqual(0, scaled.data[0]); + + const d2 = MeterSquare_f.splat(20); + const scaled2 = d2.sqrt(); + try std.testing.expectApproxEqAbs(4.472135955, scaled2.data[0], 1e-4); +} + +test "Scalar Chained: velocity and acceleration" { + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + + const dist = Meter.splat(100); + const t1 = Second.splat(5); + const velocity = dist.div(t1); + try std.testing.expectEqual(20, velocity.data[0]); + + const t2 = Second.splat(4); + const accel = velocity.div(t2); + try std.testing.expectEqual(5, accel.data[0]); +} + +test "Scalar DivBy integer exact" { + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + + const dist = Meter.splat(120); + const time = Second.splat(4); + const vel = dist.div(time); + try std.testing.expectEqual(30, vel.data[0]); +} + +test "Scalar Finer scales skip dim 0" { + const Dimless = Tensor(i128, .{}, .{}, &.{1}); + const KiloMetre = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + + const r = Dimless.splat(30); + const km = KiloMetre.splat(4); + const vel = r.mul(km); + try std.testing.expectEqual(120, vel.data[0]); + try std.testing.expectEqual(Scales.UnitScale.k, @TypeOf(vel).scales.get(.L)); +} + +test "Scalar Conversion chain: km -> m -> cm" { + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const CentiMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .c }, &.{1}); + + const km = KiloMeter.splat(15); + const m = km.to(Meter); + const cm = m.to(CentiMeter); + try std.testing.expectEqual(15_000, m.data[0]); + try std.testing.expectEqual(1_500_000, cm.data[0]); +} + +test "Scalar Conversion: hours -> minutes -> seconds" { + const Hour = Tensor(i128, .{ .T = 1 }, .{ .T = .hour }, &.{1}); + const Minute = Tensor(i128, .{ .T = 1 }, .{ .T = .min }, &.{1}); + const Second = Tensor(i128, .{ .T = 1 }, .{}, &.{1}); + + const h = Hour.splat(1); + const min = h.to(Minute); + const sec = min.to(Second); + try std.testing.expectEqual(60, min.data[0]); + try std.testing.expectEqual(3600, sec.data[0]); +} + +test "Scalar Format" { + const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{1}); + const Meter = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + + const m = Meter.splat(1.23456); + const accel = MeterPerSecondSq.splat(9.81); + + var buf: [64]u8 = undefined; + var res = try std.fmt.bufPrint(&buf, "{d:.2}", .{m}); + try std.testing.expectEqualStrings("1.23m", res); + + res = try std.fmt.bufPrint(&buf, "{d}", .{accel}); + try std.testing.expectEqualStrings("9.81m.ns⁻²", res); +} + +test "Scalar Abs" { + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const MeterF = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + + try std.testing.expectEqual(50, Meter.splat(-50).abs().data[0]); + try std.testing.expectEqual(42.5, MeterF.splat(-42.5).abs().data[0]); +} + +test "Scalar Pow" { + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(4); + try std.testing.expectEqual(16, d.pow(2).data[0]); + try std.testing.expectEqual(64, d.pow(3).data[0]); +} + +test "Scalar mul comptime_int" { + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const d = Meter.splat(7); + try std.testing.expectEqual(21, d.mul(3).data[0]); +} + +test "Scalar add/sub bare number on dimensionless scalar" { + const DimLess = Tensor(i128, .{}, .{}, &.{1}); + const a = DimLess.splat(10); + try std.testing.expectEqual(15, a.add(5).data[0]); + try std.testing.expectEqual(7, a.sub(3).data[0]); +} + +test "Scalar Imperial length scales" { + const Foot = Tensor(f64, .{ .L = 1 }, .{ .L = .ft }, &.{1}); + const Meter = Tensor(f64, .{ .L = 1 }, .{}, &.{1}); + const Inch = Tensor(f64, .{ .L = 1 }, .{ .L = .inch }, &.{1}); + + try std.testing.expectApproxEqAbs(0.3048, Foot.splat(1.0).to(Meter).data[0], 1e-9); + try std.testing.expectApproxEqAbs(1.0, Inch.splat(12.0).to(Foot).data[0], 1e-9); +} + +test "Scalar Imperial mass scales" { + const Pound = Tensor(f64, .{ .M = 1 }, .{ .M = .lb }, &.{1}); + const Ounce = Tensor(f64, .{ .M = 1 }, .{ .M = .oz }, &.{1}); + + const total = Pound.splat(2.0).add(Ounce.splat(8.0)).to(Pound); + try std.testing.expectApproxEqAbs(2.5, total.data[0], 1e-6); +} + +test "Scalar comparisons with comptime_int on dimensionless scalar" { + const DimLess = Tensor(i128, .{}, .{}, &.{1}); + const x = DimLess.splat(42); + try std.testing.expect(x.eq(42)); + try std.testing.expect(x.gt(10)); +} + +// ─── Vector / Tensor tests ──────────────────────────────────────────────── + +test "Vector initiate" { + const Meter4 = Tensor(f32, .{ .L = 1 }, .{}, &.{4}); + const m = Meter4.splat(1); + try std.testing.expect(m.data[0] == 1); + try std.testing.expect(m.data[3] == 1); +} + +test "Vector format" { + const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{3}); + const KgMeterPerSecond = Tensor(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3}); + + const accel = MeterPerSecondSq.splat(9.81); + const momentum = KgMeterPerSecond{ .data = .{ 43, 0, 11 } }; + + var buf: [64]u8 = undefined; + var res = try std.fmt.bufPrint(&buf, "{d}", .{accel}); + try std.testing.expectEqualStrings("(9.81, 9.81, 9.81)m.ns⁻²", res); + + res = try std.fmt.bufPrint(&buf, "{d:.2}", .{momentum}); + try std.testing.expectEqualStrings("(43.00, 0.00, 11.00)m.kg.s⁻¹", res); +} + +test "Vector Vec3 Init and Basic Arithmetic" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + + const v_zero = Meter3.zero; + try std.testing.expectEqual(0, v_zero.data[0]); + try std.testing.expectEqual(0, v_zero.data[2]); + + const v_one = Meter3.one; + try std.testing.expectEqual(1, v_one.data[0]); + + const v_def = Meter3.splat(5); + try std.testing.expectEqual(5, v_def.data[2]); + + const v1 = Meter3{ .data = .{ 10, 20, 30 } }; + const v2 = Meter3{ .data = .{ 2, 4, 6 } }; + + const added = v1.add(v2); + try std.testing.expectEqual(12, added.data[0]); + try std.testing.expectEqual(24, added.data[1]); + try std.testing.expectEqual(36, added.data[2]); + + const subbed = v1.sub(v2); + try std.testing.expectEqual(8, subbed.data[0]); + try std.testing.expectEqual(16, subbed.data[1]); + try std.testing.expectEqual(24, subbed.data[2]); + + const neg = v1.negate(); + try std.testing.expectEqual(-10, neg.data[0]); + try std.testing.expectEqual(-20, neg.data[1]); + try std.testing.expectEqual(-30, neg.data[2]); +} + +test "Vector Kinematics (scalar mul/div broadcast)" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const Second1 = Tensor(i32, .{ .T = 1 }, .{}, &.{1}); + + const pos = Meter3{ .data = .{ 100, 200, 300 } }; + const time = Second1.splat(10); + + const vel = pos.div(time); + try std.testing.expectEqual(10, vel.data[0]); + try std.testing.expectEqual(20, vel.data[1]); + try std.testing.expectEqual(30, vel.data[2]); + try std.testing.expectEqual(1, @TypeOf(vel).dims.get(.L)); + try std.testing.expectEqual(-1, @TypeOf(vel).dims.get(.T)); + + const new_pos = vel.mul(time); + try std.testing.expectEqual(100, new_pos.data[0]); + try std.testing.expectEqual(0, @TypeOf(new_pos).dims.get(.T)); +} + +test "Vector Element-wise Math and Scaling" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + + const v1 = Meter3{ .data = .{ 10, 20, 30 } }; + const v2 = Meter3{ .data = .{ 2, 5, 10 } }; + const dv = v1.div(v2); + try std.testing.expectEqual(5, dv.data[0]); + try std.testing.expectEqual(4, dv.data[1]); + try std.testing.expectEqual(3, dv.data[2]); + try std.testing.expectEqual(0, @TypeOf(dv).dims.get(.L)); +} + +test "Vector Conversions" { + const KiloMeter3 = Tensor(i32, .{ .L = 1 }, .{ .L = .k }, &.{3}); + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + + const v_km = KiloMeter3{ .data = .{ 1, 2, 3 } }; + const v_m = v_km.to(Meter3); + try std.testing.expectEqual(1000, v_m.data[0]); + try std.testing.expectEqual(2000, v_m.data[1]); + try std.testing.expectEqual(3000, v_m.data[2]); + try std.testing.expectEqual(UnitScale.none, @TypeOf(v_m).scales.get(.L)); +} + +test "Vector Length" { + const MeterInt3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const MeterFloat3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + + const v_int = MeterInt3{ .data = .{ 3, 4, 0 } }; + try std.testing.expectEqual(25, v_int.lengthSqr()); + try std.testing.expectEqual(5, v_int.length()); + + const v_float = MeterFloat3{ .data = .{ 3.0, 4.0, 0.0 } }; + try std.testing.expectApproxEqAbs(@as(f32, 25.0), v_float.lengthSqr(), 1e-4); + try std.testing.expectApproxEqAbs(@as(f32, 5.0), v_float.length(), 1e-4); +} + +test "Vector Comparisons" { + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter3 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{3}); + + const v1 = Meter3{ .data = .{ 1000.0, 500.0, 0.0 } }; + const v2 = KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } }; + const v3 = KiloMeter3{ .data = .{ 1.0, 0.6, 0.0 } }; + + try std.testing.expect(v1.eqAll(v2)); + try std.testing.expect(v1.neAll(v3)); + + const higher = v3.gt(v1); + try std.testing.expectEqual(false, higher[0]); + try std.testing.expectEqual(true, higher[1]); + try std.testing.expectEqual(false, higher[2]); + + const equal = v3.eq(v1); + try std.testing.expectEqual(true, equal[0]); + try std.testing.expectEqual(false, equal[1]); + try std.testing.expectEqual(true, equal[2]); + + const low_eq = v1.lte(v3); + try std.testing.expect(low_eq[0] and low_eq[1] and low_eq[2]); +} + +test "Vector vs Scalar broadcast comparison" { + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); + + const positions = Meter3{ .data = .{ 500.0, 1200.0, 3000.0 } }; + const threshold = KiloMeter1.splat(1); // 1 km = 1000 m + + const exceeded = positions.gt(threshold); + try std.testing.expectEqual(false, exceeded[0]); + try std.testing.expectEqual(true, exceeded[1]); + try std.testing.expectEqual(true, exceeded[2]); + + const Meter1 = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + const exact = positions.eq(Meter1.splat(500)); + try std.testing.expect(exact[0] == true); + try std.testing.expect(exact[1] == false); +} + +test "Vector contract — dot product (rank-1 × rank-1)" { + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const Newton3 = Tensor(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3}); + + const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } }; + const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } }; + + const work = force.contract(pos, 0, 0); + try std.testing.expectEqual(50.0, work.data[0]); + try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M)); + try std.testing.expectEqual(2, @TypeOf(work).dims.get(.L)); + try std.testing.expectEqual(-2, @TypeOf(work).dims.get(.T)); +} + +test "Vector contract — matrix multiply (rank-2 × rank-2)" { + const A = Tensor(f32, .{}, .{}, &.{ 2, 3 }); + const B = Tensor(f32, .{}, .{}, &.{ 3, 2 }); + + const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } }; + const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } }; + + const c = a.contract(b, 1, 0); + try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]); + try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]); + try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 0 })]); + try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 1 })]); +} + +test "Vector Abs, Pow, Sqrt and Product" { + const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + + const v1 = Meter3{ .data = .{ -2.0, 3.0, -4.0 } }; + const v_abs = v1.abs(); + try std.testing.expectEqual(2.0, v_abs.data[0]); + try std.testing.expectEqual(4.0, v_abs.data[2]); + + const vol = v_abs.product(); + try std.testing.expectEqual(24.0, vol.data[0]); + try std.testing.expectEqual(3, @TypeOf(vol).dims.get(.L)); + + const area_vec = v_abs.pow(2); + try std.testing.expectEqual(4.0, area_vec.data[0]); + try std.testing.expectEqual(16.0, area_vec.data[2]); + try std.testing.expectEqual(2, @TypeOf(area_vec).dims.get(.L)); + + const sqrted = area_vec.sqrt(); + try std.testing.expectEqual(2, sqrted.data[0]); + try std.testing.expectEqual(4, sqrted.data[2]); + try std.testing.expectEqual(1, @TypeOf(sqrted).dims.get(.L)); +} + +test "Vector mul comptime_int broadcast" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const v = Meter3{ .data = .{ 1, 2, 3 } }; + const scaled = v.mul(10); + try std.testing.expectEqual(10, scaled.data[0]); + try std.testing.expectEqual(20, scaled.data[1]); + try std.testing.expectEqual(30, scaled.data[2]); + try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); +} + +test "Vector mul comptime_float broadcast" { + const MeterF3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } }; + const scaled = v.mul(0.5); + try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6); + try std.testing.expectApproxEqAbs(1.0, scaled.data[1], 1e-6); + try std.testing.expectApproxEqAbs(2.0, scaled.data[2], 1e-6); + try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L)); +} + +test "Vector div comptime_int broadcast" { + const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const v = Meter3{ .data = .{ 10, 20, 30 } }; + const halved = v.div(2); + try std.testing.expectEqual(5, halved.data[0]); + try std.testing.expectEqual(10, halved.data[1]); + try std.testing.expectEqual(15, halved.data[2]); + try std.testing.expectEqual(1, @TypeOf(halved).dims.get(.L)); +} + +test "Vector div comptime_float broadcast" { + const MeterF3 = Tensor(f64, .{ .L = 1 }, .{}, &.{3}); + const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } }; + const r = v.div(3.0); + try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9); + try std.testing.expectApproxEqAbs(2.0, r.data[1], 1e-9); + try std.testing.expectApproxEqAbs(1.0, r.data[2], 1e-9); +} + +test "Vector eq broadcast on dimensionless" { + const DimLess3 = Tensor(i32, .{}, .{}, &.{3}); + const v = DimLess3{ .data = .{ 1, 2, 3 } }; + + const eq_res = v.eq(2); + try std.testing.expectEqual(false, eq_res[0]); + try std.testing.expectEqual(true, eq_res[1]); + try std.testing.expectEqual(false, eq_res[2]); + + const gt_res = v.gt(1); + try std.testing.expectEqual(false, gt_res[0]); + try std.testing.expectEqual(true, gt_res[1]); + try std.testing.expectEqual(true, gt_res[2]); +} + +test "Tensor idx helper and matrix access" { + const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 }); + var m: Mat3x3 = Mat3x3.zero; + m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0; + m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0; + m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0; + + try std.testing.expectEqual(1.0, m.data[0]); + try std.testing.expectEqual(2.0, m.data[4]); + try std.testing.expectEqual(3.0, m.data[8]); + try std.testing.expectEqual(0.0, m.data[1]); +} + +test "Tensor strides_arr correctness" { + const T1 = Tensor(f32, .{}, .{}, &.{3}); + const T2 = Tensor(f32, .{}, .{}, &.{ 3, 4 }); + const T3 = Tensor(f32, .{}, .{}, &.{ 2, 3, 4 }); + + try std.testing.expectEqual(1, T1.strides_arr[0]); + try std.testing.expectEqual(4, T2.strides_arr[0]); + try std.testing.expectEqual(1, T2.strides_arr[1]); + try std.testing.expectEqual(12, T3.strides_arr[0]); + try std.testing.expectEqual(4, T3.strides_arr[1]); + try std.testing.expectEqual(1, T3.strides_arr[2]); +} diff --git a/src/benchmark.zig b/src/benchmark.zig index 7da3e6f..c6fb592 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -1,7 +1,6 @@ const std = @import("std"); const Io = std.Io; -const Scalar = @import("Quantity.zig").Scalar; -const Vector = @import("Quantity.zig").Vector; +const Tensor = @import("Tensor.zig").Tensor; var io: Io = undefined; pub fn main(init: std.process.Init) !void { @@ -21,15 +20,17 @@ pub fn main(init: std.process.Init) !void { // try stdout_writer.flush(); // try vectorSIMDvsNative(i128, &stdout_writer.interface); // try stdout_writer.flush(); - - try bench_Scalar(&stdout_writer.interface); - try stdout_writer.flush(); + // + // try bench_Scalar(&stdout_writer.interface); + // try stdout_writer.flush(); try bench_vsNative(&stdout_writer.interface); try stdout_writer.flush(); - try bench_crossTypeVsNative(&stdout_writer.interface); + // try bench_crossTypeVsNative(&stdout_writer.interface); try stdout_writer.flush(); try bench_Vector(&stdout_writer.interface); try stdout_writer.flush(); + try bench_HighDimTensor(&stdout_writer.interface); + try stdout_writer.flush(); } fn getTime() Io.Timestamp { @@ -97,9 +98,9 @@ fn bench_Scalar(writer: *std.Io.Writer) !void { comptime var tidx: usize = 0; inline for (Types, TNames) |T, tname| { - const M = Scalar(T, .{ .L = 1 }, .{}); - const KM = Scalar(T, .{ .L = 1 }, .{ .L = .k }); - const S = Scalar(T, .{ .T = 1 }, .{}); + const M = Tensor(T, .{ .L = 1 }, .{}, &.{1}); + const KM = Tensor(T, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const S = Tensor(T, .{ .T = 1 }, .{}, &.{1}); inline for (Ops, 0..) |op_name, oidx| { var samples: [SAMPLES]f64 = undefined; @@ -170,7 +171,7 @@ fn bench_Scalar(writer: *std.Io.Writer) !void { fn bench_vsNative(writer: *std.Io.Writer) !void { const ITERS: usize = 100_000; - const SAMPLES: usize = 5; + const SAMPLES: usize = 100; const getValT = struct { fn f(comptime TT: type, i: usize) TT { @@ -179,8 +180,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { } }.f; - const Types = .{ i32, i64, i128, f32, f64 }; - const TNames = .{ "i32", "i64", "i128", "f32", "f64" }; + const Types = .{ f64, i64, i128, f32, f64 }; + const TNames = .{ "f64", "i64", "i128", "f32", "f64" }; // Expanded Ops to match bench_Scalar const Ops = .{ "add", "sub", "mul", "div", "abs", "eq", "gt" }; @@ -188,35 +189,34 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { \\ \\ Scalar vs Native Overhead Analysis \\ - \\┌───────────┬──────┬───────────┬───────────┬───────────┐ - \\│ Operation │ Type │ Native │ Scalar │ Slowdown │ - \\├───────────┼──────┼───────────┼───────────┼───────────┤ + \\┌───────────┬──────┬───────────┬───────────┬───────────┬───────────────────────┐ + \\│ Operation │ Type │ Native │ @Vector │ Tensor{{1}} │ Slowdown Nat | Vec │ + \\├───────────┼──────┼───────────┼───────────┼───────────┼───────────────────────┤ \\ , .{}); inline for (Ops, 0..) |op_name, j| { inline for (Types, 0..) |T, tidx| { var native_total_ns: f64 = 0; - var quantity_total_ns: f64 = 0; + var vector_total_ns: f64 = 0; + var tensor_total_ns: f64 = 0; - const M = Scalar(T, .{ .L = 1 }, .{}); - const S = Scalar(T, .{ .T = 1 }, .{}); + const M = Tensor(T, .{}, .{}, &.{1}); std.mem.doNotOptimizeAway({ for (0..SAMPLES) |_| { // --- 1. Benchmark Native --- const n_start = getTime(); - for (0..ITERS) |i| { - const a = getValT(T, i); - const b = getValT(T, 2); - + const a = getValT(T, 10); + const b = getValT(T, 2); + for (0..ITERS) |_| { // Native logic branch _ = if (comptime std.mem.eql(u8, op_name, "add")) - a + b + if (comptime @typeInfo(T) == .int) a +| b else a + b else if (comptime std.mem.eql(u8, op_name, "sub")) - a - b + if (comptime @typeInfo(T) == .int) a -| b else a - b else if (comptime std.mem.eql(u8, op_name, "mul")) - a * b + if (comptime @typeInfo(T) == .int) a *| b else a * b else if (comptime std.mem.eql(u8, op_name, "div")) if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b else if (comptime std.mem.eql(u8, op_name, "abs")) @@ -231,12 +231,36 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { const n_end = getTime(); native_total_ns += @as(f64, @floatFromInt(n_start.durationTo(n_end).toNanoseconds())); + const v_start = getTime(); + const va = getValT(T, 10); + const vb = getValT(T, 2); + for (0..ITERS) |_| { + // Native logic branch + _ = if (comptime std.mem.eql(u8, op_name, "add")) + if (comptime @typeInfo(T) == .int) va +| vb else va + vb + else if (comptime std.mem.eql(u8, op_name, "sub")) + if (comptime @typeInfo(T) == .int) va -| vb else va - vb + else if (comptime std.mem.eql(u8, op_name, "mul")) + if (comptime @typeInfo(T) == .int) va *| vb else va * vb + else if (comptime std.mem.eql(u8, op_name, "div")) + if (comptime @typeInfo(T) == .int) @divTrunc(va, vb) else va / vb + else if (comptime std.mem.eql(u8, op_name, "abs")) + if (comptime @typeInfo(T) == .int) @abs(va) else @as(T, @abs(va)) + else if (comptime std.mem.eql(u8, op_name, "eq")) + va == vb + else if (comptime std.mem.eql(u8, op_name, "gt")) + va > vb + else + unreachable; + } + const v_end = getTime(); + vector_total_ns += @as(f64, @floatFromInt(v_start.durationTo(v_end).toNanoseconds())); + // --- 2. Benchmark Scalar --- const q_start = getTime(); - for (0..ITERS) |i| { - const qa = M.splat(getValT(T, i)); - const qb = if (comptime std.mem.eql(u8, op_name, "div")) S.splat(getValT(T, 2)) else M.splat(getValT(T, 2)); - + const qa = M.splat(getValT(T, 10)); + const qb = M.splat(getValT(T, 2)); + for (0..ITERS) |_| { // Scalar logic branch _ = if (comptime std.mem.eql(u8, op_name, "add")) qa.add(qb) @@ -256,22 +280,24 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { unreachable; } const q_end = getTime(); - quantity_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds())); + tensor_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds())); } }); const avg_n = (native_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); - const avg_q = (quantity_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); - const slowdown = avg_q / avg_n; + const avg_v = (vector_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); + const avg_t = (tensor_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); + const slowdown_nt = avg_t / avg_n; + const slowdown_vt = avg_t / avg_v; - try writer.print("│ {s:<9} │ {s:<4} │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>8.2}x │\n", .{ - op_name, TNames[tidx], avg_n, avg_q, slowdown, + try writer.print("│ {s:<9} │ {s:<4} │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>8.2}x {d:>8.2}x │\n", .{ + op_name, TNames[tidx], avg_n, avg_v, avg_t, slowdown_nt, slowdown_vt, }); } - if (j != Ops.len - 1) try writer.print("├───────────┼──────┼───────────┼───────────┼───────────┤\n", .{}); + if (j != Ops.len - 1) try writer.print("├───────────┼──────┼───────────┼───────────┼───────────┼───────────────────────┤\n", .{}); } - try writer.print("└───────────┴──────┴───────────┴───────────┴───────────┘\n", .{}); + try writer.print("└───────────┴──────┴───────────┴───────────┴───────────┴───────────────────────┘\n", .{}); } fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void { @@ -321,9 +347,9 @@ fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void { var native_total_ns: f64 = 0; var quantity_total_ns: f64 = 0; - const M1 = Scalar(T1, .{ .L = 1 }, .{}); - const M2 = Scalar(T2, .{ .L = 1 }, .{}); - const S2 = Scalar(T2, .{ .T = 1 }, .{}); + const M1 = Tensor(T1, .{ .L = 1 }, .{}, &.{1}); + const M2 = Tensor(T2, .{ .L = 1 }, .{}, &.{1}); + const S2 = Tensor(T2, .{ .T = 1 }, .{}, &.{1}); std.mem.doNotOptimizeAway({ for (0..SAMPLES) |_| { @@ -429,9 +455,8 @@ fn bench_Vector(writer: *std.Io.Writer) !void { try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname }); inline for (Lengths) |len| { - const Q_base = Scalar(T, .{ .L = 1 }, .{}); - const Q_time = Scalar(T, .{ .T = 1 }, .{}); - const V = Vector(len, Q_base); + const Q_time = Tensor(T, .{ .T = 1 }, .{}, &.{1}); + const V = Tensor(T, .{ .L = 1 }, .{}, &.{len}); // cross product is only defined for len == 3 const is_cross = comptime std.mem.eql(u8, op_name, "cross"); @@ -455,10 +480,10 @@ fn bench_Vector(writer: *std.Io.Writer) !void { _ = v1.div(V.splat(getVal(T, i +% 2, 63))); } else if (comptime std.mem.eql(u8, op_name, "mulScalar")) { const s_val = Q_time.splat(getVal(T, i +% 2, 63)); - _ = v1.mulScalar(s_val); + _ = v1.mul(s_val); } else if (comptime std.mem.eql(u8, op_name, "dot")) { const v2 = V.splat(getVal(T, i +% 5, 63)); - _ = v1.dot(v2); + _ = v1.contract(v2, 0, 0); } else if (comptime std.mem.eql(u8, op_name, "cross")) { // len == 3 guaranteed by the guard above const v2 = V.splat(getVal(T, i +% 5, 63)); @@ -490,6 +515,102 @@ fn bench_Vector(writer: *std.Io.Writer) !void { try writer.print("└──────────────────┴──────┴─────────┴─────────┴─────────┴─────────┴─────────┘\n", .{}); } +fn bench_HighDimTensor(writer: *std.Io.Writer) !void { + const ITERS: usize = 5_000; + const SAMPLES: usize = 5; + + const getVal = struct { + fn f(comptime TT: type, i: usize, comptime mask: u7) TT { + const v: u8 = @as(u8, @truncate(i & @as(usize, mask))) + 1; + return if (comptime @typeInfo(TT) == .float) @floatFromInt(v) else @intCast(v); + } + }.f; + + const computeStats = struct { + fn f(samples: []f64, iters: usize) f64 { + std.mem.sort(f64, samples, {}, std.sort.asc(f64)); + const mid = samples.len / 2; + const median_ns = if (samples.len % 2 == 0) + (samples[mid - 1] + samples[mid]) / 2.0 + else + samples[mid]; + return median_ns / @as(f64, @floatFromInt(iters)); + } + }.f; + + try writer.print( + \\ + \\ High Dimension Tensor benchmark — {d} iterations, {d} samples/cell + \\ (Results in ns/op) + \\ + \\┌─────────────────┬──────┬──────────────┬──────────────┬──────────────┬──────────────┐ + \\│ Operation │ Type │ 2x2x2 │ 3x3x3 │ 4x4x4 │ 10x10x10x10 │ + \\├─────────────────┼──────┼──────────────┼──────────────┼──────────────┼──────────────┤ + \\ + , .{ ITERS, SAMPLES }); + + const Types = .{ i32, i64, f32, f64 }; + const TNames = .{ "i32", "i64", "f32", "f64" }; + + // Testing multiple structural bounds + const Shapes = .{ + &.{ 2, 2, 2 }, + &.{ 3, 3, 3 }, + &.{ 4, 4, 4 }, + &.{ 10, 10, 10, 10 }, + }; + + const Ops = .{ "add", "sub", "mulElem", "mulScalar", "abs" }; + + inline for (Ops, 0..) |op_name, o_idx| { + inline for (Types, TNames) |T, tname| { + try writer.print("│ {s:<15} │ {s:<4} │", .{ op_name, tname }); + + inline for (Shapes) |shape| { + const V = Tensor(T, .{ .L = 1 }, .{}, shape); + const Q = Tensor(T, .{ .T = 1 }, .{}, &.{1}); // For scalar broadcasting operations + + var samples: [SAMPLES]f64 = undefined; + + for (0..SAMPLES) |s_idx| { + const t_start = getTime(); + + for (0..ITERS) |i| { + std.mem.doNotOptimizeAway({ + const t1 = V.splat(getVal(T, i, 63)); + + _ = if (comptime std.mem.eql(u8, op_name, "add")) + t1.add(V.splat(getVal(T, i +% 7, 63))) + else if (comptime std.mem.eql(u8, op_name, "sub")) + t1.sub(V.splat(getVal(T, i +% 3, 63))) + else if (comptime std.mem.eql(u8, op_name, "mulElem")) + t1.mul(V.splat(getVal(T, i +% 5, 63))) + else if (comptime std.mem.eql(u8, op_name, "mulScalar")) + t1.mul(Q.splat(getVal(T, i +% 2, 63))) + else if (comptime std.mem.eql(u8, op_name, "abs")) + t1.abs() + else + unreachable; + }); + } + + const t_end = getTime(); + samples[s_idx] = @as(f64, @floatFromInt(t_start.durationTo(t_end).toNanoseconds())); + } + + const median_ns_per_op = computeStats(&samples, ITERS); + try writer.print(" {d:>12.1} │", .{median_ns_per_op}); + } + try writer.print("\n", .{}); + } + + if (o_idx < Ops.len - 1) { + try writer.print("├─────────────────┼──────┼──────────────┼──────────────┼──────────────┼──────────────┤\n", .{}); + } + } + try writer.print("└─────────────────┴──────┴──────────────┴──────────────┴──────────────┴──────────────┘\n", .{}); +} + fn vectorSIMDvsNative(comptime T: type, writer: *std.Io.Writer) !void { const iterations: u64 = 10_000; const lens = [_]u32{ 1, 2, 3, 4, 5, 10, 100, 1_000, 10_000 }; diff --git a/src/helper.zig b/src/helper.zig deleted file mode 100644 index fb104e1..0000000 --- a/src/helper.zig +++ /dev/null @@ -1,97 +0,0 @@ -const std = @import("std"); - -pub fn isInt(comptime T: type) bool { - return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int; -} - -pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void { - if (n == 0) return; - var val = n; - if (val < 0) { - try writer.writeAll("\u{207B}"); - val = -val; - } - var buf: [12]u8 = undefined; - const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return; - for (str) |c| { - const s = switch (c) { - '0' => "\u{2070}", - '1' => "\u{00B9}", - '2' => "\u{00B2}", - '3' => "\u{00B3}", - '4' => "\u{2074}", - '5' => "\u{2075}", - '6' => "\u{2076}", - '7' => "\u{2077}", - '8' => "\u{2078}", - '9' => "\u{2079}", - else => unreachable, - }; - try writer.writeAll(s); - } -} - -const Scales = @import("Scales.zig"); -const Dimensions = @import("Dimensions.zig"); -const Dimension = @import("Dimensions.zig").Dimension; - -pub fn finerScales(comptime T1: type, comptime T2: type) Scales { - const d1: Dimensions = T1.dims; - const d2: Dimensions = T2.dims; - const s1: Scales = T1.scales; - const s2: Scales = T2.scales; - comptime var out = Scales.initFill(.none); - inline for (std.enums.values(Dimension)) |dim| { - const scale1 = comptime s1.get(dim); - const scale2 = comptime s2.get(dim); - out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0) - .none - else if (comptime d1.get(dim) == 0) - scale2 - else if (comptime d2.get(dim) == 0) - scale1 - else if (comptime scale1.getFactor() > scale2.getFactor()) - scale2 - else - scale1); - } - comptime return out; -} - -// --------------------------------------------------------------------------- -// RHS normalisation helpers -// --------------------------------------------------------------------------- - -const Quantity = @import("Quantity.zig").Quantity; - -/// Returns true if `T` is a `Scalar_` type (has `dims`, `scales`, and `value`). -pub fn isScalarType(comptime T: type) bool { - return @typeInfo(T) == .@"struct" and - @hasDecl(T, "ISQUANTITY") and - @field(T, "ISQUANTITY"); -} - -/// Resolve the Scalar type that `rhs` will be treated as. -/// -/// Accepted rhs types: -/// - Any `Scalar_` type → returned as-is -/// - `comptime_int` / `comptime_float` → dimensionless `Scalar_(BaseT, {}, {})` -/// - `BaseT` (the scalar's value type) → dimensionless `Scalar_(BaseT, {}, {})` -/// -/// Everything else is a compile error, including other int/float types. -pub fn rhsQuantityType(comptime ValueType: type, N: usize, comptime RhsT: type) type { - if (comptime isScalarType(RhsT)) return RhsT; - if (comptime RhsT == comptime_int or RhsT == comptime_float or RhsT == ValueType) - return Quantity(ValueType, N, .{}, .{}); - @compileError( - "rhs must be a Scalar, " ++ @typeName(ValueType) ++ - ", comptime_int, or comptime_float; got " ++ @typeName(RhsT), - ); -} - -/// Convert `rhs` to its normalised Scalar form (see `rhsScalarType`). -pub inline fn toRhsQuantity(comptime BaseT: type, N: usize, rhs: anytype) rhsQuantityType(BaseT, N, @TypeOf(rhs)) { - if (comptime isScalarType(@TypeOf(rhs))) return rhs; - const DimLess = Quantity(BaseT, N, .{}, .{}); - return DimLess{ .data = @splat(@as(BaseT, rhs)) }; -} diff --git a/src/main.zig b/src/main.zig index 4d61ee6..33ba626 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,15 +1,13 @@ const std = @import("std"); -pub const Vector = @import("Quantity.zig").Vector; -pub const Scalar = @import("Quantity.zig").Scalar; +pub const Tensor = @import("Tensor.zig").Tensor; pub const Dimensions = @import("Dimensions.zig"); pub const Scales = @import("Scales.zig"); pub const Base = @import("Base.zig"); test { - _ = @import("Quantity.zig"); + _ = @import("Tensor.zig"); _ = @import("Dimensions.zig"); _ = @import("Scales.zig"); _ = @import("Base.zig"); - _ = @import("helper.zig"); }