Compare commits

...

4 Commits

Author SHA1 Message Date
AdrienBouvais
44aaa8a8b2 Removed all inline for (0..total) for either builtin or for loop without inline
This is to prevent giant binary for Tensor with a lot of Scalar
2026-04-27 16:11:46 +02:00
AdrienBouvais
cd954b379b Added cross to Tensor + fix benchmark 2026-04-27 15:13:15 +02:00
AdrienBouvais
16d25e7e7e Added shape comptime check for Tensor add/sub/div/mul 2026-04-27 14:45:41 +02:00
AdrienBouvais
f37a196b15 Fixed new Tensor to be everything (Scalar, Vector, Matrix and above) 2026-04-27 09:11:24 +02:00
7 changed files with 592 additions and 639 deletions

View File

@ -3,34 +3,39 @@ const std = @import("std");
// Adjust these imports to match your actual file names // Adjust these imports to match your actual file names
const Dimensions = @import("Dimensions.zig"); const Dimensions = @import("Dimensions.zig");
const Scales = @import("Scales.zig"); const Scales = @import("Scales.zig");
const Scalar = @import("Quantity.zig").Scalar; const Tensor = @import("Tensor.zig").Tensor;
fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime s: Scales.ArgOpts) type { fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime s: Scales.ArgOpts) type {
return struct { return struct {
const dims = Dimensions.init(d); pub const dims = Dimensions.init(d);
const scales = Scales.init(s); pub const scales = Scales.init(s);
/// Instantiates the constant into a specific numeric type. /// Instantiates the constant into a specific numeric type.
pub fn Of(comptime T: type) Scalar(T, d, s) { pub fn Of(comptime T: type) Tensor(T, d, s, &.{1}) {
return .{ .data = @splat(@as(T, @floatCast(val))) }; const casted_val: T = switch (@typeInfo(T)) {
.float => @floatCast(val),
.int => @intFromFloat(val),
else => @compileError("Unsupported type for PhysicalConstant"),
};
return Tensor(T, d, s, &.{1}).splat(casted_val);
} }
}; };
} }
fn BaseScalar(comptime d: Dimensions.ArgOpts) type { fn BaseScalar(comptime d: Dimensions.ArgOpts) type {
return struct { return struct {
const dims = Dimensions.init(d); pub const dims = Dimensions.init(d);
/// Creates a Scalar of this dimension using default scales. /// Creates a Scalar of this dimension using default scales.
/// Example: const V = Quantities.Velocity.Base(f32); /// Example: const V = Quantities.Velocity.Of(f32);
pub fn Of(comptime T: type) type { pub fn Of(comptime T: type) type {
return Scalar(T, d, .{}); return Tensor(T, d, .{}, &.{1});
} }
/// Creates a Scalar of this dimension using custom scales. /// Creates a Scalar of this dimension using custom scales.
/// Example: const Kmh = Quantities.Velocity.Scaled(f32, Scales.init(.{ .L = .k, .T = .hour })); /// Example: const Kmh = Quantities.Velocity.Scaled(f32, .{ .L = .k, .T = .hour });
pub fn Scaled(comptime T: type, comptime s: Scales.ArgOpts) type { pub fn Scaled(comptime T: type, comptime s: Scales.ArgOpts) type {
return Scalar(T, d, s); return Tensor(T, d, s, &.{1});
} }
}; };
} }
@ -107,7 +112,7 @@ pub const ElectricCapacitance = BaseScalar(.{ .T = 4, .L = -2, .M = -1, .I = 2 }
pub const ElectricImpedance = ElectricResistance; pub const ElectricImpedance = ElectricResistance;
pub const MagneticFlux = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .I = -1 }); pub const MagneticFlux = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .I = -1 });
pub const MagneticDensity = BaseScalar(.{ .M = 1, .T = -2, .I = -1 }); pub const MagneticDensity = BaseScalar(.{ .M = 1, .T = -2, .I = -1 });
pub const MagneticStrength = BaseScalar(.{ .L = -1, .I = 1 }); // Fixed typo from MagneticStrengh pub const MagneticStrength = BaseScalar(.{ .L = -1, .I = 1 });
pub const MagneticMoment = BaseScalar(.{ .L = 2, .I = 1 }); pub const MagneticMoment = BaseScalar(.{ .L = 2, .I = 1 });
// ========================================== // ==========================================
@ -140,7 +145,7 @@ pub const ThermalHeat = Energy;
pub const ThermalWork = Energy; pub const ThermalWork = Energy;
pub const ThermalCapacity = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .Tr = -1 }); pub const ThermalCapacity = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .Tr = -1 });
pub const ThermalCapacityPerMass = BaseScalar(.{ .L = 2, .T = -2, .Tr = -1 }); pub const ThermalCapacityPerMass = BaseScalar(.{ .L = 2, .T = -2, .Tr = -1 });
pub const ThermalFluxDensity = BaseScalar(.{ .M = 1, .T = -3 }); // Fixed typo from ThermalluxDensity pub const ThermalFluxDensity = BaseScalar(.{ .M = 1, .T = -3 });
pub const ThermalConductance = BaseScalar(.{ .M = 1, .L = 2, .T = -3, .Tr = -1 }); pub const ThermalConductance = BaseScalar(.{ .M = 1, .L = 2, .T = -3, .Tr = -1 });
pub const ThermalConductivity = BaseScalar(.{ .M = 1, .L = 1, .T = -3, .Tr = -1 }); pub const ThermalConductivity = BaseScalar(.{ .M = 1, .L = 1, .T = -3, .Tr = -1 });
pub const ThermalResistance = BaseScalar(.{ .M = -1, .L = -2, .T = 3, .Tr = 1 }); pub const ThermalResistance = BaseScalar(.{ .M = -1, .L = -2, .T = 3, .Tr = 1 });
@ -152,20 +157,24 @@ pub const ThermalEntropy = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .Tr = -1 });
// ========================================== // ==========================================
pub const Frequency = BaseScalar(.{ .T = -1 }); pub const Frequency = BaseScalar(.{ .T = -1 });
pub const Viscosity = BaseScalar(.{ .M = 1, .L = -1, .T = -1 }); pub const Viscosity = BaseScalar(.{ .M = 1, .L = -1, .T = -1 });
pub const SurfaceTension = BaseScalar(.{ .M = 1, .T = -2 }); // Corrected from MT-2a pub const SurfaceTension = BaseScalar(.{ .M = 1, .T = -2 });
// ==========================================
// Tests
// ==========================================
test "BaseQuantities - Core dimensions instantiation" { test "BaseQuantities - Core dimensions instantiation" {
// Basic types via generic wrappers // Basic types via generic wrappers
const M = Meter.Of(f32); const M = Meter.Of(f32);
const distance = M.splat(100); const distance = M.splat(100);
try std.testing.expectEqual(100.0, distance.value()); try std.testing.expectEqual(100.0, distance.data[0]);
try std.testing.expectEqual(1, M.dims.get(.L)); try std.testing.expectEqual(1, M.dims.get(.L));
try std.testing.expectEqual(0, M.dims.get(.T)); try std.testing.expectEqual(0, M.dims.get(.T));
// Test specific scale variants // Test specific scale variants
const Kmh = Speed.Scaled(f32, .{ .L = .k, .T = .hour }); const Kmh = Speed.Scaled(f32, .{ .L = .k, .T = .hour });
const speed = Kmh.splat(120); const speed = Kmh.splat(120);
try std.testing.expectEqual(120.0, speed.value()); try std.testing.expectEqual(120.0, speed.data[0]);
try std.testing.expectEqual(.k, @TypeOf(speed).scales.get(.L)); try std.testing.expectEqual(.k, @TypeOf(speed).scales.get(.L));
try std.testing.expectEqual(.hour, @TypeOf(speed).scales.get(.T)); try std.testing.expectEqual(.hour, @TypeOf(speed).scales.get(.T));
} }
@ -176,12 +185,12 @@ test "BaseQuantities - Kinematics equations" {
// Velocity = Distance / Time // Velocity = Distance / Time
const v = d.div(t); const v = d.div(t);
try std.testing.expectEqual(25.0, v.value()); try std.testing.expectEqual(25.0, v.data[0]);
try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims)); try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
// Acceleration = Velocity / Time // Acceleration = Velocity / Time
const a = v.div(t); const a = v.div(t);
try std.testing.expectEqual(12.5, a.value()); try std.testing.expectEqual(12.5, a.data[0]);
try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims)); try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
} }
@ -193,13 +202,13 @@ test "BaseQuantities - Dynamics (Force and Work)" {
// Force = mass * acceleration // Force = mass * acceleration
const f = m.mul(a); const f = m.mul(a);
try std.testing.expectEqual(98, f.value()); try std.testing.expectEqual(98, f.data[0]);
try std.testing.expect(Force.dims.eql(@TypeOf(f).dims)); try std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
// Energy (Work) = Force * distance // Energy (Work) = Force * distance
const distance = Meter.Of(f32).splat(5.0); const distance = Meter.Of(f32).splat(5.0);
const energy = f.mul(distance); const energy = f.mul(distance);
try std.testing.expectEqual(490, energy.value()); try std.testing.expectEqual(490, energy.data[0]);
try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims)); try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
} }
@ -209,26 +218,26 @@ test "BaseQuantities - Electric combinations" {
// Charge = Current * time // Charge = Current * time
const charge = current.mul(time); const charge = current.mul(time);
try std.testing.expectEqual(6.0, charge.value()); try std.testing.expectEqual(6.0, charge.data[0]);
try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims)); try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
} }
test "Constants - Initialization and dimension checks" { test "Constants - Initialization and dimension checks" {
// Speed of Light // Speed of Light
const c = Constants.SpeedOfLight.Of(f64); const c = Constants.SpeedOfLight.Of(f64);
try std.testing.expectEqual(299792458.0, c.value()); try std.testing.expectEqual(299792458.0, c.data[0]);
try std.testing.expectEqual(1, @TypeOf(c).dims.get(.L)); try std.testing.expectEqual(1, @TypeOf(c).dims.get(.L));
try std.testing.expectEqual(-1, @TypeOf(c).dims.get(.T)); try std.testing.expectEqual(-1, @TypeOf(c).dims.get(.T));
// Electron Mass (verifying scale as well) // Electron Mass (verifying scale as well)
const me = Constants.ElectronMass.Of(f64); const me = Constants.ElectronMass.Of(f64);
try std.testing.expectEqual(9.1093837139e-31, me.value()); try std.testing.expectEqual(9.1093837139e-31, me.data[0]);
try std.testing.expectEqual(1, @TypeOf(me).dims.get(.M)); try std.testing.expectEqual(1, @TypeOf(me).dims.get(.M));
try std.testing.expectEqual(.k, @TypeOf(me).scales.get(.M)); // Should be scaled to kg try std.testing.expectEqual(.k, @TypeOf(me).scales.get(.M)); // Should be scaled to kg
// Boltzmann Constant (Complex derived dimensions) // Boltzmann Constant (Complex derived dimensions)
const kb = Constants.Boltzmann.Of(f64); const kb = Constants.Boltzmann.Of(f64);
try std.testing.expectEqual(1.380649e-23, kb.value()); try std.testing.expectEqual(1.380649e-23, kb.data[0]);
try std.testing.expectEqual(1, @TypeOf(kb).dims.get(.M)); try std.testing.expectEqual(1, @TypeOf(kb).dims.get(.M));
try std.testing.expectEqual(2, @TypeOf(kb).dims.get(.L)); try std.testing.expectEqual(2, @TypeOf(kb).dims.get(.L));
try std.testing.expectEqual(-2, @TypeOf(kb).dims.get(.T)); try std.testing.expectEqual(-2, @TypeOf(kb).dims.get(.T));
@ -237,7 +246,7 @@ test "Constants - Initialization and dimension checks" {
// Vacuum Permittivity // Vacuum Permittivity
const eps0 = Constants.VacuumPermittivity.Of(f64); const eps0 = Constants.VacuumPermittivity.Of(f64);
try std.testing.expectEqual(8.8541878188e-12, eps0.value()); try std.testing.expectEqual(8.8541878188e-12, eps0.data[0]);
try std.testing.expectEqual(-1, @TypeOf(eps0).dims.get(.M)); try std.testing.expectEqual(-1, @TypeOf(eps0).dims.get(.M));
try std.testing.expectEqual(-3, @TypeOf(eps0).dims.get(.L)); try std.testing.expectEqual(-3, @TypeOf(eps0).dims.get(.L));
try std.testing.expectEqual(4, @TypeOf(eps0).dims.get(.T)); try std.testing.expectEqual(4, @TypeOf(eps0).dims.get(.T));
@ -245,7 +254,7 @@ test "Constants - Initialization and dimension checks" {
// Fine Structure Constant (Dimensionless) // Fine Structure Constant (Dimensionless)
const alpha = Constants.FineStructure.Of(f64); const alpha = Constants.FineStructure.Of(f64);
try std.testing.expectEqual(0.0072973525643, alpha.value()); try std.testing.expectEqual(0.0072973525643, alpha.data[0]);
try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.M)); try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.M));
try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.L)); try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.L));
} }

View File

@ -51,17 +51,17 @@ data: std.EnumArray(Dimension, comptime_int),
/// Unspecified dimensions default to 0. /// Unspecified dimensions default to 0.
pub fn init(comptime init_val: ArgOpts) Self { pub fn init(comptime init_val: ArgOpts) Self {
var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) }; var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) };
inline for (std.meta.fields(@TypeOf(init_val))) |f| for (std.meta.fields(@TypeOf(init_val))) |f|
s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
return s; return s;
} }
pub fn initFill(comptime val: comptime_int) Self { pub fn initFill(comptime val: comptime_int) Self {
return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) }; comptime return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) };
} }
pub fn get(comptime self: Self, comptime key: Dimension) comptime_int { pub fn get(comptime self: Self, comptime key: Dimension) comptime_int {
return self.data.get(key); comptime return self.data.get(key);
} }
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void { pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void {
@ -70,40 +70,40 @@ pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void
pub fn argsOpt(self: Self) ArgOpts { pub fn argsOpt(self: Self) ArgOpts {
var args: ArgOpts = undefined; var args: ArgOpts = undefined;
inline for (std.enums.values(Dimension)) |d| for (std.enums.values(Dimension)) |d|
@field(args, @tagName(d)) = self.get(d); @field(args, @tagName(d)) = self.get(d);
return args; comptime return args;
} }
/// Add exponents component-wise. Used internally by `mul`. /// Add exponents component-wise. Used internally by `mul`.
pub fn add(comptime a: Self, comptime b: Self) Self { pub fn add(comptime a: Self, comptime b: Self) Self {
var result = Self.initFill(0); var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d| for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) + b.get(d)); result.set(d, a.get(d) + b.get(d));
return result; comptime return result;
} }
/// Subtract exponents component-wise. Used internally by `div`. /// Subtract exponents component-wise. Used internally by `div`.
pub fn sub(comptime a: Self, comptime b: Self) Self { pub fn sub(comptime a: Self, comptime b: Self) Self {
var result = Self.initFill(0); var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d| for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) - b.get(d)); result.set(d, a.get(d) - b.get(d));
return result; comptime return result;
} }
/// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar. /// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar.
pub fn scale(comptime a: Self, comptime exp: comptime_int) Self { pub fn scale(comptime a: Self, comptime exp: comptime_int) Self {
var result = Self.initFill(0); var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d| for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) * exp); result.set(d, a.get(d) * exp);
return result; comptime return result;
} }
pub fn div(comptime a: Self, comptime exp: comptime_int) Self { pub fn div(comptime a: Self, comptime exp: comptime_int) Self {
var result = Self.initFill(0); var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d| inline for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) / exp); result.set(d, a.get(d) / exp);
return result; comptime return result;
} }
/// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`. /// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`.

View File

@ -1,5 +1,4 @@
const std = @import("std"); const std = @import("std");
const hlp = @import("helper.zig");
const Dimensions = @import("Dimensions.zig"); const Dimensions = @import("Dimensions.zig");
const Dimension = @import("Dimensions.zig").Dimension; const Dimension = @import("Dimensions.zig").Dimension;
@ -66,7 +65,7 @@ pub const UnitScale = enum(isize) {
} }
pub inline fn getFactor(self: @This()) comptime_float { pub inline fn getFactor(self: @This()) comptime_float {
return comptime switch (self) { comptime return switch (self) {
// Standard SI Exponents // Standard SI Exponents
inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))), inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))),
@ -84,7 +83,7 @@ pub const UnitScale = enum(isize) {
inline .lb => 453.59237, inline .lb => 453.59237,
inline .st => 6350.29318, inline .st => 6350.29318,
inline else => @floatFromInt(@intFromEnum(self)), else => @floatFromInt(@intFromEnum(self)),
}; };
} }
}; };
@ -99,16 +98,16 @@ data: std.EnumArray(Dimension, UnitScale),
pub fn init(comptime init_val: ArgOpts) Self { pub fn init(comptime init_val: ArgOpts) Self {
comptime var s = Self{ .data = std.EnumArray(Dimension, UnitScale).initFill(.none) }; comptime var s = Self{ .data = std.EnumArray(Dimension, UnitScale).initFill(.none) };
inline for (std.meta.fields(@TypeOf(init_val))) |f| { inline for (std.meta.fields(@TypeOf(init_val))) |f| {
if (comptime hlp.isInt(@TypeOf(@field(init_val, f.name)))) if (comptime @typeInfo(@TypeOf(@field(init_val, f.name))) == .comptime_int)
s.data.set(@field(Dimension, f.name), @enumFromInt(@field(init_val, f.name))) s.data.set(@field(Dimension, f.name), @enumFromInt(@field(init_val, f.name)))
else else
s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
} }
return s; return comptime s;
} }
pub fn initFill(comptime val: UnitScale) Self { pub fn initFill(comptime val: UnitScale) Self {
return comptime .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; comptime return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) };
} }
pub fn get(comptime self: Self, comptime key: Dimension) UnitScale { pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
@ -116,7 +115,7 @@ pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
} }
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void { pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void {
comptime self.data.set(key, val); self.data.set(key, val);
} }
pub fn argsOpt(self: Self) ArgOpts { pub fn argsOpt(self: Self) ArgOpts {
@ -145,5 +144,5 @@ pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float
factor /= base; factor /= base;
} }
} }
return comptime factor; comptime return factor;
} }

View File

@ -1,47 +1,57 @@
const std = @import("std"); const std = @import("std");
const hlp = @import("helper.zig");
const Scales = @import("Scales.zig"); const Scales = @import("Scales.zig");
const UnitScale = Scales.UnitScale; const UnitScale = Scales.UnitScale;
const Dimensions = @import("Dimensions.zig"); const Dimensions = @import("Dimensions.zig");
const Dimension = Dimensions.Dimension; const Dimension = Dimensions.Dimension;
// //
// Comptime shape utilities // Comptime utilities
// //
pub fn shapeTotal(comptime shape: []const usize) usize { pub fn shapeTotal(comptime shape: []const comptime_int) usize {
var t: usize = 1; var t: comptime_int = 1;
for (shape) |s| t *= s; for (shape) |s| t *= s;
return t; return t;
} }
/// Check if two shapes are strictly identical.
pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool {
if (a.len != b.len) return false;
for (a, 0..) |v, i|
if (v != b[i]) return false;
return true;
}
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). /// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
/// e.g. shape {3, 4} strides {4, 1} /// e.g. shape {3, 4} strides {4, 1}
/// shape {2, 3, 4} strides {12, 4, 1} /// shape {2, 3, 4} strides {12, 4, 1}
pub fn shapeStrides(comptime shape: []const usize) [shape.len]usize { pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int {
var st: [shape.len]usize = undefined; var st: [shape.len]comptime_int = undefined;
if (shape.len == 0) return st; if (shape.len == 0) return st;
st[shape.len - 1] = 1; st[shape.len - 1] = 1;
if (shape.len > 1) { if (shape.len > 1) {
var i: usize = shape.len - 1; var i: comptime_int = shape.len - 1;
while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i]; while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i];
} }
return st; return st;
} }
/// Return a copy of `shape` with the element at `axis` removed. /// Return a copy of `shape` with the element at `axis` removed.
pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [shape.len - 1]usize { pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int {
var out: [shape.len - 1]usize = undefined; var out: [shape.len - 1]comptime_int = undefined;
var j: usize = 0; var j: comptime_int = 0;
for (shape, 0..) |v, i| { for (shape, 0..) |v, i| {
if (i != axis) { out[j] = v; j += 1; } if (i != axis) {
out[j] = v;
j += 1;
}
} }
return out; return out;
} }
/// Concatenate two compile-time slices. /// Concatenate two compile-time slices.
pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b.len]usize { pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int {
var out: [a.len + b.len]usize = undefined; var out: [a.len + b.len]comptime_int = undefined;
for (a, 0..) |v, i| out[i] = v; for (a, 0..) |v, i| out[i] = v;
for (b, 0..) |v, i| out[a.len + i] = v; for (b, 0..) |v, i| out[a.len + i] = v;
return out; return out;
@ -50,11 +60,11 @@ pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b
/// Decode a flat row-major index into N-D coordinates. /// Decode a flat row-major index into N-D coordinates.
/// Called only in comptime contexts (all arguments are comptime). /// Called only in comptime contexts (all arguments are comptime).
pub fn decodeFlatCoords( pub fn decodeFlatCoords(
comptime flat: usize, comptime flat: comptime_int,
comptime n: usize, comptime n: comptime_int,
comptime strd: [n]usize, comptime strd: [n]comptime_int,
) [n]usize { ) [n]usize {
var coords: [n]usize = undefined; var coords: [n]comptime_int = undefined;
var tmp = flat; var tmp = flat;
for (0..n) |i| { for (0..n) |i| {
coords[i] = if (strd[i] == 0) 0 else tmp / strd[i]; coords[i] = if (strd[i] == 0) 0 else tmp / strd[i];
@ -96,22 +106,48 @@ pub fn insertAxis(
return out; return out;
} }
fn isInt(comptime T: type) bool {
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
}
fn finerScales(comptime T1: type, comptime T2: type) Scales {
const d1: Dimensions = T1.dims;
const d2: Dimensions = T2.dims;
const s1: Scales = T1.scales;
const s2: Scales = T2.scales;
comptime var out = Scales.initFill(.none);
for (std.enums.values(Dimension)) |dim| {
const scale1 = comptime s1.get(dim);
const scale2 = comptime s2.get(dim);
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
.none
else if (comptime d1.get(dim) == 0)
scale2
else if (comptime d2.get(dim) == 0)
scale1
else if (comptime scale1.getFactor() > scale2.getFactor())
scale2
else
scale1);
}
comptime return out;
}
// //
// File-scope RHS normalisation helpers // File-scope RHS normalisation helpers
//
// Any bare comptime_int / comptime_float / runtime T used as an arithmetic
// or comparison RHS is wrapped into a dimensionless Tensor of shape {1}.
// Actual Tensor types are passed through unchanged.
// //
fn isTensor(comptime Rhs: type) bool {
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
}
fn RhsTensorType(comptime T: type, comptime Rhs: type) type { fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
if (@hasDecl(Rhs, "ISTENSOR")) return Rhs; if (comptime isTensor(Rhs)) return Rhs;
return Tensor(T, .{}, .{}, &.{1}); return Tensor(T, .{}, .{}, &.{1});
} }
fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
const Rhs = @TypeOf(r); const Rhs = @TypeOf(r);
if (comptime @hasDecl(Rhs, "ISTENSOR")) return r; if (comptime isTensor(Rhs)) return r;
const scalar: T = switch (comptime @typeInfo(Rhs)) { const scalar: T = switch (comptime @typeInfo(Rhs)) {
.comptime_int => switch (comptime @typeInfo(T)) { .comptime_int => switch (comptime @typeInfo(T)) {
.float => @as(T, @floatFromInt(r)), .float => @as(T, @floatFromInt(r)),
@ -134,55 +170,44 @@ fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} }; return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} };
} }
// pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
// Tensor unified dimensioned ND type. if (n == 0) return;
// var val = n;
// T : element numeric type (f32, f64, i32, i128, ) if (val < 0) {
// d_opt : SI dimension exponents try writer.writeAll("\u{207B}");
// s_opt : unit scales val = -val;
// shape_ : compile-time shape }
// &.{1} scalar var buf: [12]u8 = undefined;
// &.{3} 3-vector const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return;
// &.{4, 4} 4×4 matrix for (str) |c| {
// &.{3, 3, 3} 3D field const s = switch (c) {
// '0' => "\u{2070}",
// Storage: flat @Vector(total, T) where total = product(shape_). '1' => "\u{00B9}",
// All arithmetic operates on the flat vector directly SIMD wherever possible. '2' => "\u{00B2}",
// '3' => "\u{00B3}",
// Shape-related comptime constants exposed on every Tensor type: '4' => "\u{2074}",
// dims : Dimensions SI exponent struct '5' => "\u{2075}",
// scales : Scales unit scale struct '6' => "\u{2076}",
// shape : []const usize '7' => "\u{2077}",
// rank : usize = shape.len '8' => "\u{2078}",
// total : usize = product(shape) '9' => "\u{2079}",
// strides_arr : [rank]usize row-major strides else => unreachable,
// };
// Index helper: try writer.writeAll(s);
// Tensor.idx(.{row, col}) flat index (comptime, no runtime cost) }
// }
// GPU readiness:
// tensor.asSlice() []T (zero-copy pointer to the flat @Vector storage)
//
// Contraction (replaces dot / cross / matmul):
// a.contract(b, axis_a, axis_b)
// For rank-1 × rank-1 this is the dot product.
// For rank-2 × rank-2 with axis_a=1, axis_b=0 this is matrix multiply.
//
// Removed from Quantity:
// Scalar / Vector aliases, Vec3 / ScalarType, .value(), .vec(), .vec3(),
// dot(), cross(), mulScalar(), divScalar(), eqScalar() and friends.
// Use Tensor(..., &.{1}), .data[0], mul(), div(), eq() respectively.
//
pub fn Tensor( pub fn Tensor(
comptime T: type, comptime T: type,
comptime d_opt: Dimensions.ArgOpts, comptime d_opt: Dimensions.ArgOpts,
comptime s_opt: Scales.ArgOpts, comptime s_opt: Scales.ArgOpts,
comptime shape_: []const usize, comptime shape_: []const comptime_int,
) type { ) type {
comptime { comptime {
std.debug.assert(shape_.len >= 1); if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
for (shape_) |s| std.debug.assert(s >= 1); for (shape_) |s| {
if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1.");
}
} }
@setEvalBranchQuota(10_000_000); @setEvalBranchQuota(10_000_000);
@ -191,41 +216,32 @@ pub fn Tensor(
const Vec = @Vector(_total, T); const Vec = @Vector(_total, T);
return struct { return struct {
/// Flat SIMD storage. All arithmetic operates here directly.
data: Vec, data: Vec,
const Self = @This(); const Self = @This();
pub const ValueType : type = T; pub const ValueType: type = T;
pub const dims : Dimensions = Dimensions.init(d_opt); pub const dims: Dimensions = Dimensions.init(d_opt);
pub const scales : Scales = Scales.init(s_opt); pub const scales: Scales = Scales.init(s_opt);
pub const shape : []const usize = shape_; pub const shape: []const comptime_int = shape_;
pub const rank : usize = shape_.len; pub const rank: comptime_int = shape_.len;
pub const total : usize = _total; pub const total: comptime_int = _total;
pub const strides_arr: [shape_.len]usize = _strides; pub const strides_arr: [shape_.len]comptime_int = _strides;
pub const ISTENSOR = true; pub const ISTENSOR = true;
//
// Index helper
//
/// Convert N-D coords (row-major) to flat index fully comptime. /// Convert N-D coords (row-major) to flat index fully comptime.
/// Usage: Tensor.idx(.{row, col}) /// Usage: Tensor.idx(.{row, col})
pub fn idx(comptime coords: [rank]usize) usize { pub inline fn idx(comptime coords: [rank]usize) usize {
comptime { comptime {
var flat: usize = 0; var flat: usize = 0;
for (0..rank) |i| { for (0..rank) |i| {
std.debug.assert(coords[i] < shape[i]); if (coords[i] >= shape[i]) @compileError("idx: Coordinate out of bounds");
flat += coords[i] * strides_arr[i]; flat += coords[i] * strides_arr[i];
} }
return flat; return flat;
} }
} }
//
// Constructors
//
/// Broadcast a single value across all elements. /// Broadcast a single value across all elements.
pub inline fn splat(v: T) Self { pub inline fn splat(v: T) Self {
return .{ .data = @splat(v) }; return .{ .data = @splat(v) };
@ -234,25 +250,17 @@ pub fn Tensor(
pub const zero: Self = splat(0); pub const zero: Self = splat(0);
pub const one: Self = splat(1); pub const one: Self = splat(1);
//
// GPU readiness
//
/// Return a mutable slice to the flat storage zero-copy WebGPU buffer mapping. /// Return a mutable slice to the flat storage zero-copy WebGPU buffer mapping.
pub inline fn asSlice(self: *Self) []T { pub inline fn asSlice(self: *Self) []T {
return @as([*]T, @ptrCast(&self.data))[0..total]; return @as([*]T, @ptrCast(&self.data))[0..total];
} }
// inline fn RhsT(comptime Rhs: type) type {
// Internal: RHS normalisation return RhsTensorType(T, Rhs);
// }
inline fn rhs(r: anytype) RhsT(@TypeOf(r)) {
inline fn RhsT(comptime Rhs: type) type { return RhsTensorType(T, Rhs); } return toRhsTensor(T, r);
inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { return toRhsTensor(T, r); } }
//
// Internal: scalar broadcast (shape {1} full Vec)
//
inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec { inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec {
return if (comptime RhsType.total == 1 and total > 1) return if (comptime RhsType.total == 1 and total > 1)
@ -261,57 +269,54 @@ pub fn Tensor(
r.data; r.data;
} }
//
// Arithmetic
//
/// Element-wise add. Dimensions must match; scales resolve to finer. /// Element-wise add. Dimensions must match; scales resolve to finer.
/// RHS must have the same element count as self, or total == 1 (broadcast). /// RHS must have the same shape as self, or total == 1 (broadcast).
pub inline fn add(self: Self, r: anytype) Tensor( pub inline fn add(self: Self, r: anytype) Tensor(
T, T,
dims.argsOpt(), dims.argsOpt(),
hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
shape_, shape_,
) { ) {
const rhs_q = rhs(r); const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q); const RhsType = @TypeOf(rhs_q);
if (comptime !dims.eql(RhsType.dims)) if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != total and RhsType.total != 1) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in add: element counts must match or RHS must be scalar (total=1)."); @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
const rr: Vec = blk: { const rr: Vec = blk: {
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
break :blk broadcastToVec(RhsNorm, rn); break :blk broadcastToVec(RhsNorm, rn);
}; };
return .{ .data = if (comptime hlp.isInt(T)) l +| rr else l + rr }; return .{ .data = if (comptime isInt(T)) l +| rr else l + rr };
} }
/// Element-wise subtract. Dimensions must match; scales resolve to finer. /// Element-wise sub. Dimensions must match; scales resolve to finer.
/// RHS must have the same shape as self, or total == 1 (broadcast).
pub inline fn sub(self: Self, r: anytype) Tensor( pub inline fn sub(self: Self, r: anytype) Tensor(
T, T,
dims.argsOpt(), dims.argsOpt(),
hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
shape_, shape_,
) { ) {
const rhs_q = rhs(r); const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q); const RhsType = @TypeOf(rhs_q);
if (comptime !dims.eql(RhsType.dims)) if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != total and RhsType.total != 1) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in sub."); @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
const rr: Vec = blk: { const rr: Vec = blk: {
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
break :blk broadcastToVec(RhsNorm, rn); break :blk broadcastToVec(RhsNorm, rn);
}; };
return .{ .data = if (comptime hlp.isInt(T)) l -| rr else l - rr }; return .{ .data = if (comptime isInt(T)) l -| rr else l - rr };
} }
/// Element-wise multiply. Dimension exponents summed. /// Element-wise multiply. Dimension exponents summed.
@ -319,20 +324,20 @@ pub fn Tensor(
pub inline fn mul(self: Self, r: anytype) Tensor( pub inline fn mul(self: Self, r: anytype) Tensor(
T, T,
dims.add(RhsT(@TypeOf(r)).dims).argsOpt(), dims.add(RhsT(@TypeOf(r)).dims).argsOpt(),
hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
shape_, shape_,
) { ) {
const rhs_q = rhs(r); const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q); const RhsType = @TypeOf(rhs_q);
if (comptime RhsType.total != total and RhsType.total != 1) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in mul."); @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
const rr: Vec = broadcastToVec(RhsNorm, rr_base); const rr: Vec = broadcastToVec(RhsNorm, rr_base);
return .{ .data = if (comptime hlp.isInt(T)) l *| rr else l * rr }; return .{ .data = if (comptime isInt(T)) l *| rr else l * rr };
} }
/// Element-wise divide. Dimension exponents subtracted. /// Element-wise divide. Dimension exponents subtracted.
@ -340,32 +345,26 @@ pub fn Tensor(
pub inline fn div(self: Self, r: anytype) Tensor( pub inline fn div(self: Self, r: anytype) Tensor(
T, T,
dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(), dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(),
hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
shape_, shape_,
) { ) {
const rhs_q = rhs(r); const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q); const RhsType = @TypeOf(rhs_q);
if (comptime RhsType.total != total and RhsType.total != 1) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in div."); @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
const rr: Vec = broadcastToVec(RhsNorm, rr_base); const rr: Vec = broadcastToVec(RhsNorm, rr_base);
if (comptime hlp.isInt(T)) { if (comptime isInt(T)) {
var result: Vec = undefined; return .{ .data = @divTrunc(l, rr) };
inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]);
return .{ .data = result };
} else { } else {
return .{ .data = l / rr }; return .{ .data = l / rr };
} }
} }
//
// Unary
//
/// Absolute value of every element. /// Absolute value of every element.
pub inline fn abs(self: Self) Self { pub inline fn abs(self: Self) Self {
return .{ .data = @bitCast(@abs(self.data)) }; return .{ .data = @bitCast(@abs(self.data)) };
@ -378,19 +377,27 @@ pub fn Tensor(
scales.argsOpt(), scales.argsOpt(),
shape_, shape_,
) { ) {
if (comptime hlp.isInt(T)) { if (comptime exp == 0) return .{ .data = @splat(1) };
var result: Vec = undefined; if (comptime exp == 1) return self;
inline for (0..total) |i|
result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T); var base = self.data;
return .{ .data = result };
} else {
const abs_exp = comptime @abs(exp);
var result: Vec = @splat(1); var result: Vec = @splat(1);
comptime var i = 0; comptime var e = @abs(exp);
inline while (i < abs_exp) : (i += 1) result *= self.data;
if (comptime exp < 0) result = @as(Vec, @splat(1)) / result; // $O(\log n)$ Exponentiation by squaring applied to the entire vector
return .{ .data = result }; inline while (e > 0) {
if (e % 2 == 1) {
result = if (comptime isInt(T)) result *| base else result * base;
} }
e /= 2;
if (e > 0) {
base = if (comptime isInt(T)) base *| base else base * base;
}
}
if (comptime !isInt(T) and exp < 0) {
result = @as(Vec, @splat(1)) / result;
}
return .{ .data = result };
} }
/// Square root of every element. All dimension exponents must be even. /// Square root of every element. All dimension exponents must be even.
@ -403,15 +410,16 @@ pub fn Tensor(
if (comptime !dims.isSquare()) if (comptime !dims.isSquare())
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
if (comptime @typeInfo(T) == .float) { if (comptime @typeInfo(T) == .float) {
return .{ .data = @sqrt(self.data) }; return .{ .data = @sqrt(self.data) }; // Float is natively vectorized!
} else { } else {
var result: Vec = undefined; const arr: [total]T = self.data; // Add this!
var res_arr: [total]T = undefined;
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
inline for (0..total) |i| { for (0..total) |i| {
const v = self.data[i]; const v = arr[i];
result[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
} }
return .{ .data = result }; return .{ .data = res_arr };
} }
} }
@ -420,13 +428,9 @@ pub fn Tensor(
return .{ .data = -self.data }; return .{ .data = -self.data };
} }
//
// Conversion
//
/// Convert to a compatible Tensor type. /// Convert to a compatible Tensor type.
/// Dimension mismatch compile error. /// Dimension mismatch compile error.
/// Dest.total must equal self.total, or Dest.total == 1 (scalar pattern). /// Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern).
/// Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. /// Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
pub inline fn to( pub inline fn to(
self: Self, self: Self,
@ -434,82 +438,92 @@ pub fn Tensor(
) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
if (comptime !dims.eql(ActualDest.dims))
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
if (comptime Self == ActualDest) return self; if (comptime Self == ActualDest) return self;
comptime std.debug.assert(Dest.total == total or Dest.total == 1); // Run validation checks FIRST before dealing with types
if (comptime !dims.eql(ActualDest.dims))
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
const DestT = ActualDest.ValueType;
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
const DestT = ActualDest.ValueType;
const DestVec = @Vector(total, DestT); const DestVec = @Vector(total, DestT);
// Same numeric type // If ratio is 1, handle type conversion correctly based on BOTH source and dest types
if (comptime ratio == 1.0) {
const T_info = @typeInfo(T);
const Dest_info = @typeInfo(DestT);
return .{
.data = if (comptime T_info == .int and Dest_info == .int)
@as(DestVec, @intCast(self.data))
else if (comptime T_info == .float and Dest_info == .float)
@as(DestVec, @floatCast(self.data))
else if (comptime T_info == .int and Dest_info == .float)
@as(DestVec, @floatFromInt(self.data))
else if (comptime T_info == .float and Dest_info == .int)
@as(DestVec, @intFromFloat(self.data)) // Or @intFromFloat(@round(self.data)) if you want rounding
else
unreachable,
};
}
if (comptime T == DestT) { if (comptime T == DestT) {
if (comptime @typeInfo(T) == .float) if (comptime @typeInfo(T) == .float)
return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
// Integer branch prevents division-by-zero
if (comptime ratio >= 1.0) { if (comptime ratio >= 1.0) {
const mult: T = comptime @intFromFloat(@round(ratio)); const mult: T = comptime @intFromFloat(@round(ratio));
return .{ .data = self.data *| @as(Vec, @splat(mult)) }; return .{ .data = self.data *| @as(Vec, @splat(mult)) };
} else { } else {
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
const half: T = comptime @divTrunc(div_val, 2); const half: T = comptime @divTrunc(div_val, 2);
var result: DestVec = undefined;
inline for (0..total) |i| { if (comptime @typeInfo(T).int.signedness == .unsigned) {
const val = self.data[i]; return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) };
result[i] = if (val >= 0) } else {
@divTrunc(val + half, div_val) // Vectorized branchless negative handling
else const is_pos = self.data >= @as(Vec, @splat(0));
@divTrunc(val - half, div_val); const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half)));
return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) };
} }
return .{ .data = result };
} }
} }
// Cross numeric type // Cross-type fully vectorized casting with scales
var result: DestVec = undefined; const FVec = @Vector(total, f64);
inline for (0..total) |i| { const float_vec: FVec = switch (comptime @typeInfo(T)) {
const float_val: f64 = switch (comptime @typeInfo(T)) { .float => @floatCast(self.data),
.float => @floatCast(self.data[i]), .int => @floatFromInt(self.data),
.int => @floatFromInt(self.data[i]),
else => unreachable, else => unreachable,
}; };
const scaled = float_val * ratio;
result[i] = switch (comptime @typeInfo(DestT)) {
.float => @floatCast(scaled),
.int => @intFromFloat(@round(scaled)),
else => unreachable,
};
}
return .{ .data = result };
}
// const scaled = float_vec * @as(FVec, @splat(ratio));
// Comparisons
// return switch (comptime @typeInfo(DestT)) {
// Return type: bool when total == 1 (scalar semantics) .float => .{ .data = @floatCast(scaled) },
// [total]bool when total > 1 (element-wise, flat-indexed) .int => .{ .data = @intFromFloat(@round(scaled)) },
// else => unreachable,
// Whole-tensor equality check eqAll / neAll (always returns bool). };
// A shape {1} RHS is broadcast automatically, unifying the old }
// eqScalar / gtScalar / family into the plain eq / gt / methods.
//
const CmpResult = if (total == 1) bool else [total]bool; const CmpResult = if (total == 1) bool else [total]bool;
inline fn cmpResult(v: @Vector(total, bool)) CmpResult { inline fn cmpResult(v: @Vector(total, bool)) CmpResult {
return if (comptime total == 1) v[0] else @as([total]bool, v); return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v);
} }
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } {
const RhsType = @TypeOf(rhs_q); const RhsType = @TypeOf(rhs_q);
const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
const rr: Vec = blk: { const rr: Vec = blk: {
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
break :blk broadcastToVec(RhsNorm, rn); break :blk broadcastToVec(RhsNorm, rn);
}; };
@ -577,26 +591,6 @@ pub fn Tensor(
return !self.eqAll(other); return !self.eqAll(other);
} }
//
// Contraction generalised dot product / matrix multiply / einsum
//
// a.contract(b, axis_a, axis_b)
//
// Sums over dimension `axis_a` of `a` and `axis_b` of `b`.
// Requires a.shape[axis_a] == b.shape[axis_b] (checked at comptime).
//
// Result shape = a.shape \ axis_a ++ b.shape \ axis_b
// Result dims = a.dims + b.dims (exponents summed, as in mul)
// Result scales = finer of a, b
//
// Special cases:
// rank-1 × rank-1, axis 0 × 0 dot product (result shape {1})
// rank-2 × rank-2, axis 1 × 0 matrix multiply
// rank-1 × rank-2, axis 0 × 0 vectormatrix product
//
// All index arithmetic is comptime; runtime cost is the multiply-add loop only.
//
pub inline fn contract( pub inline fn contract(
self: Self, self: Self,
other: anytype, other: anytype,
@ -604,18 +598,18 @@ pub fn Tensor(
comptime axis_b: usize, comptime axis_b: usize,
) blk: { ) blk: {
const OT = @TypeOf(other); const OT = @TypeOf(other);
comptime std.debug.assert(axis_a < rank); if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
comptime std.debug.assert(axis_b < OT.rank); if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds");
comptime std.debug.assert(shape_[axis_a] == OT.shape[axis_b]); if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
// Contracted-away free axes; empty joint scalar shape {1}
const sa = shapeRemoveAxis(shape_, axis_a); const sa = shapeRemoveAxis(shape_, axis_a);
const sb = shapeRemoveAxis(OT.shape, axis_b); const sb = shapeRemoveAxis(OT.shape, axis_b);
const rs_raw = shapeCat(&sa, &sb); const rs_raw = shapeCat(&sa, &sb);
const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw; const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
break :blk Tensor( break :blk Tensor(
T, T,
dims.add(OT.dims).argsOpt(), dims.add(OT.dims).argsOpt(),
hlp.finerScales(Self, OT).argsOpt(), finerScales(Self, OT).argsOpt(),
rs, rs,
); );
} { } {
@ -625,55 +619,131 @@ pub fn Tensor(
const sa = comptime shapeRemoveAxis(shape_, axis_a); const sa = comptime shapeRemoveAxis(shape_, axis_a);
const sb = comptime shapeRemoveAxis(OT.shape, axis_b); const sb = comptime shapeRemoveAxis(OT.shape, axis_b);
const rs_raw = comptime shapeCat(&sa, &sb); const rs_raw = comptime shapeCat(&sa, &sb);
const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
const ResultType = Tensor( const ResultType = Tensor(
T, T,
dims.add(OT.dims).argsOpt(), dims.add(OT.dims).argsOpt(),
hlp.finerScales(Self, OT).argsOpt(), finerScales(Self, OT).argsOpt(),
rs, rs,
); );
// Normalise scales before accumulation const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), shape_); const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
const OtherNorm = Tensor(T, OT.dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), OT.shape);
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
// Precompute result strides from rs_raw (for coord decoding) // FAST PATH: Dot Product
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
if (comptime !isInt(T)) {
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
} else {
// For integers, we do a vectorized saturating multiply,
// then convert to an array to do a saturating sum
const mul_arr: [total]T = a_data *| b_data;
var acc: T = 0;
for (mul_arr) |val| acc +|= val;
return .{ .data = @splat(acc) };
}
}
// --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING ---
const a_arr: [total]T = a_data;
const b_arr: [OT.total]T = b_data;
// FAST PATH: 2D Matrix Multiplication
if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) {
const rows = shape_[0];
const cols = OT.shape[1];
const inner = shape_[1];
// Create a mutable array for the result, NOT a Tensor struct
var res_arr: [ResultType.total]T = undefined;
for (0..rows) |i| {
for (0..cols) |j| {
var acc: T = 0;
for (0..inner) |id| {
const a_flat = i * _strides[0] + id * _strides[1];
const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1];
// Use a_arr and b_arr here
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
}
// Write to the array
res_arr[i * cols + j] = acc;
}
}
// Return the initialized Tensor struct
return .{ .data = res_arr };
}
// FALLBACK PATH
const rs_raw_strides = comptime shapeStrides(&rs_raw); const rs_raw_strides = comptime shapeStrides(&rs_raw);
var result: ResultType = .{ .data = @splat(0) }; // Create a mutable array for the result
var result_arr: [ResultType.total]T = undefined;
inline for (0..ResultType.total) |res_flat| { for (0..ResultType.total) |res_flat| {
// Decode result flat index into free coords using rs_raw layout. const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
// When rs_raw.len == 0, decodeFlatCoords returns [0]usize{} correct.
const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*; var a_free: [sa.len]usize = undefined;
const b_free: [sb.len]usize = comptime res_coords[sa.len..].*; for (0..sa.len) |i| a_free[i] = res_coords[i];
var b_free: [sb.len]usize = undefined;
for (0..sb.len) |i| b_free[i] = res_coords[sa.len + i];
var acc: T = 0; var acc: T = 0;
inline for (0..k) |ki| { for (0..k) |ki| {
// Reinsert the contracted index into free coords full coord arrays const a_coords = insertAxis(rank, axis_a, ki, &a_free);
const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free); const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free);
const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free); const a_flat = encodeFlatCoords(&a_coords, rank, _strides);
const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
if (comptime hlp.isInt(T)) // Use a_arr and b_arr here
acc +|= a_data[a_flat] *| b_data[b_flat] if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
else
acc += a_data[a_flat] * b_data[b_flat];
} }
result.data[res_flat] = acc; // Write to the array
} result_arr[res_flat] = acc;
return result;
} }
// // Return the initialized Tensor struct
// Reduction helpers return .{ .data = result_arr };
// }
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
/// Result dimensions are the sum of input dimensions.
pub inline fn cross(self: Self, other: anytype) Tensor(
T,
dims.add(RhsT(@TypeOf(other)).dims).argsOpt(),
finerScales(Self, RhsT(@TypeOf(other))).argsOpt(),
&.{3},
) {
const rhs_q = rhs(other);
const RhsType = @TypeOf(rhs_q);
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) {
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
}
// Bring both to the same scale (e.g., mm vs m)
const p = self.resolveScalePair(rhs_q);
const l = p.l;
const r = p.r;
var res: [3]T = undefined;
if (comptime isInt(T)) {
res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]);
res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]);
res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]);
} else {
res[0] = (l[1] * r[2]) - (l[2] * r[1]);
res[1] = (l[2] * r[0]) - (l[0] * r[2]);
res[2] = (l[0] * r[1]) - (l[1] * r[0]);
}
return .{ .data = res };
}
/// Sum of squared elements. Cheaper than length(); use for ordering. /// Sum of squared elements. Cheaper than length(); use for ordering.
pub inline fn lengthSqr(self: Self) T { pub inline fn lengthSqr(self: Self) T {
@ -700,10 +770,6 @@ pub fn Tensor(
return .{ .data = .{@reduce(.Mul, self.data)} }; return .{ .data = .{@reduce(.Mul, self.data)} };
} }
//
// Formatting
//
pub fn formatNumber( pub fn formatNumber(
self: Self, self: Self,
writer: *std.Io.Writer, writer: *std.Io.Writer,
@ -722,7 +788,8 @@ pub fn Tensor(
} }
} else { } else {
try writer.writeAll("("); try writer.writeAll("(");
inline for (0..total) |i| { const max_to_print = 6;
inline for (0..@min(total, max_to_print)) |i| {
if (i > 0) try writer.writeAll(", "); if (i > 0) try writer.writeAll(", ");
switch (@typeInfo(T)) { switch (@typeInfo(T)) {
.float, .comptime_float => try writer.printFloat(self.data[i], options), .float, .comptime_float => try writer.printFloat(self.data[i], options),
@ -734,6 +801,8 @@ pub fn Tensor(
}), }),
else => unreachable, else => unreachable,
} }
if (comptime i == max_to_print - 1 and total != max_to_print - 1)
try writer.writeAll(", ...");
} }
try writer.writeAll(")"); try writer.writeAll(")");
} }
@ -751,7 +820,7 @@ pub fn Tensor(
else else
try writer.print("{s}{s}", .{ uscale.str(), bu.unit() }); try writer.print("{s}{s}", .{ uscale.str(), bu.unit() });
if (v != 1) try hlp.printSuperscript(writer, v); if (v != 1) try printSuperscript(writer, v);
} }
} }
}; };
@ -760,15 +829,6 @@ pub fn Tensor(
// //
// Tests // Tests
// //
// Naming convention used throughout:
// Tensor(T, d, s, &.{1}) former Scalar
// Tensor(T, d, s, &.{N}) former Vector of length N
// .data[0] former .value()
// .mul(x) former .mulScalar(x) (x may be scalar Tensor or bare number)
// .div(x) former .divScalar(x)
// .eq(x) former .eqScalar(x) (broadcasts when x.total==1)
// .contract(other, 0, 0) former .dot(other) (for rank-1 tensors)
//
// Scalar tests // Scalar tests
@ -1185,7 +1245,6 @@ test "Vector Comparisons" {
} }
test "Vector vs Scalar broadcast comparison" { test "Vector vs Scalar broadcast comparison" {
// Replaces the old eqScalar / gtScalar now just eq / gt with a shape-{1} rhs.
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1});
@ -1210,7 +1269,6 @@ test "Vector contract — dot product (rank-1 × rank-1)" {
const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } }; const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } };
const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } }; const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } };
// work = force · pos
const work = force.contract(pos, 0, 0); const work = force.contract(pos, 0, 0);
try std.testing.expectEqual(50.0, work.data[0]); try std.testing.expectEqual(50.0, work.data[0]);
try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M)); try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M));
@ -1219,28 +1277,17 @@ test "Vector contract — dot product (rank-1 × rank-1)" {
} }
test "Vector contract — matrix multiply (rank-2 × rank-2)" { test "Vector contract — matrix multiply (rank-2 × rank-2)" {
// 2×3 matrix multiplied by 3×2 matrix 2×2 result const A = Tensor(f32, .{}, .{}, &.{ 2, 3 });
const A = Tensor(f32, .{}, .{}, &.{2, 3}); const B = Tensor(f32, .{}, .{}, &.{ 3, 2 });
const B = Tensor(f32, .{}, .{}, &.{3, 2});
// A = [[1, 2, 3],
// [4, 5, 6]]
const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } }; const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } };
// B = [[7, 8],
// [9, 10],
// [11, 12]]
const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } }; const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } };
// C = A @ B (contract over axis 1 of A × axis 0 of B)
// C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
// C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64
// C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139
// C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154
const c = a.contract(b, 1, 0); const c = a.contract(b, 1, 0);
try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 0})]); try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]);
try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 1})]); try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]);
try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 0})]); try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 0 })]);
try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 1})]); try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 1 })]);
} }
test "Vector Abs, Pow, Sqrt and Product" { test "Vector Abs, Pow, Sqrt and Product" {
@ -1321,23 +1368,22 @@ test "Vector eq broadcast on dimensionless" {
} }
test "Tensor idx helper and matrix access" { test "Tensor idx helper and matrix access" {
const Mat3x3 = Tensor(f32, .{}, .{}, &.{3, 3}); const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 });
// Identity-like: set [0][0]=1, [1][1]=2, [2][2]=3
var m: Mat3x3 = Mat3x3.zero; var m: Mat3x3 = Mat3x3.zero;
m.data[Mat3x3.idx(.{0, 0})] = 1.0; m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0;
m.data[Mat3x3.idx(.{1, 1})] = 2.0; m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0;
m.data[Mat3x3.idx(.{2, 2})] = 3.0; m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0;
try std.testing.expectEqual(1.0, m.data[0]); // [0][0] try std.testing.expectEqual(1.0, m.data[0]);
try std.testing.expectEqual(2.0, m.data[4]); // [1][1] (stride 3 1*3+1=4) try std.testing.expectEqual(2.0, m.data[4]);
try std.testing.expectEqual(3.0, m.data[8]); // [2][2] (2*3+2=8) try std.testing.expectEqual(3.0, m.data[8]);
try std.testing.expectEqual(0.0, m.data[1]); // [0][1] try std.testing.expectEqual(0.0, m.data[1]);
} }
test "Tensor strides_arr correctness" { test "Tensor strides_arr correctness" {
const T1 = Tensor(f32, .{}, .{}, &.{3}); const T1 = Tensor(f32, .{}, .{}, &.{3});
const T2 = Tensor(f32, .{}, .{}, &.{3, 4}); const T2 = Tensor(f32, .{}, .{}, &.{ 3, 4 });
const T3 = Tensor(f32, .{}, .{}, &.{2, 3, 4}); const T3 = Tensor(f32, .{}, .{}, &.{ 2, 3, 4 });
try std.testing.expectEqual(1, T1.strides_arr[0]); try std.testing.expectEqual(1, T1.strides_arr[0]);
try std.testing.expectEqual(4, T2.strides_arr[0]); try std.testing.expectEqual(4, T2.strides_arr[0]);

View File

@ -1,7 +1,6 @@
const std = @import("std"); const std = @import("std");
const Io = std.Io; const Io = std.Io;
const Scalar = @import("Quantity.zig").Scalar; const Tensor = @import("Tensor.zig").Tensor;
const Vector = @import("Quantity.zig").Vector;
var io: Io = undefined; var io: Io = undefined;
pub fn main(init: std.process.Init) !void { pub fn main(init: std.process.Init) !void {
@ -11,16 +10,16 @@ pub fn main(init: std.process.Init) !void {
io = init.io; io = init.io;
// try vectorSIMDvsNative(f64, &stdout_writer.interface); try vectorSIMDvsNative(f64, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
// try vectorSIMDvsNative(f32, &stdout_writer.interface); try vectorSIMDvsNative(f32, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
// try vectorSIMDvsNative(i32, &stdout_writer.interface); try vectorSIMDvsNative(i32, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
// try vectorSIMDvsNative(i64, &stdout_writer.interface); try vectorSIMDvsNative(i64, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
// try vectorSIMDvsNative(i128, &stdout_writer.interface); try vectorSIMDvsNative(i128, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
try bench_Scalar(&stdout_writer.interface); try bench_Scalar(&stdout_writer.interface);
try stdout_writer.flush(); try stdout_writer.flush();
@ -97,9 +96,9 @@ fn bench_Scalar(writer: *std.Io.Writer) !void {
comptime var tidx: usize = 0; comptime var tidx: usize = 0;
inline for (Types, TNames) |T, tname| { inline for (Types, TNames) |T, tname| {
const M = Scalar(T, .{ .L = 1 }, .{}); const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
const KM = Scalar(T, .{ .L = 1 }, .{ .L = .k }); const KM = Tensor(T, .{ .L = 1 }, .{ .L = .k }, &.{1});
const S = Scalar(T, .{ .T = 1 }, .{}); const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
inline for (Ops, 0..) |op_name, oidx| { inline for (Ops, 0..) |op_name, oidx| {
var samples: [SAMPLES]f64 = undefined; var samples: [SAMPLES]f64 = undefined;
@ -199,8 +198,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
var native_total_ns: f64 = 0; var native_total_ns: f64 = 0;
var quantity_total_ns: f64 = 0; var quantity_total_ns: f64 = 0;
const M = Scalar(T, .{ .L = 1 }, .{}); const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
const S = Scalar(T, .{ .T = 1 }, .{}); const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
std.mem.doNotOptimizeAway({ std.mem.doNotOptimizeAway({
for (0..SAMPLES) |_| { for (0..SAMPLES) |_| {
@ -321,9 +320,9 @@ fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void {
var native_total_ns: f64 = 0; var native_total_ns: f64 = 0;
var quantity_total_ns: f64 = 0; var quantity_total_ns: f64 = 0;
const M1 = Scalar(T1, .{ .L = 1 }, .{}); const M1 = Tensor(T1, .{ .L = 1 }, .{}, &.{1});
const M2 = Scalar(T2, .{ .L = 1 }, .{}); const M2 = Tensor(T2, .{ .L = 1 }, .{}, &.{1});
const S2 = Scalar(T2, .{ .T = 1 }, .{}); const S2 = Tensor(T2, .{ .T = 1 }, .{}, &.{1});
std.mem.doNotOptimizeAway({ std.mem.doNotOptimizeAway({
for (0..SAMPLES) |_| { for (0..SAMPLES) |_| {
@ -429,9 +428,8 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname }); try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname });
inline for (Lengths) |len| { inline for (Lengths) |len| {
const Q_base = Scalar(T, .{ .L = 1 }, .{}); const Q_time = Tensor(T, .{ .T = 1 }, .{}, &.{1});
const Q_time = Scalar(T, .{ .T = 1 }, .{}); const V = Tensor(T, .{ .L = 1 }, .{}, &.{len});
const V = Vector(len, Q_base);
// cross product is only defined for len == 3 // cross product is only defined for len == 3
const is_cross = comptime std.mem.eql(u8, op_name, "cross"); const is_cross = comptime std.mem.eql(u8, op_name, "cross");
@ -455,10 +453,10 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
_ = v1.div(V.splat(getVal(T, i +% 2, 63))); _ = v1.div(V.splat(getVal(T, i +% 2, 63)));
} else if (comptime std.mem.eql(u8, op_name, "mulScalar")) { } else if (comptime std.mem.eql(u8, op_name, "mulScalar")) {
const s_val = Q_time.splat(getVal(T, i +% 2, 63)); const s_val = Q_time.splat(getVal(T, i +% 2, 63));
_ = v1.mulScalar(s_val); _ = v1.mul(s_val);
} else if (comptime std.mem.eql(u8, op_name, "dot")) { } else if (comptime std.mem.eql(u8, op_name, "dot")) {
const v2 = V.splat(getVal(T, i +% 5, 63)); const v2 = V.splat(getVal(T, i +% 5, 63));
_ = v1.dot(v2); _ = v1.contract(v2, 0, 0);
} else if (comptime std.mem.eql(u8, op_name, "cross")) { } else if (comptime std.mem.eql(u8, op_name, "cross")) {
// len == 3 guaranteed by the guard above // len == 3 guaranteed by the guard above
const v2 = V.splat(getVal(T, i +% 5, 63)); const v2 = V.splat(getVal(T, i +% 5, 63));

View File

@ -1,97 +0,0 @@
const std = @import("std");
pub fn isInt(comptime T: type) bool {
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
}
pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
if (n == 0) return;
var val = n;
if (val < 0) {
try writer.writeAll("\u{207B}");
val = -val;
}
var buf: [12]u8 = undefined;
const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return;
for (str) |c| {
const s = switch (c) {
'0' => "\u{2070}",
'1' => "\u{00B9}",
'2' => "\u{00B2}",
'3' => "\u{00B3}",
'4' => "\u{2074}",
'5' => "\u{2075}",
'6' => "\u{2076}",
'7' => "\u{2077}",
'8' => "\u{2078}",
'9' => "\u{2079}",
else => unreachable,
};
try writer.writeAll(s);
}
}
const Scales = @import("Scales.zig");
const Dimensions = @import("Dimensions.zig");
const Dimension = @import("Dimensions.zig").Dimension;
pub fn finerScales(comptime T1: type, comptime T2: type) Scales {
const d1: Dimensions = T1.dims;
const d2: Dimensions = T2.dims;
const s1: Scales = T1.scales;
const s2: Scales = T2.scales;
comptime var out = Scales.initFill(.none);
inline for (std.enums.values(Dimension)) |dim| {
const scale1 = comptime s1.get(dim);
const scale2 = comptime s2.get(dim);
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
.none
else if (comptime d1.get(dim) == 0)
scale2
else if (comptime d2.get(dim) == 0)
scale1
else if (comptime scale1.getFactor() > scale2.getFactor())
scale2
else
scale1);
}
comptime return out;
}
// ---------------------------------------------------------------------------
// RHS normalisation helpers
// ---------------------------------------------------------------------------
const Quantity = @import("Quantity.zig").Quantity;
/// Returns true if `T` is a `Scalar_` type (has `dims`, `scales`, and `value`).
pub fn isScalarType(comptime T: type) bool {
return @typeInfo(T) == .@"struct" and
@hasDecl(T, "ISQUANTITY") and
@field(T, "ISQUANTITY");
}
/// Resolve the Scalar type that `rhs` will be treated as.
///
/// Accepted rhs types:
/// - Any `Scalar_` type returned as-is
/// - `comptime_int` / `comptime_float` dimensionless `Scalar_(BaseT, {}, {})`
/// - `BaseT` (the scalar's value type) dimensionless `Scalar_(BaseT, {}, {})`
///
/// Everything else is a compile error, including other int/float types.
pub fn rhsQuantityType(comptime ValueType: type, N: usize, comptime RhsT: type) type {
if (comptime isScalarType(RhsT)) return RhsT;
if (comptime RhsT == comptime_int or RhsT == comptime_float or RhsT == ValueType)
return Quantity(ValueType, N, .{}, .{});
@compileError(
"rhs must be a Scalar, " ++ @typeName(ValueType) ++
", comptime_int, or comptime_float; got " ++ @typeName(RhsT),
);
}
/// Convert `rhs` to its normalised Scalar form (see `rhsScalarType`).
pub inline fn toRhsQuantity(comptime BaseT: type, N: usize, rhs: anytype) rhsQuantityType(BaseT, N, @TypeOf(rhs)) {
if (comptime isScalarType(@TypeOf(rhs))) return rhs;
const DimLess = Quantity(BaseT, N, .{}, .{});
return DimLess{ .data = @splat(@as(BaseT, rhs)) };
}

View File

@ -1,15 +1,13 @@
const std = @import("std"); const std = @import("std");
pub const Vector = @import("Quantity.zig").Vector; pub const Tensor = @import("Tensor.zig").Tensor;
pub const Scalar = @import("Quantity.zig").Scalar;
pub const Dimensions = @import("Dimensions.zig"); pub const Dimensions = @import("Dimensions.zig");
pub const Scales = @import("Scales.zig"); pub const Scales = @import("Scales.zig");
pub const Base = @import("Base.zig"); pub const Base = @import("Base.zig");
test { test {
_ = @import("Quantity.zig"); _ = @import("Tensor.zig");
_ = @import("Dimensions.zig"); _ = @import("Dimensions.zig");
_ = @import("Scales.zig"); _ = @import("Scales.zig");
_ = @import("Base.zig"); _ = @import("Base.zig");
_ = @import("helper.zig");
} }