Fixed new Tensor to be everything (Scalar, Vector, Matrix and above)
This commit is contained in:
parent
934a40fe1a
commit
f37a196b15
59
src/Base.zig
59
src/Base.zig
@ -3,34 +3,39 @@ const std = @import("std");
|
||||
// Adjust these imports to match your actual file names
|
||||
const Dimensions = @import("Dimensions.zig");
|
||||
const Scales = @import("Scales.zig");
|
||||
const Scalar = @import("Quantity.zig").Scalar;
|
||||
const Tensor = @import("Tensor.zig").Tensor;
|
||||
|
||||
fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime s: Scales.ArgOpts) type {
|
||||
return struct {
|
||||
const dims = Dimensions.init(d);
|
||||
const scales = Scales.init(s);
|
||||
pub const dims = Dimensions.init(d);
|
||||
pub const scales = Scales.init(s);
|
||||
|
||||
/// Instantiates the constant into a specific numeric type.
|
||||
pub fn Of(comptime T: type) Scalar(T, d, s) {
|
||||
return .{ .data = @splat(@as(T, @floatCast(val))) };
|
||||
pub fn Of(comptime T: type) Tensor(T, d, s, &.{1}) {
|
||||
const casted_val: T = switch (@typeInfo(T)) {
|
||||
.float => @floatCast(val),
|
||||
.int => @intFromFloat(val),
|
||||
else => @compileError("Unsupported type for PhysicalConstant"),
|
||||
};
|
||||
return Tensor(T, d, s, &.{1}).splat(casted_val);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn BaseScalar(comptime d: Dimensions.ArgOpts) type {
|
||||
return struct {
|
||||
const dims = Dimensions.init(d);
|
||||
pub const dims = Dimensions.init(d);
|
||||
|
||||
/// Creates a Scalar of this dimension using default scales.
|
||||
/// Example: const V = Quantities.Velocity.Base(f32);
|
||||
/// Example: const V = Quantities.Velocity.Of(f32);
|
||||
pub fn Of(comptime T: type) type {
|
||||
return Scalar(T, d, .{});
|
||||
return Tensor(T, d, .{}, &.{1});
|
||||
}
|
||||
|
||||
/// Creates a Scalar of this dimension using custom scales.
|
||||
/// Example: const Kmh = Quantities.Velocity.Scaled(f32, Scales.init(.{ .L = .k, .T = .hour }));
|
||||
/// Example: const Kmh = Quantities.Velocity.Scaled(f32, .{ .L = .k, .T = .hour });
|
||||
pub fn Scaled(comptime T: type, comptime s: Scales.ArgOpts) type {
|
||||
return Scalar(T, d, s);
|
||||
return Tensor(T, d, s, &.{1});
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -107,7 +112,7 @@ pub const ElectricCapacitance = BaseScalar(.{ .T = 4, .L = -2, .M = -1, .I = 2 }
|
||||
pub const ElectricImpedance = ElectricResistance;
|
||||
pub const MagneticFlux = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .I = -1 });
|
||||
pub const MagneticDensity = BaseScalar(.{ .M = 1, .T = -2, .I = -1 });
|
||||
pub const MagneticStrength = BaseScalar(.{ .L = -1, .I = 1 }); // Fixed typo from MagneticStrengh
|
||||
pub const MagneticStrength = BaseScalar(.{ .L = -1, .I = 1 });
|
||||
pub const MagneticMoment = BaseScalar(.{ .L = 2, .I = 1 });
|
||||
|
||||
// ==========================================
|
||||
@ -140,7 +145,7 @@ pub const ThermalHeat = Energy;
|
||||
pub const ThermalWork = Energy;
|
||||
pub const ThermalCapacity = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .Tr = -1 });
|
||||
pub const ThermalCapacityPerMass = BaseScalar(.{ .L = 2, .T = -2, .Tr = -1 });
|
||||
pub const ThermalFluxDensity = BaseScalar(.{ .M = 1, .T = -3 }); // Fixed typo from ThermalluxDensity
|
||||
pub const ThermalFluxDensity = BaseScalar(.{ .M = 1, .T = -3 });
|
||||
pub const ThermalConductance = BaseScalar(.{ .M = 1, .L = 2, .T = -3, .Tr = -1 });
|
||||
pub const ThermalConductivity = BaseScalar(.{ .M = 1, .L = 1, .T = -3, .Tr = -1 });
|
||||
pub const ThermalResistance = BaseScalar(.{ .M = -1, .L = -2, .T = 3, .Tr = 1 });
|
||||
@ -152,20 +157,24 @@ pub const ThermalEntropy = BaseScalar(.{ .M = 1, .L = 2, .T = -2, .Tr = -1 });
|
||||
// ==========================================
|
||||
pub const Frequency = BaseScalar(.{ .T = -1 });
|
||||
pub const Viscosity = BaseScalar(.{ .M = 1, .L = -1, .T = -1 });
|
||||
pub const SurfaceTension = BaseScalar(.{ .M = 1, .T = -2 }); // Corrected from MT-2a
|
||||
pub const SurfaceTension = BaseScalar(.{ .M = 1, .T = -2 });
|
||||
|
||||
// ==========================================
|
||||
// Tests
|
||||
// ==========================================
|
||||
|
||||
test "BaseQuantities - Core dimensions instantiation" {
|
||||
// Basic types via generic wrappers
|
||||
const M = Meter.Of(f32);
|
||||
const distance = M.splat(100);
|
||||
try std.testing.expectEqual(100.0, distance.value());
|
||||
try std.testing.expectEqual(100.0, distance.data[0]);
|
||||
try std.testing.expectEqual(1, M.dims.get(.L));
|
||||
try std.testing.expectEqual(0, M.dims.get(.T));
|
||||
|
||||
// Test specific scale variants
|
||||
const Kmh = Speed.Scaled(f32, .{ .L = .k, .T = .hour });
|
||||
const speed = Kmh.splat(120);
|
||||
try std.testing.expectEqual(120.0, speed.value());
|
||||
try std.testing.expectEqual(120.0, speed.data[0]);
|
||||
try std.testing.expectEqual(.k, @TypeOf(speed).scales.get(.L));
|
||||
try std.testing.expectEqual(.hour, @TypeOf(speed).scales.get(.T));
|
||||
}
|
||||
@ -176,12 +185,12 @@ test "BaseQuantities - Kinematics equations" {
|
||||
|
||||
// Velocity = Distance / Time
|
||||
const v = d.div(t);
|
||||
try std.testing.expectEqual(25.0, v.value());
|
||||
try std.testing.expectEqual(25.0, v.data[0]);
|
||||
try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
|
||||
|
||||
// Acceleration = Velocity / Time
|
||||
const a = v.div(t);
|
||||
try std.testing.expectEqual(12.5, a.value());
|
||||
try std.testing.expectEqual(12.5, a.data[0]);
|
||||
try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
|
||||
}
|
||||
|
||||
@ -193,13 +202,13 @@ test "BaseQuantities - Dynamics (Force and Work)" {
|
||||
|
||||
// Force = mass * acceleration
|
||||
const f = m.mul(a);
|
||||
try std.testing.expectEqual(98, f.value());
|
||||
try std.testing.expectEqual(98, f.data[0]);
|
||||
try std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
|
||||
|
||||
// Energy (Work) = Force * distance
|
||||
const distance = Meter.Of(f32).splat(5.0);
|
||||
const energy = f.mul(distance);
|
||||
try std.testing.expectEqual(490, energy.value());
|
||||
try std.testing.expectEqual(490, energy.data[0]);
|
||||
try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
|
||||
}
|
||||
|
||||
@ -209,26 +218,26 @@ test "BaseQuantities - Electric combinations" {
|
||||
|
||||
// Charge = Current * time
|
||||
const charge = current.mul(time);
|
||||
try std.testing.expectEqual(6.0, charge.value());
|
||||
try std.testing.expectEqual(6.0, charge.data[0]);
|
||||
try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
|
||||
}
|
||||
|
||||
test "Constants - Initialization and dimension checks" {
|
||||
// Speed of Light
|
||||
const c = Constants.SpeedOfLight.Of(f64);
|
||||
try std.testing.expectEqual(299792458.0, c.value());
|
||||
try std.testing.expectEqual(299792458.0, c.data[0]);
|
||||
try std.testing.expectEqual(1, @TypeOf(c).dims.get(.L));
|
||||
try std.testing.expectEqual(-1, @TypeOf(c).dims.get(.T));
|
||||
|
||||
// Electron Mass (verifying scale as well)
|
||||
const me = Constants.ElectronMass.Of(f64);
|
||||
try std.testing.expectEqual(9.1093837139e-31, me.value());
|
||||
try std.testing.expectEqual(9.1093837139e-31, me.data[0]);
|
||||
try std.testing.expectEqual(1, @TypeOf(me).dims.get(.M));
|
||||
try std.testing.expectEqual(.k, @TypeOf(me).scales.get(.M)); // Should be scaled to kg
|
||||
|
||||
// Boltzmann Constant (Complex derived dimensions)
|
||||
const kb = Constants.Boltzmann.Of(f64);
|
||||
try std.testing.expectEqual(1.380649e-23, kb.value());
|
||||
try std.testing.expectEqual(1.380649e-23, kb.data[0]);
|
||||
try std.testing.expectEqual(1, @TypeOf(kb).dims.get(.M));
|
||||
try std.testing.expectEqual(2, @TypeOf(kb).dims.get(.L));
|
||||
try std.testing.expectEqual(-2, @TypeOf(kb).dims.get(.T));
|
||||
@ -237,7 +246,7 @@ test "Constants - Initialization and dimension checks" {
|
||||
|
||||
// Vacuum Permittivity
|
||||
const eps0 = Constants.VacuumPermittivity.Of(f64);
|
||||
try std.testing.expectEqual(8.8541878188e-12, eps0.value());
|
||||
try std.testing.expectEqual(8.8541878188e-12, eps0.data[0]);
|
||||
try std.testing.expectEqual(-1, @TypeOf(eps0).dims.get(.M));
|
||||
try std.testing.expectEqual(-3, @TypeOf(eps0).dims.get(.L));
|
||||
try std.testing.expectEqual(4, @TypeOf(eps0).dims.get(.T));
|
||||
@ -245,7 +254,7 @@ test "Constants - Initialization and dimension checks" {
|
||||
|
||||
// Fine Structure Constant (Dimensionless)
|
||||
const alpha = Constants.FineStructure.Of(f64);
|
||||
try std.testing.expectEqual(0.0072973525643, alpha.value());
|
||||
try std.testing.expectEqual(0.0072973525643, alpha.data[0]);
|
||||
try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.M));
|
||||
try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.L));
|
||||
}
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
const std = @import("std");
|
||||
const hlp = @import("helper.zig");
|
||||
const Dimensions = @import("Dimensions.zig");
|
||||
const Dimension = @import("Dimensions.zig").Dimension;
|
||||
|
||||
@ -99,7 +98,7 @@ data: std.EnumArray(Dimension, UnitScale),
|
||||
pub fn init(comptime init_val: ArgOpts) Self {
|
||||
comptime var s = Self{ .data = std.EnumArray(Dimension, UnitScale).initFill(.none) };
|
||||
inline for (std.meta.fields(@TypeOf(init_val))) |f| {
|
||||
if (comptime hlp.isInt(@TypeOf(@field(init_val, f.name))))
|
||||
if (comptime @typeInfo(@TypeOf(@field(init_val, f.name))) == .comptime_int)
|
||||
s.data.set(@field(Dimension, f.name), @enumFromInt(@field(init_val, f.name)))
|
||||
else
|
||||
s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
const std = @import("std");
|
||||
const hlp = @import("helper.zig");
|
||||
const Scales = @import("Scales.zig");
|
||||
const UnitScale = Scales.UnitScale;
|
||||
const Dimensions = @import("Dimensions.zig");
|
||||
const Dimension = Dimensions.Dimension;
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Comptime shape utilities
|
||||
// Comptime utilities
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn shapeTotal(comptime shape: []const usize) usize {
|
||||
@ -34,7 +33,10 @@ pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [sha
|
||||
var out: [shape.len - 1]usize = undefined;
|
||||
var j: usize = 0;
|
||||
for (shape, 0..) |v, i| {
|
||||
if (i != axis) { out[j] = v; j += 1; }
|
||||
if (i != axis) {
|
||||
out[j] = v;
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -96,6 +98,32 @@ pub fn insertAxis(
|
||||
return out;
|
||||
}
|
||||
|
||||
fn isInt(comptime T: type) bool {
|
||||
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
|
||||
}
|
||||
|
||||
fn finerScales(comptime T1: type, comptime T2: type) Scales {
|
||||
const d1: Dimensions = T1.dims;
|
||||
const d2: Dimensions = T2.dims;
|
||||
const s1: Scales = T1.scales;
|
||||
const s2: Scales = T2.scales;
|
||||
comptime var out = Scales.initFill(.none);
|
||||
inline for (std.enums.values(Dimension)) |dim| {
|
||||
const scale1 = comptime s1.get(dim);
|
||||
const scale2 = comptime s2.get(dim);
|
||||
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
|
||||
.none
|
||||
else if (comptime d1.get(dim) == 0)
|
||||
scale2
|
||||
else if (comptime d2.get(dim) == 0)
|
||||
scale1
|
||||
else if (comptime scale1.getFactor() > scale2.getFactor())
|
||||
scale2
|
||||
else
|
||||
scale1);
|
||||
}
|
||||
comptime return out;
|
||||
}
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// File-scope RHS normalisation helpers
|
||||
//
|
||||
@ -104,14 +132,18 @@ pub fn insertAxis(
|
||||
// Actual Tensor types are passed through unchanged.
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn isTensor(comptime Rhs: type) bool {
|
||||
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
|
||||
}
|
||||
|
||||
fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
||||
if (@hasDecl(Rhs, "ISTENSOR")) return Rhs;
|
||||
if (comptime isTensor(Rhs)) return Rhs;
|
||||
return Tensor(T, .{}, .{}, &.{1});
|
||||
}
|
||||
|
||||
fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
|
||||
const Rhs = @TypeOf(r);
|
||||
if (comptime @hasDecl(Rhs, "ISTENSOR")) return r;
|
||||
if (comptime isTensor(Rhs)) return r;
|
||||
const scalar: T = switch (comptime @typeInfo(Rhs)) {
|
||||
.comptime_int => switch (comptime @typeInfo(T)) {
|
||||
.float => @as(T, @floatFromInt(r)),
|
||||
@ -134,46 +166,59 @@ fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
|
||||
return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} };
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Tensor — unified dimensioned ND type.
|
||||
//
|
||||
// T : element numeric type (f32, f64, i32, i128, …)
|
||||
// d_opt : SI dimension exponents
|
||||
// s_opt : unit scales
|
||||
// shape_ : compile-time shape
|
||||
// &.{1} → scalar
|
||||
// &.{3} → 3-vector
|
||||
// &.{4, 4} → 4×4 matrix
|
||||
// &.{3, 3, 3} → 3D field
|
||||
//
|
||||
// Storage: flat @Vector(total, T) where total = product(shape_).
|
||||
// All arithmetic operates on the flat vector directly → SIMD wherever possible.
|
||||
//
|
||||
// Shape-related comptime constants exposed on every Tensor type:
|
||||
// dims : Dimensions — SI exponent struct
|
||||
// scales : Scales — unit scale struct
|
||||
// shape : []const usize
|
||||
// rank : usize = shape.len
|
||||
// total : usize = product(shape)
|
||||
// strides_arr : [rank]usize — row-major strides
|
||||
//
|
||||
// Index helper:
|
||||
// Tensor.idx(.{row, col}) → flat index (comptime, no runtime cost)
|
||||
//
|
||||
// GPU readiness:
|
||||
// tensor.asSlice() → []T (zero-copy pointer to the flat @Vector storage)
|
||||
//
|
||||
// Contraction (replaces dot / cross / matmul):
|
||||
// a.contract(b, axis_a, axis_b)
|
||||
// For rank-1 × rank-1 this is the dot product.
|
||||
// For rank-2 × rank-2 with axis_a=1, axis_b=0 this is matrix multiply.
|
||||
//
|
||||
// Removed from Quantity:
|
||||
// Scalar / Vector aliases, Vec3 / ScalarType, .value(), .vec(), .vec3(),
|
||||
// dot(), cross(), mulScalar(), divScalar(), eqScalar() and friends.
|
||||
// Use Tensor(..., &.{1}), .data[0], mul(), div(), eq() respectively.
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
|
||||
if (n == 0) return;
|
||||
var val = n;
|
||||
if (val < 0) {
|
||||
try writer.writeAll("\u{207B}");
|
||||
val = -val;
|
||||
}
|
||||
var buf: [12]u8 = undefined;
|
||||
const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return;
|
||||
for (str) |c| {
|
||||
const s = switch (c) {
|
||||
'0' => "\u{2070}",
|
||||
'1' => "\u{00B9}",
|
||||
'2' => "\u{00B2}",
|
||||
'3' => "\u{00B3}",
|
||||
'4' => "\u{2074}",
|
||||
'5' => "\u{2075}",
|
||||
'6' => "\u{2076}",
|
||||
'7' => "\u{2077}",
|
||||
'8' => "\u{2078}",
|
||||
'9' => "\u{2079}",
|
||||
else => unreachable,
|
||||
};
|
||||
try writer.writeAll(s);
|
||||
}
|
||||
}
|
||||
|
||||
/// ─────────────────────────────────────────────────────────────────────────────
|
||||
/// Tensor — unified dimensioned ND type.
|
||||
///
|
||||
/// T : element numeric type (f32, f64, i32, i128, …)
|
||||
/// d_opt : SI dimension exponents
|
||||
/// s_opt : unit scales
|
||||
/// shape_ : compile-time shape
|
||||
/// &.{1} → scalar
|
||||
/// &.{3} → 3-vector
|
||||
/// &.{4, 4} → 4×4 matrix
|
||||
/// &.{3, 3, 3} → 3D field
|
||||
///
|
||||
/// Storage: flat @Vector(total, T) where total = product(shape_).
|
||||
/// All arithmetic operates on the flat vector directly → SIMD wherever possible.
|
||||
///
|
||||
/// Shape-related comptime constants exposed on every Tensor type:
|
||||
/// dims : Dimensions — SI exponent struct
|
||||
/// scales : Scales — unit scale struct
|
||||
/// shape : []const usize
|
||||
/// rank : usize = shape.len
|
||||
/// total : usize = product(shape)
|
||||
/// strides_arr : [rank]usize — row-major strides
|
||||
///
|
||||
/// Index helper:
|
||||
/// Tensor.idx(.{row, col}) → flat index (comptime, no runtime cost)
|
||||
/// ─────────────────────────────────────────────────────────────────────────────
|
||||
pub fn Tensor(
|
||||
comptime T: type,
|
||||
comptime d_opt: Dimensions.ArgOpts,
|
||||
@ -196,12 +241,12 @@ pub fn Tensor(
|
||||
|
||||
const Self = @This();
|
||||
|
||||
pub const ValueType : type = T;
|
||||
pub const dims : Dimensions = Dimensions.init(d_opt);
|
||||
pub const scales : Scales = Scales.init(s_opt);
|
||||
pub const shape : []const usize = shape_;
|
||||
pub const rank : usize = shape_.len;
|
||||
pub const total : usize = _total;
|
||||
pub const ValueType: type = T;
|
||||
pub const dims: Dimensions = Dimensions.init(d_opt);
|
||||
pub const scales: Scales = Scales.init(s_opt);
|
||||
pub const shape: []const usize = shape_;
|
||||
pub const rank: usize = shape_.len;
|
||||
pub const total: usize = _total;
|
||||
pub const strides_arr: [shape_.len]usize = _strides;
|
||||
pub const ISTENSOR = true;
|
||||
|
||||
@ -211,7 +256,7 @@ pub fn Tensor(
|
||||
|
||||
/// Convert N-D coords (row-major) to flat index — fully comptime.
|
||||
/// Usage: Tensor.idx(.{row, col})
|
||||
pub fn idx(comptime coords: [rank]usize) usize {
|
||||
pub inline fn idx(comptime coords: [rank]usize) usize {
|
||||
comptime {
|
||||
var flat: usize = 0;
|
||||
for (0..rank) |i| {
|
||||
@ -247,8 +292,12 @@ pub fn Tensor(
|
||||
// Internal: RHS normalisation
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
inline fn RhsT(comptime Rhs: type) type { return RhsTensorType(T, Rhs); }
|
||||
inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { return toRhsTensor(T, r); }
|
||||
inline fn RhsT(comptime Rhs: type) type {
|
||||
return RhsTensorType(T, Rhs);
|
||||
}
|
||||
inline fn rhs(r: anytype) RhsT(@TypeOf(r)) {
|
||||
return toRhsTensor(T, r);
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Internal: scalar broadcast (shape {1} → full Vec)
|
||||
@ -270,7 +319,7 @@ pub fn Tensor(
|
||||
pub inline fn add(self: Self, r: anytype) Tensor(
|
||||
T,
|
||||
dims.argsOpt(),
|
||||
hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
shape_,
|
||||
) {
|
||||
const rhs_q = rhs(r);
|
||||
@ -280,21 +329,21 @@ pub fn Tensor(
|
||||
if (comptime RhsType.total != total and RhsType.total != 1)
|
||||
@compileError("Shape mismatch in add: element counts must match or RHS must be scalar (total=1).");
|
||||
|
||||
const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
const rr: Vec = blk: {
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||
break :blk broadcastToVec(RhsNorm, rn);
|
||||
};
|
||||
return .{ .data = if (comptime hlp.isInt(T)) l +| rr else l + rr };
|
||||
return .{ .data = if (comptime isInt(T)) l +| rr else l + rr };
|
||||
}
|
||||
|
||||
/// Element-wise subtract. Dimensions must match; scales resolve to finer.
|
||||
pub inline fn sub(self: Self, r: anytype) Tensor(
|
||||
T,
|
||||
dims.argsOpt(),
|
||||
hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
shape_,
|
||||
) {
|
||||
const rhs_q = rhs(r);
|
||||
@ -304,14 +353,14 @@ pub fn Tensor(
|
||||
if (comptime RhsType.total != total and RhsType.total != 1)
|
||||
@compileError("Shape mismatch in sub.");
|
||||
|
||||
const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
const rr: Vec = blk: {
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||
break :blk broadcastToVec(RhsNorm, rn);
|
||||
};
|
||||
return .{ .data = if (comptime hlp.isInt(T)) l -| rr else l - rr };
|
||||
return .{ .data = if (comptime isInt(T)) l -| rr else l - rr };
|
||||
}
|
||||
|
||||
/// Element-wise multiply. Dimension exponents summed.
|
||||
@ -319,7 +368,7 @@ pub fn Tensor(
|
||||
pub inline fn mul(self: Self, r: anytype) Tensor(
|
||||
T,
|
||||
dims.add(RhsT(@TypeOf(r)).dims).argsOpt(),
|
||||
hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
shape_,
|
||||
) {
|
||||
const rhs_q = rhs(r);
|
||||
@ -327,12 +376,12 @@ pub fn Tensor(
|
||||
if (comptime RhsType.total != total and RhsType.total != 1)
|
||||
@compileError("Shape mismatch in mul.");
|
||||
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||
return .{ .data = if (comptime hlp.isInt(T)) l *| rr else l * rr };
|
||||
return .{ .data = if (comptime isInt(T)) l *| rr else l * rr };
|
||||
}
|
||||
|
||||
/// Element-wise divide. Dimension exponents subtracted.
|
||||
@ -340,7 +389,7 @@ pub fn Tensor(
|
||||
pub inline fn div(self: Self, r: anytype) Tensor(
|
||||
T,
|
||||
dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(),
|
||||
hlp.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
shape_,
|
||||
) {
|
||||
const rhs_q = rhs(r);
|
||||
@ -348,12 +397,12 @@ pub fn Tensor(
|
||||
if (comptime RhsType.total != total and RhsType.total != 1)
|
||||
@compileError("Shape mismatch in div.");
|
||||
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||
if (comptime hlp.isInt(T)) {
|
||||
if (comptime isInt(T)) {
|
||||
var result: Vec = undefined;
|
||||
inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]);
|
||||
return .{ .data = result };
|
||||
@ -378,7 +427,7 @@ pub fn Tensor(
|
||||
scales.argsOpt(),
|
||||
shape_,
|
||||
) {
|
||||
if (comptime hlp.isInt(T)) {
|
||||
if (comptime isInt(T)) {
|
||||
var result: Vec = undefined;
|
||||
inline for (0..total) |i|
|
||||
result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T);
|
||||
@ -506,10 +555,10 @@ pub fn Tensor(
|
||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||
inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } {
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
const TargetType = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
const rr: Vec = blk: {
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), hlp.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||
break :blk broadcastToVec(RhsNorm, rn);
|
||||
};
|
||||
@ -604,9 +653,9 @@ pub fn Tensor(
|
||||
comptime axis_b: usize,
|
||||
) blk: {
|
||||
const OT = @TypeOf(other);
|
||||
comptime std.debug.assert(axis_a < rank);
|
||||
comptime std.debug.assert(axis_b < OT.rank);
|
||||
comptime std.debug.assert(shape_[axis_a] == OT.shape[axis_b]);
|
||||
std.debug.assert(axis_a < rank);
|
||||
std.debug.assert(axis_b < OT.rank);
|
||||
std.debug.assert(shape_[axis_a] == OT.shape[axis_b]);
|
||||
// Contracted-away free axes; empty joint → scalar shape {1}
|
||||
const sa = shapeRemoveAxis(shape_, axis_a);
|
||||
const sb = shapeRemoveAxis(OT.shape, axis_b);
|
||||
@ -615,7 +664,7 @@ pub fn Tensor(
|
||||
break :blk Tensor(
|
||||
T,
|
||||
dims.add(OT.dims).argsOpt(),
|
||||
hlp.finerScales(Self, OT).argsOpt(),
|
||||
finerScales(Self, OT).argsOpt(),
|
||||
rs,
|
||||
);
|
||||
} {
|
||||
@ -630,13 +679,13 @@ pub fn Tensor(
|
||||
const ResultType = Tensor(
|
||||
T,
|
||||
dims.add(OT.dims).argsOpt(),
|
||||
hlp.finerScales(Self, OT).argsOpt(),
|
||||
finerScales(Self, OT).argsOpt(),
|
||||
rs,
|
||||
);
|
||||
|
||||
// Normalise scales before accumulation
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), shape_);
|
||||
const OtherNorm = Tensor(T, OT.dims.argsOpt(), hlp.finerScales(Self, OT).argsOpt(), OT.shape);
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
|
||||
const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
|
||||
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
|
||||
|
||||
@ -661,7 +710,7 @@ pub fn Tensor(
|
||||
const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides);
|
||||
const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
|
||||
|
||||
if (comptime hlp.isInt(T))
|
||||
if (comptime isInt(T))
|
||||
acc +|= a_data[a_flat] *| b_data[b_flat]
|
||||
else
|
||||
acc += a_data[a_flat] * b_data[b_flat];
|
||||
@ -751,7 +800,7 @@ pub fn Tensor(
|
||||
else
|
||||
try writer.print("{s}{s}", .{ uscale.str(), bu.unit() });
|
||||
|
||||
if (v != 1) try hlp.printSuperscript(writer, v);
|
||||
if (v != 1) try printSuperscript(writer, v);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -1220,8 +1269,8 @@ test "Vector contract — dot product (rank-1 × rank-1)" {
|
||||
|
||||
test "Vector contract — matrix multiply (rank-2 × rank-2)" {
|
||||
// 2×3 matrix multiplied by 3×2 matrix → 2×2 result
|
||||
const A = Tensor(f32, .{}, .{}, &.{2, 3});
|
||||
const B = Tensor(f32, .{}, .{}, &.{3, 2});
|
||||
const A = Tensor(f32, .{}, .{}, &.{ 2, 3 });
|
||||
const B = Tensor(f32, .{}, .{}, &.{ 3, 2 });
|
||||
|
||||
// A = [[1, 2, 3],
|
||||
// [4, 5, 6]]
|
||||
@ -1237,10 +1286,10 @@ test "Vector contract — matrix multiply (rank-2 × rank-2)" {
|
||||
// C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139
|
||||
// C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154
|
||||
const c = a.contract(b, 1, 0);
|
||||
try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 0})]);
|
||||
try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{0, 1})]);
|
||||
try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 0})]);
|
||||
try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{2,2}).idx(.{1, 1})]);
|
||||
try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]);
|
||||
try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]);
|
||||
try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 0 })]);
|
||||
try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 1 })]);
|
||||
}
|
||||
|
||||
test "Vector Abs, Pow, Sqrt and Product" {
|
||||
@ -1321,12 +1370,12 @@ test "Vector eq broadcast on dimensionless" {
|
||||
}
|
||||
|
||||
test "Tensor idx helper and matrix access" {
|
||||
const Mat3x3 = Tensor(f32, .{}, .{}, &.{3, 3});
|
||||
const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 });
|
||||
// Identity-like: set [0][0]=1, [1][1]=2, [2][2]=3
|
||||
var m: Mat3x3 = Mat3x3.zero;
|
||||
m.data[Mat3x3.idx(.{0, 0})] = 1.0;
|
||||
m.data[Mat3x3.idx(.{1, 1})] = 2.0;
|
||||
m.data[Mat3x3.idx(.{2, 2})] = 3.0;
|
||||
m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0;
|
||||
m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0;
|
||||
m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0;
|
||||
|
||||
try std.testing.expectEqual(1.0, m.data[0]); // [0][0]
|
||||
try std.testing.expectEqual(2.0, m.data[4]); // [1][1] (stride 3 → 1*3+1=4)
|
||||
@ -1336,8 +1385,8 @@ test "Tensor idx helper and matrix access" {
|
||||
|
||||
test "Tensor strides_arr correctness" {
|
||||
const T1 = Tensor(f32, .{}, .{}, &.{3});
|
||||
const T2 = Tensor(f32, .{}, .{}, &.{3, 4});
|
||||
const T3 = Tensor(f32, .{}, .{}, &.{2, 3, 4});
|
||||
const T2 = Tensor(f32, .{}, .{}, &.{ 3, 4 });
|
||||
const T3 = Tensor(f32, .{}, .{}, &.{ 2, 3, 4 });
|
||||
|
||||
try std.testing.expectEqual(1, T1.strides_arr[0]);
|
||||
try std.testing.expectEqual(4, T2.strides_arr[0]);
|
||||
@ -1,97 +0,0 @@
|
||||
const std = @import("std");
|
||||
|
||||
pub fn isInt(comptime T: type) bool {
|
||||
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
|
||||
}
|
||||
|
||||
pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
|
||||
if (n == 0) return;
|
||||
var val = n;
|
||||
if (val < 0) {
|
||||
try writer.writeAll("\u{207B}");
|
||||
val = -val;
|
||||
}
|
||||
var buf: [12]u8 = undefined;
|
||||
const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return;
|
||||
for (str) |c| {
|
||||
const s = switch (c) {
|
||||
'0' => "\u{2070}",
|
||||
'1' => "\u{00B9}",
|
||||
'2' => "\u{00B2}",
|
||||
'3' => "\u{00B3}",
|
||||
'4' => "\u{2074}",
|
||||
'5' => "\u{2075}",
|
||||
'6' => "\u{2076}",
|
||||
'7' => "\u{2077}",
|
||||
'8' => "\u{2078}",
|
||||
'9' => "\u{2079}",
|
||||
else => unreachable,
|
||||
};
|
||||
try writer.writeAll(s);
|
||||
}
|
||||
}
|
||||
|
||||
const Scales = @import("Scales.zig");
|
||||
const Dimensions = @import("Dimensions.zig");
|
||||
const Dimension = @import("Dimensions.zig").Dimension;
|
||||
|
||||
pub fn finerScales(comptime T1: type, comptime T2: type) Scales {
|
||||
const d1: Dimensions = T1.dims;
|
||||
const d2: Dimensions = T2.dims;
|
||||
const s1: Scales = T1.scales;
|
||||
const s2: Scales = T2.scales;
|
||||
comptime var out = Scales.initFill(.none);
|
||||
inline for (std.enums.values(Dimension)) |dim| {
|
||||
const scale1 = comptime s1.get(dim);
|
||||
const scale2 = comptime s2.get(dim);
|
||||
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
|
||||
.none
|
||||
else if (comptime d1.get(dim) == 0)
|
||||
scale2
|
||||
else if (comptime d2.get(dim) == 0)
|
||||
scale1
|
||||
else if (comptime scale1.getFactor() > scale2.getFactor())
|
||||
scale2
|
||||
else
|
||||
scale1);
|
||||
}
|
||||
comptime return out;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RHS normalisation helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const Quantity = @import("Quantity.zig").Quantity;
|
||||
|
||||
/// Returns true if `T` is a `Scalar_` type (has `dims`, `scales`, and `value`).
|
||||
pub fn isScalarType(comptime T: type) bool {
|
||||
return @typeInfo(T) == .@"struct" and
|
||||
@hasDecl(T, "ISQUANTITY") and
|
||||
@field(T, "ISQUANTITY");
|
||||
}
|
||||
|
||||
/// Resolve the Scalar type that `rhs` will be treated as.
|
||||
///
|
||||
/// Accepted rhs types:
|
||||
/// - Any `Scalar_` type → returned as-is
|
||||
/// - `comptime_int` / `comptime_float` → dimensionless `Scalar_(BaseT, {}, {})`
|
||||
/// - `BaseT` (the scalar's value type) → dimensionless `Scalar_(BaseT, {}, {})`
|
||||
///
|
||||
/// Everything else is a compile error, including other int/float types.
|
||||
pub fn rhsQuantityType(comptime ValueType: type, N: usize, comptime RhsT: type) type {
|
||||
if (comptime isScalarType(RhsT)) return RhsT;
|
||||
if (comptime RhsT == comptime_int or RhsT == comptime_float or RhsT == ValueType)
|
||||
return Quantity(ValueType, N, .{}, .{});
|
||||
@compileError(
|
||||
"rhs must be a Scalar, " ++ @typeName(ValueType) ++
|
||||
", comptime_int, or comptime_float; got " ++ @typeName(RhsT),
|
||||
);
|
||||
}
|
||||
|
||||
/// Convert `rhs` to its normalised Scalar form (see `rhsScalarType`).
|
||||
pub inline fn toRhsQuantity(comptime BaseT: type, N: usize, rhs: anytype) rhsQuantityType(BaseT, N, @TypeOf(rhs)) {
|
||||
if (comptime isScalarType(@TypeOf(rhs))) return rhs;
|
||||
const DimLess = Quantity(BaseT, N, .{}, .{});
|
||||
return DimLess{ .data = @splat(@as(BaseT, rhs)) };
|
||||
}
|
||||
@ -1,15 +1,13 @@
|
||||
const std = @import("std");
|
||||
|
||||
pub const Vector = @import("Quantity.zig").Vector;
|
||||
pub const Scalar = @import("Quantity.zig").Scalar;
|
||||
pub const Tensor = @import("Tensor.zig").Tensor;
|
||||
pub const Dimensions = @import("Dimensions.zig");
|
||||
pub const Scales = @import("Scales.zig");
|
||||
pub const Base = @import("Base.zig");
|
||||
|
||||
test {
|
||||
_ = @import("Quantity.zig");
|
||||
_ = @import("Tensor.zig");
|
||||
_ = @import("Dimensions.zig");
|
||||
_ = @import("Scales.zig");
|
||||
_ = @import("Base.zig");
|
||||
_ = @import("helper.zig");
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user