Now pass all test with new *const way
I am not quite sure about it yet, but it is faster sooo idk. Let's see long term
This commit is contained in:
parent
26ff02c50f
commit
8816a65518
46
src/Base.zig
46
src/Base.zig
@ -11,7 +11,7 @@ fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime
|
|||||||
pub const scales = Scales.init(s);
|
pub const scales = Scales.init(s);
|
||||||
|
|
||||||
/// Instantiates the constant into a specific numeric type.
|
/// Instantiates the constant into a specific numeric type.
|
||||||
pub fn Of(comptime T: type) Tensor(T, d, s, &.{1}) {
|
pub fn Of(comptime T: type) *const Tensor(T, d, s, &.{1}) {
|
||||||
const casted_val: T = switch (@typeInfo(T)) {
|
const casted_val: T = switch (@typeInfo(T)) {
|
||||||
.float => @floatCast(val),
|
.float => @floatCast(val),
|
||||||
.int => @intFromFloat(val),
|
.int => @intFromFloat(val),
|
||||||
@ -175,8 +175,8 @@ test "BaseQuantities - Core dimensions instantiation" {
|
|||||||
const Kmh = Speed.Scaled(f32, .{ .L = .k, .T = .hour });
|
const Kmh = Speed.Scaled(f32, .{ .L = .k, .T = .hour });
|
||||||
const speed = Kmh.splat(120);
|
const speed = Kmh.splat(120);
|
||||||
try std.testing.expectEqual(120.0, speed.data[0]);
|
try std.testing.expectEqual(120.0, speed.data[0]);
|
||||||
try std.testing.expectEqual(.k, @TypeOf(speed).scales.get(.L));
|
try std.testing.expectEqual(.k, @TypeOf(speed.*).scales.get(.L));
|
||||||
try std.testing.expectEqual(.hour, @TypeOf(speed).scales.get(.T));
|
try std.testing.expectEqual(.hour, @TypeOf(speed.*).scales.get(.T));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "BaseQuantities - Kinematics equations" {
|
test "BaseQuantities - Kinematics equations" {
|
||||||
@ -186,12 +186,12 @@ test "BaseQuantities - Kinematics equations" {
|
|||||||
// Velocity = Distance / Time
|
// Velocity = Distance / Time
|
||||||
const v = d.div(t);
|
const v = d.div(t);
|
||||||
try std.testing.expectEqual(25.0, v.data[0]);
|
try std.testing.expectEqual(25.0, v.data[0]);
|
||||||
try comptime std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
|
try comptime std.testing.expect(Speed.dims.eql(@TypeOf(v.*).dims));
|
||||||
|
|
||||||
// Acceleration = Velocity / Time
|
// Acceleration = Velocity / Time
|
||||||
const a = v.div(t);
|
const a = v.div(t);
|
||||||
try std.testing.expectEqual(12.5, a.data[0]);
|
try std.testing.expectEqual(12.5, a.data[0]);
|
||||||
try comptime std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
|
try comptime std.testing.expect(Acceleration.dims.eql(@TypeOf(a.*).dims));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "BaseQuantities - Dynamics (Force and Work)" {
|
test "BaseQuantities - Dynamics (Force and Work)" {
|
||||||
@ -203,13 +203,13 @@ test "BaseQuantities - Dynamics (Force and Work)" {
|
|||||||
// Force = mass * acceleration
|
// Force = mass * acceleration
|
||||||
const f = m.mul(a);
|
const f = m.mul(a);
|
||||||
try std.testing.expectEqual(98, f.data[0]);
|
try std.testing.expectEqual(98, f.data[0]);
|
||||||
try comptime std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
|
try comptime std.testing.expect(Force.dims.eql(@TypeOf(f.*).dims));
|
||||||
|
|
||||||
// Energy (Work) = Force * distance
|
// Energy (Work) = Force * distance
|
||||||
const distance = Meter.Of(f32).splat(5.0);
|
const distance = Meter.Of(f32).splat(5.0);
|
||||||
const energy = f.mul(distance);
|
const energy = f.mul(distance);
|
||||||
try std.testing.expectEqual(490, energy.data[0]);
|
try std.testing.expectEqual(490, energy.data[0]);
|
||||||
try comptime std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
|
try comptime std.testing.expect(Energy.dims.eql(@TypeOf(energy.*).dims));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "BaseQuantities - Electric combinations" {
|
test "BaseQuantities - Electric combinations" {
|
||||||
@ -219,42 +219,42 @@ test "BaseQuantities - Electric combinations" {
|
|||||||
// Charge = Current * time
|
// Charge = Current * time
|
||||||
const charge = current.mul(time);
|
const charge = current.mul(time);
|
||||||
try std.testing.expectEqual(6.0, charge.data[0]);
|
try std.testing.expectEqual(6.0, charge.data[0]);
|
||||||
try comptime std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
|
try comptime std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge.*).dims));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Constants - Initialization and dimension checks" {
|
test "Constants - Initialization and dimension checks" {
|
||||||
// Speed of Light
|
// Speed of Light
|
||||||
const c = Constants.SpeedOfLight.Of(f64);
|
const c = Constants.SpeedOfLight.Of(f64);
|
||||||
try std.testing.expectEqual(299792458.0, c.data[0]);
|
try std.testing.expectEqual(299792458.0, c.data[0]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(c).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(c.*).dims.get(.L));
|
||||||
try std.testing.expectEqual(-1, @TypeOf(c).dims.get(.T));
|
try std.testing.expectEqual(-1, @TypeOf(c.*).dims.get(.T));
|
||||||
|
|
||||||
// Electron Mass (verifying scale as well)
|
// Electron Mass (verifying scale as well)
|
||||||
const me = Constants.ElectronMass.Of(f64);
|
const me = Constants.ElectronMass.Of(f64);
|
||||||
try std.testing.expectEqual(9.1093837139e-31, me.data[0]);
|
try std.testing.expectEqual(9.1093837139e-31, me.data[0]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(me).dims.get(.M));
|
try std.testing.expectEqual(1, @TypeOf(me.*).dims.get(.M));
|
||||||
try std.testing.expectEqual(.k, @TypeOf(me).scales.get(.M)); // Should be scaled to kg
|
try std.testing.expectEqual(.k, @TypeOf(me.*).scales.get(.M)); // Should be scaled to kg
|
||||||
|
|
||||||
// Boltzmann Constant (Complex derived dimensions)
|
// Boltzmann Constant (Complex derived dimensions)
|
||||||
const kb = Constants.Boltzmann.Of(f64);
|
const kb = Constants.Boltzmann.Of(f64);
|
||||||
try std.testing.expectEqual(1.380649e-23, kb.data[0]);
|
try std.testing.expectEqual(1.380649e-23, kb.data[0]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(kb).dims.get(.M));
|
try std.testing.expectEqual(1, @TypeOf(kb.*).dims.get(.M));
|
||||||
try std.testing.expectEqual(2, @TypeOf(kb).dims.get(.L));
|
try std.testing.expectEqual(2, @TypeOf(kb.*).dims.get(.L));
|
||||||
try std.testing.expectEqual(-2, @TypeOf(kb).dims.get(.T));
|
try std.testing.expectEqual(-2, @TypeOf(kb.*).dims.get(.T));
|
||||||
try std.testing.expectEqual(-1, @TypeOf(kb).dims.get(.Tp));
|
try std.testing.expectEqual(-1, @TypeOf(kb.*).dims.get(.Tp));
|
||||||
try std.testing.expectEqual(.k, @TypeOf(kb).scales.get(.M));
|
try std.testing.expectEqual(.k, @TypeOf(kb.*).scales.get(.M));
|
||||||
|
|
||||||
// Vacuum Permittivity
|
// Vacuum Permittivity
|
||||||
const eps0 = Constants.VacuumPermittivity.Of(f64);
|
const eps0 = Constants.VacuumPermittivity.Of(f64);
|
||||||
try std.testing.expectEqual(8.8541878188e-12, eps0.data[0]);
|
try std.testing.expectEqual(8.8541878188e-12, eps0.data[0]);
|
||||||
try std.testing.expectEqual(-1, @TypeOf(eps0).dims.get(.M));
|
try std.testing.expectEqual(-1, @TypeOf(eps0.*).dims.get(.M));
|
||||||
try std.testing.expectEqual(-3, @TypeOf(eps0).dims.get(.L));
|
try std.testing.expectEqual(-3, @TypeOf(eps0.*).dims.get(.L));
|
||||||
try std.testing.expectEqual(4, @TypeOf(eps0).dims.get(.T));
|
try std.testing.expectEqual(4, @TypeOf(eps0.*).dims.get(.T));
|
||||||
try std.testing.expectEqual(2, @TypeOf(eps0).dims.get(.I));
|
try std.testing.expectEqual(2, @TypeOf(eps0.*).dims.get(.I));
|
||||||
|
|
||||||
// Fine Structure Constant (Dimensionless)
|
// Fine Structure Constant (Dimensionless)
|
||||||
const alpha = Constants.FineStructure.Of(f64);
|
const alpha = Constants.FineStructure.Of(f64);
|
||||||
try std.testing.expectEqual(0.0072973525643, alpha.data[0]);
|
try std.testing.expectEqual(0.0072973525643, alpha.data[0]);
|
||||||
try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.M));
|
try std.testing.expectEqual(0, @TypeOf(alpha.*).dims.get(.M));
|
||||||
try std.testing.expectEqual(0, @TypeOf(alpha).dims.get(.L));
|
try std.testing.expectEqual(0, @TypeOf(alpha.*).dims.get(.L));
|
||||||
}
|
}
|
||||||
|
|||||||
150
src/Tensor.zig
150
src/Tensor.zig
@ -139,8 +139,8 @@ inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
|||||||
|
|
||||||
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
|
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
|
||||||
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
|
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
|
||||||
inline fn toRhsTensor(comptime T: type, r: anytype) *const RhsTensorType(T, @TypeOf(r.*)) {
|
inline fn toRhsTensor(comptime T: type, r: anytype) *const RhsTensorType(T, if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) {
|
||||||
const Rhs = @TypeOf(r.*);
|
const Rhs = if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r);
|
||||||
if (comptime isTensor(Rhs)) return r;
|
if (comptime isTensor(Rhs)) return r;
|
||||||
const scalar: T = switch (@typeInfo(Rhs)) {
|
const scalar: T = switch (@typeInfo(Rhs)) {
|
||||||
.comptime_int => switch (comptime @typeInfo(T)) {
|
.comptime_int => switch (comptime @typeInfo(T)) {
|
||||||
@ -161,7 +161,7 @@ inline fn toRhsTensor(comptime T: type, r: anytype) *const RhsTensorType(T, @Typ
|
|||||||
},
|
},
|
||||||
else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)),
|
else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)),
|
||||||
};
|
};
|
||||||
return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} };
|
return &Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
|
pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
|
||||||
@ -237,12 +237,12 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Broadcast a single value across all elements.
|
/// Broadcast a single value across all elements.
|
||||||
pub inline fn splat(v: T) Self {
|
pub inline fn splat(v: T) *const Self {
|
||||||
return .{ .data = @splat(v) };
|
return &.{ .data = @splat(v) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const zero: Self = splat(0);
|
pub const zero: Self = splat(0).*;
|
||||||
pub const one: Self = splat(1);
|
pub const one: Self = splat(1).*;
|
||||||
|
|
||||||
/// Return a mutable slice to the flat storage — zero-copy WebGPU buffer mapping.
|
/// Return a mutable slice to the flat storage — zero-copy WebGPU buffer mapping.
|
||||||
pub inline fn asSlice(self: *Self) []T {
|
pub inline fn asSlice(self: *Self) []T {
|
||||||
@ -252,11 +252,12 @@ pub fn Tensor(
|
|||||||
inline fn RhsT(comptime Rhs: type) type {
|
inline fn RhsT(comptime Rhs: type) type {
|
||||||
return RhsTensorType(T, Rhs);
|
return RhsTensorType(T, Rhs);
|
||||||
}
|
}
|
||||||
inline fn rhs(r: anytype) *const RhsT(@TypeOf(r.*)) {
|
|
||||||
|
inline fn rhs(r: anytype) *const RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) {
|
||||||
return toRhsTensor(T, r);
|
return toRhsTensor(T, r);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec {
|
inline fn broadcastToVec(comptime RhsType: type, r: *const RhsType) Vec {
|
||||||
return if (comptime RhsType.total == 1 and total > 1)
|
return if (comptime RhsType.total == 1 and total > 1)
|
||||||
@splat(r.data[0])
|
@splat(r.data[0])
|
||||||
else
|
else
|
||||||
@ -268,7 +269,7 @@ pub fn Tensor(
|
|||||||
pub inline fn add(self: *const Self, r: anytype) *const Tensor(
|
pub inline fn add(self: *const Self, r: anytype) *const Tensor(
|
||||||
T,
|
T,
|
||||||
dims.argsOpt(),
|
dims.argsOpt(),
|
||||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
const rhs_t = rhs(r);
|
const rhs_t = rhs(r);
|
||||||
@ -295,7 +296,7 @@ pub fn Tensor(
|
|||||||
pub inline fn sub(self: *const Self, r: anytype) *const Tensor(
|
pub inline fn sub(self: *const Self, r: anytype) *const Tensor(
|
||||||
T,
|
T,
|
||||||
dims.argsOpt(),
|
dims.argsOpt(),
|
||||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
const rhs_t = rhs(r);
|
const rhs_t = rhs(r);
|
||||||
@ -321,8 +322,8 @@ pub fn Tensor(
|
|||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn mul(self: *const Self, r: anytype) *const Tensor(
|
pub inline fn mul(self: *const Self, r: anytype) *const Tensor(
|
||||||
T,
|
T,
|
||||||
dims.add(RhsT(@TypeOf(r)).dims).argsOpt(),
|
dims.add(RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)).dims).argsOpt(),
|
||||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
const rhs_q = rhs(r);
|
const rhs_q = rhs(r);
|
||||||
@ -334,7 +335,7 @@ pub fn Tensor(
|
|||||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||||
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base.*);
|
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||||
return &.{ .data = if (comptime isInt(T)) l *| rr else l * rr };
|
return &.{ .data = if (comptime isInt(T)) l *| rr else l * rr };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,8 +343,8 @@ pub fn Tensor(
|
|||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn div(self: *const Self, r: anytype) *const Tensor(
|
pub inline fn div(self: *const Self, r: anytype) *const Tensor(
|
||||||
T,
|
T,
|
||||||
dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(),
|
dims.sub(RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)).dims).argsOpt(),
|
||||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
const rhs_q = rhs(r);
|
const rhs_q = rhs(r);
|
||||||
@ -355,7 +356,7 @@ pub fn Tensor(
|
|||||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||||
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base.*);
|
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||||
if (comptime isInt(T)) {
|
if (comptime isInt(T)) {
|
||||||
return &.{ .data = @divTrunc(l, rr) };
|
return &.{ .data = @divTrunc(l, rr) };
|
||||||
} else {
|
} else {
|
||||||
@ -369,7 +370,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Raise every element to a comptime integer exponent.
|
/// Raise every element to a comptime integer exponent.
|
||||||
pub inline fn pow(self: Self, comptime exp: comptime_int) Tensor(
|
pub inline fn pow(self: Self, comptime exp: comptime_int) *const Tensor(
|
||||||
T,
|
T,
|
||||||
dims.scale(exp).argsOpt(),
|
dims.scale(exp).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
@ -399,7 +400,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Square root of every element. All dimension exponents must be even.
|
/// Square root of every element. All dimension exponents must be even.
|
||||||
pub inline fn sqrt(self: Self) Tensor(
|
pub inline fn sqrt(self: Self) *const Tensor(
|
||||||
T,
|
T,
|
||||||
dims.div(2).argsOpt(),
|
dims.div(2).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
@ -422,7 +423,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Negate every element.
|
/// Negate every element.
|
||||||
pub inline fn negate(self: Self) Self {
|
pub inline fn negate(self: Self) *const Self {
|
||||||
return &.{ .data = -self.data };
|
return &.{ .data = -self.data };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -526,7 +527,7 @@ pub fn Tensor(
|
|||||||
const rr: Vec = blk: {
|
const rr: Vec = blk: {
|
||||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||||
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||||
break :blk broadcastToVec(RhsNorm, rn.*);
|
break :blk broadcastToVec(RhsNorm, rn);
|
||||||
};
|
};
|
||||||
return .{ .l = l, .r = rr };
|
return .{ .l = l, .r = rr };
|
||||||
}
|
}
|
||||||
@ -598,7 +599,7 @@ pub fn Tensor(
|
|||||||
comptime axis_a: usize,
|
comptime axis_a: usize,
|
||||||
comptime axis_b: usize,
|
comptime axis_b: usize,
|
||||||
) blk: {
|
) blk: {
|
||||||
const OT = @TypeOf(other);
|
const OT = @TypeOf(other.*);
|
||||||
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
||||||
if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds");
|
if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds");
|
||||||
if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
|
if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
|
||||||
@ -614,7 +615,7 @@ pub fn Tensor(
|
|||||||
rs,
|
rs,
|
||||||
);
|
);
|
||||||
} {
|
} {
|
||||||
const OT = @TypeOf(other);
|
const OT = @TypeOf(other.*);
|
||||||
const k: usize = comptime shape_[axis_a]; // contraction dimension
|
const k: usize = comptime shape_[axis_a]; // contraction dimension
|
||||||
|
|
||||||
const sa = comptime shapeRemoveAxis(shape_, axis_a);
|
const sa = comptime shapeRemoveAxis(shape_, axis_a);
|
||||||
@ -638,7 +639,7 @@ pub fn Tensor(
|
|||||||
// FAST PATH: Dot Product
|
// FAST PATH: Dot Product
|
||||||
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
|
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
|
||||||
if (comptime !isInt(T)) {
|
if (comptime !isInt(T)) {
|
||||||
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
|
return &.{ .data = @splat(@reduce(.Add, a_data * b_data)) };
|
||||||
} else {
|
} else {
|
||||||
// For integers, we do a vectorized saturating multiply,
|
// For integers, we do a vectorized saturating multiply,
|
||||||
// then convert to an array to do a saturating sum
|
// then convert to an array to do a saturating sum
|
||||||
@ -772,7 +773,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatNumber(
|
pub fn formatNumber(
|
||||||
self: *const Self,
|
self: Self,
|
||||||
writer: *std.Io.Writer,
|
writer: *std.Io.Writer,
|
||||||
options: std.fmt.Number,
|
options: std.fmt.Number,
|
||||||
) !void {
|
) !void {
|
||||||
@ -876,7 +877,7 @@ test "Scalar Add" {
|
|||||||
const distance2 = Meter.splat(20);
|
const distance2 = Meter.splat(20);
|
||||||
const added = distance.add(distance2);
|
const added = distance.add(distance2);
|
||||||
try std.testing.expectEqual(30, added.data[0]);
|
try std.testing.expectEqual(30, added.data[0]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(added.*).dims.get(.L));
|
||||||
|
|
||||||
const distance3 = KiloMeter.splat(2);
|
const distance3 = KiloMeter.splat(2);
|
||||||
const added2 = distance.add(distance3);
|
const added2 = distance.add(distance3);
|
||||||
@ -915,13 +916,13 @@ test "Scalar MulBy" {
|
|||||||
const t = Second.splat(4);
|
const t = Second.splat(4);
|
||||||
const at = d.mul(t);
|
const at = d.mul(t);
|
||||||
try std.testing.expectEqual(12, at.data[0]);
|
try std.testing.expectEqual(12, at.data[0]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(at).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(at.*).dims.get(.L));
|
||||||
try std.testing.expectEqual(1, @TypeOf(at).dims.get(.T));
|
try std.testing.expectEqual(1, @TypeOf(at.*).dims.get(.T));
|
||||||
|
|
||||||
const d2 = Meter.splat(5);
|
const d2 = Meter.splat(5);
|
||||||
const area = d.mul(d2);
|
const area = d.mul(d2);
|
||||||
try std.testing.expectEqual(15, area.data[0]);
|
try std.testing.expectEqual(15, area.data[0]);
|
||||||
try std.testing.expectEqual(2, @TypeOf(area).dims.get(.L));
|
try std.testing.expectEqual(2, @TypeOf(area.*).dims.get(.L));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar MulBy with scale" {
|
test "Scalar MulBy with scale" {
|
||||||
@ -931,8 +932,8 @@ test "Scalar MulBy with scale" {
|
|||||||
const dist = KiloMeter.splat(2.0);
|
const dist = KiloMeter.splat(2.0);
|
||||||
const mass = KiloGram.splat(3.0);
|
const mass = KiloGram.splat(3.0);
|
||||||
const prod = dist.mul(mass);
|
const prod = dist.mul(mass);
|
||||||
try std.testing.expectEqual(1, @TypeOf(prod).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(prod.*).dims.get(.L));
|
||||||
try std.testing.expectEqual(1, @TypeOf(prod).dims.get(.M));
|
try std.testing.expectEqual(1, @TypeOf(prod.*).dims.get(.M));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar MulBy with type change" {
|
test "Scalar MulBy with type change" {
|
||||||
@ -943,9 +944,10 @@ test "Scalar MulBy with type change" {
|
|||||||
|
|
||||||
const d = Meter.splat(3);
|
const d = Meter.splat(3);
|
||||||
const t = Second.splat(4);
|
const t = Second.splat(4);
|
||||||
|
const ms = d.mul(t);
|
||||||
|
|
||||||
try std.testing.expectEqual(12, d.mul(t).to(KmSec).data[0]);
|
try std.testing.expectEqual(12, ms.to(KmSec).data[0]);
|
||||||
try std.testing.expectApproxEqAbs(12.0, d.mul(t).to(KmSec_f).data[0], 0.0001);
|
try std.testing.expectApproxEqAbs(12.0, ms.to(KmSec_f).data[0], 0.0001);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar MulBy small" {
|
test "Scalar MulBy small" {
|
||||||
@ -971,7 +973,7 @@ test "Scalar Sqrt" {
|
|||||||
var d = MeterSquare.splat(9);
|
var d = MeterSquare.splat(9);
|
||||||
var scaled = d.sqrt();
|
var scaled = d.sqrt();
|
||||||
try std.testing.expectEqual(3, scaled.data[0]);
|
try std.testing.expectEqual(3, scaled.data[0]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(scaled.*).dims.get(.L));
|
||||||
|
|
||||||
d = MeterSquare.splat(-5);
|
d = MeterSquare.splat(-5);
|
||||||
scaled = d.sqrt();
|
scaled = d.sqrt();
|
||||||
@ -1014,7 +1016,7 @@ test "Scalar Finer scales skip dim 0" {
|
|||||||
const km = KiloMetre.splat(4);
|
const km = KiloMetre.splat(4);
|
||||||
const vel = r.mul(km);
|
const vel = r.mul(km);
|
||||||
try std.testing.expectEqual(120, vel.data[0]);
|
try std.testing.expectEqual(120, vel.data[0]);
|
||||||
try std.testing.expectEqual(Scales.UnitScale.k, @TypeOf(vel).scales.get(.L));
|
try std.testing.expectEqual(Scales.UnitScale.k, @TypeOf(vel.*).scales.get(.L));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Conversion chain: km -> m -> cm" {
|
test "Scalar Conversion chain: km -> m -> cm" {
|
||||||
@ -1049,10 +1051,10 @@ test "Scalar Format" {
|
|||||||
const accel = MeterPerSecondSq.splat(9.81);
|
const accel = MeterPerSecondSq.splat(9.81);
|
||||||
|
|
||||||
var buf: [64]u8 = undefined;
|
var buf: [64]u8 = undefined;
|
||||||
var res = try std.fmt.bufPrint(&buf, "{d:.2}", .{m});
|
var res = try std.fmt.bufPrint(&buf, "{d:.2}", .{m.*});
|
||||||
try std.testing.expectEqualStrings("1.23m", res);
|
try std.testing.expectEqualStrings("1.23m", res);
|
||||||
|
|
||||||
res = try std.fmt.bufPrint(&buf, "{d}", .{accel});
|
res = try std.fmt.bufPrint(&buf, "{d}", .{accel.*});
|
||||||
try std.testing.expectEqualStrings("9.81m.ns⁻²", res);
|
try std.testing.expectEqualStrings("9.81m.ns⁻²", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1122,13 +1124,13 @@ test "Vector format" {
|
|||||||
const KgMeterPerSecond = Tensor(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3});
|
const KgMeterPerSecond = Tensor(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3});
|
||||||
|
|
||||||
const accel = MeterPerSecondSq.splat(9.81);
|
const accel = MeterPerSecondSq.splat(9.81);
|
||||||
const momentum = KgMeterPerSecond{ .data = .{ 43, 0, 11 } };
|
const momentum = &KgMeterPerSecond{ .data = .{ 43, 0, 11 } };
|
||||||
|
|
||||||
var buf: [64]u8 = undefined;
|
var buf: [64]u8 = undefined;
|
||||||
var res = try std.fmt.bufPrint(&buf, "{d}", .{accel});
|
var res = try std.fmt.bufPrint(&buf, "{d}", .{accel.*});
|
||||||
try std.testing.expectEqualStrings("(9.81, 9.81, 9.81)m.ns⁻²", res);
|
try std.testing.expectEqualStrings("(9.81, 9.81, 9.81)m.ns⁻²", res);
|
||||||
|
|
||||||
res = try std.fmt.bufPrint(&buf, "{d:.2}", .{momentum});
|
res = try std.fmt.bufPrint(&buf, "{d:.2}", .{momentum.*});
|
||||||
try std.testing.expectEqualStrings("(43.00, 0.00, 11.00)m.kg.s⁻¹", res);
|
try std.testing.expectEqualStrings("(43.00, 0.00, 11.00)m.kg.s⁻¹", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1145,8 +1147,8 @@ test "Vector Vec3 Init and Basic Arithmetic" {
|
|||||||
const v_def = Meter3.splat(5);
|
const v_def = Meter3.splat(5);
|
||||||
try std.testing.expectEqual(5, v_def.data[2]);
|
try std.testing.expectEqual(5, v_def.data[2]);
|
||||||
|
|
||||||
const v1 = Meter3{ .data = .{ 10, 20, 30 } };
|
const v1 = &Meter3{ .data = .{ 10, 20, 30 } };
|
||||||
const v2 = Meter3{ .data = .{ 2, 4, 6 } };
|
const v2 = &Meter3{ .data = .{ 2, 4, 6 } };
|
||||||
|
|
||||||
const added = v1.add(v2);
|
const added = v1.add(v2);
|
||||||
try std.testing.expectEqual(12, added.data[0]);
|
try std.testing.expectEqual(12, added.data[0]);
|
||||||
@ -1175,24 +1177,24 @@ test "Vector Kinematics (scalar mul/div broadcast)" {
|
|||||||
try std.testing.expectEqual(10, vel.data[0]);
|
try std.testing.expectEqual(10, vel.data[0]);
|
||||||
try std.testing.expectEqual(20, vel.data[1]);
|
try std.testing.expectEqual(20, vel.data[1]);
|
||||||
try std.testing.expectEqual(30, vel.data[2]);
|
try std.testing.expectEqual(30, vel.data[2]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(vel).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(vel.*).dims.get(.L));
|
||||||
try std.testing.expectEqual(-1, @TypeOf(vel).dims.get(.T));
|
try std.testing.expectEqual(-1, @TypeOf(vel.*).dims.get(.T));
|
||||||
|
|
||||||
const new_pos = vel.mul(time);
|
const new_pos = vel.mul(time);
|
||||||
try std.testing.expectEqual(100, new_pos.data[0]);
|
try std.testing.expectEqual(100, new_pos.data[0]);
|
||||||
try std.testing.expectEqual(0, @TypeOf(new_pos).dims.get(.T));
|
try std.testing.expectEqual(0, @TypeOf(new_pos.*).dims.get(.T));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Element-wise Math and Scaling" {
|
test "Vector Element-wise Math and Scaling" {
|
||||||
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
|
|
||||||
const v1 = Meter3{ .data = .{ 10, 20, 30 } };
|
const v1 = &Meter3{ .data = .{ 10, 20, 30 } };
|
||||||
const v2 = Meter3{ .data = .{ 2, 5, 10 } };
|
const v2 = &Meter3{ .data = .{ 2, 5, 10 } };
|
||||||
const dv = v1.div(v2);
|
const dv = v1.div(v2);
|
||||||
try std.testing.expectEqual(5, dv.data[0]);
|
try std.testing.expectEqual(5, dv.data[0]);
|
||||||
try std.testing.expectEqual(4, dv.data[1]);
|
try std.testing.expectEqual(4, dv.data[1]);
|
||||||
try std.testing.expectEqual(3, dv.data[2]);
|
try std.testing.expectEqual(3, dv.data[2]);
|
||||||
try std.testing.expectEqual(0, @TypeOf(dv).dims.get(.L));
|
try std.testing.expectEqual(0, @TypeOf(dv.*).dims.get(.L));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Conversions" {
|
test "Vector Conversions" {
|
||||||
@ -1204,14 +1206,14 @@ test "Vector Conversions" {
|
|||||||
try std.testing.expectEqual(1000, v_m.data[0]);
|
try std.testing.expectEqual(1000, v_m.data[0]);
|
||||||
try std.testing.expectEqual(2000, v_m.data[1]);
|
try std.testing.expectEqual(2000, v_m.data[1]);
|
||||||
try std.testing.expectEqual(3000, v_m.data[2]);
|
try std.testing.expectEqual(3000, v_m.data[2]);
|
||||||
try std.testing.expectEqual(UnitScale.none, @TypeOf(v_m).scales.get(.L));
|
try std.testing.expectEqual(UnitScale.none, @TypeOf(v_m.*).scales.get(.L));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Length" {
|
test "Vector Length" {
|
||||||
const MeterInt3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const MeterInt3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const MeterFloat3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const MeterFloat3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
|
|
||||||
const v_int = MeterInt3{ .data = .{ 3, 4, 0 } };
|
const v_int = &MeterInt3{ .data = .{ 3, 4, 0 } };
|
||||||
try std.testing.expectEqual(25, v_int.lengthSqr());
|
try std.testing.expectEqual(25, v_int.lengthSqr());
|
||||||
try std.testing.expectEqual(5, v_int.length());
|
try std.testing.expectEqual(5, v_int.length());
|
||||||
|
|
||||||
@ -1224,9 +1226,9 @@ test "Vector Comparisons" {
|
|||||||
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const KiloMeter3 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{3});
|
const KiloMeter3 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{3});
|
||||||
|
|
||||||
const v1 = Meter3{ .data = .{ 1000.0, 500.0, 0.0 } };
|
const v1 = &Meter3{ .data = .{ 1000.0, 500.0, 0.0 } };
|
||||||
const v2 = KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } };
|
const v2 = &KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } };
|
||||||
const v3 = KiloMeter3{ .data = .{ 1.0, 0.6, 0.0 } };
|
const v3 = &KiloMeter3{ .data = .{ 1.0, 0.6, 0.0 } };
|
||||||
|
|
||||||
try std.testing.expect(v1.eqAll(v2));
|
try std.testing.expect(v1.eqAll(v2));
|
||||||
try std.testing.expect(v1.neAll(v3));
|
try std.testing.expect(v1.neAll(v3));
|
||||||
@ -1249,7 +1251,7 @@ test "Vector vs Scalar broadcast comparison" {
|
|||||||
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
|
|
||||||
const positions = Meter3{ .data = .{ 500.0, 1200.0, 3000.0 } };
|
const positions = &Meter3{ .data = .{ 500.0, 1200.0, 3000.0 } };
|
||||||
const threshold = KiloMeter1.splat(1); // 1 km = 1000 m
|
const threshold = KiloMeter1.splat(1); // 1 km = 1000 m
|
||||||
|
|
||||||
const exceeded = positions.gt(threshold);
|
const exceeded = positions.gt(threshold);
|
||||||
@ -1267,22 +1269,22 @@ test "Vector contract — dot product (rank-1 * rank-1)" {
|
|||||||
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const Newton3 = Tensor(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3});
|
const Newton3 = Tensor(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3});
|
||||||
|
|
||||||
const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } };
|
const pos = &Meter3{ .data = .{ 10.0, 0.0, 0.0 } };
|
||||||
const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } };
|
const force = &Newton3{ .data = .{ 5.0, 5.0, 0.0 } };
|
||||||
|
|
||||||
const work = force.contract(pos, 0, 0);
|
const work = force.contract(pos, 0, 0);
|
||||||
try std.testing.expectEqual(50.0, work.data[0]);
|
try std.testing.expectEqual(50.0, work.data[0]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M));
|
try std.testing.expectEqual(1, @TypeOf(work.*).dims.get(.M));
|
||||||
try std.testing.expectEqual(2, @TypeOf(work).dims.get(.L));
|
try std.testing.expectEqual(2, @TypeOf(work.*).dims.get(.L));
|
||||||
try std.testing.expectEqual(-2, @TypeOf(work).dims.get(.T));
|
try std.testing.expectEqual(-2, @TypeOf(work.*).dims.get(.T));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector contract — matrix multiply (rank-2 * rank-2)" {
|
test "Vector contract — matrix multiply (rank-2 * rank-2)" {
|
||||||
const A = Tensor(f32, .{}, .{}, &.{ 2, 3 });
|
const A = Tensor(f32, .{}, .{}, &.{ 2, 3 });
|
||||||
const B = Tensor(f32, .{}, .{}, &.{ 3, 2 });
|
const B = Tensor(f32, .{}, .{}, &.{ 3, 2 });
|
||||||
|
|
||||||
const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } };
|
const a = &A{ .data = .{ 1, 2, 3, 4, 5, 6 } };
|
||||||
const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } };
|
const b = &B{ .data = .{ 7, 8, 9, 10, 11, 12 } };
|
||||||
|
|
||||||
const c = a.contract(b, 1, 0);
|
const c = a.contract(b, 1, 0);
|
||||||
try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]);
|
try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]);
|
||||||
@ -1294,59 +1296,59 @@ test "Vector contract — matrix multiply (rank-2 * rank-2)" {
|
|||||||
test "Vector Abs, Pow, Sqrt and Product" {
|
test "Vector Abs, Pow, Sqrt and Product" {
|
||||||
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
|
|
||||||
const v1 = Meter3{ .data = .{ -2.0, 3.0, -4.0 } };
|
const v1 = &Meter3{ .data = .{ -2.0, 3.0, -4.0 } };
|
||||||
const v_abs = v1.abs();
|
const v_abs = v1.abs();
|
||||||
try std.testing.expectEqual(2.0, v_abs.data[0]);
|
try std.testing.expectEqual(2.0, v_abs.data[0]);
|
||||||
try std.testing.expectEqual(4.0, v_abs.data[2]);
|
try std.testing.expectEqual(4.0, v_abs.data[2]);
|
||||||
|
|
||||||
const vol = v_abs.product();
|
const vol = v_abs.product();
|
||||||
try std.testing.expectEqual(24.0, vol.data[0]);
|
try std.testing.expectEqual(24.0, vol.data[0]);
|
||||||
try std.testing.expectEqual(3, @TypeOf(vol).dims.get(.L));
|
try std.testing.expectEqual(3, @TypeOf(vol.*).dims.get(.L));
|
||||||
|
|
||||||
const area_vec = v_abs.pow(2);
|
const area_vec = v_abs.pow(2);
|
||||||
try std.testing.expectEqual(4.0, area_vec.data[0]);
|
try std.testing.expectEqual(4.0, area_vec.data[0]);
|
||||||
try std.testing.expectEqual(16.0, area_vec.data[2]);
|
try std.testing.expectEqual(16.0, area_vec.data[2]);
|
||||||
try std.testing.expectEqual(2, @TypeOf(area_vec).dims.get(.L));
|
try std.testing.expectEqual(2, @TypeOf(area_vec.*).dims.get(.L));
|
||||||
|
|
||||||
const sqrted = area_vec.sqrt();
|
const sqrted = area_vec.sqrt();
|
||||||
try std.testing.expectEqual(2, sqrted.data[0]);
|
try std.testing.expectEqual(2, sqrted.data[0]);
|
||||||
try std.testing.expectEqual(4, sqrted.data[2]);
|
try std.testing.expectEqual(4, sqrted.data[2]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(sqrted).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(sqrted.*).dims.get(.L));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector mul comptime_int broadcast" {
|
test "Vector mul comptime_int broadcast" {
|
||||||
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const v = Meter3{ .data = .{ 1, 2, 3 } };
|
const v = &Meter3{ .data = .{ 1, 2, 3 } };
|
||||||
const scaled = v.mul(10);
|
const scaled = v.mul(10);
|
||||||
try std.testing.expectEqual(10, scaled.data[0]);
|
try std.testing.expectEqual(10, scaled.data[0]);
|
||||||
try std.testing.expectEqual(20, scaled.data[1]);
|
try std.testing.expectEqual(20, scaled.data[1]);
|
||||||
try std.testing.expectEqual(30, scaled.data[2]);
|
try std.testing.expectEqual(30, scaled.data[2]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(scaled.*).dims.get(.L));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector mul comptime_float broadcast" {
|
test "Vector mul comptime_float broadcast" {
|
||||||
const MeterF3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const MeterF3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } };
|
const v = &MeterF3{ .data = .{ 1.0, 2.0, 4.0 } };
|
||||||
const scaled = v.mul(0.5);
|
const scaled = v.mul(0.5);
|
||||||
try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6);
|
try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6);
|
||||||
try std.testing.expectApproxEqAbs(1.0, scaled.data[1], 1e-6);
|
try std.testing.expectApproxEqAbs(1.0, scaled.data[1], 1e-6);
|
||||||
try std.testing.expectApproxEqAbs(2.0, scaled.data[2], 1e-6);
|
try std.testing.expectApproxEqAbs(2.0, scaled.data[2], 1e-6);
|
||||||
try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(scaled.*).dims.get(.L));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector div comptime_int broadcast" {
|
test "Vector div comptime_int broadcast" {
|
||||||
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const v = Meter3{ .data = .{ 10, 20, 30 } };
|
const v = &Meter3{ .data = .{ 10, 20, 30 } };
|
||||||
const halved = v.div(2);
|
const halved = v.div(2);
|
||||||
try std.testing.expectEqual(5, halved.data[0]);
|
try std.testing.expectEqual(5, halved.data[0]);
|
||||||
try std.testing.expectEqual(10, halved.data[1]);
|
try std.testing.expectEqual(10, halved.data[1]);
|
||||||
try std.testing.expectEqual(15, halved.data[2]);
|
try std.testing.expectEqual(15, halved.data[2]);
|
||||||
try std.testing.expectEqual(1, @TypeOf(halved).dims.get(.L));
|
try std.testing.expectEqual(1, @TypeOf(halved.*).dims.get(.L));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector div comptime_float broadcast" {
|
test "Vector div comptime_float broadcast" {
|
||||||
const MeterF3 = Tensor(f64, .{ .L = 1 }, .{}, &.{3});
|
const MeterF3 = Tensor(f64, .{ .L = 1 }, .{}, &.{3});
|
||||||
const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } };
|
const v = &MeterF3{ .data = .{ 9.0, 6.0, 3.0 } };
|
||||||
const r = v.div(3.0);
|
const r = v.div(3.0);
|
||||||
try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9);
|
try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9);
|
||||||
try std.testing.expectApproxEqAbs(2.0, r.data[1], 1e-9);
|
try std.testing.expectApproxEqAbs(2.0, r.data[1], 1e-9);
|
||||||
@ -1355,7 +1357,7 @@ test "Vector div comptime_float broadcast" {
|
|||||||
|
|
||||||
test "Vector eq broadcast on dimensionless" {
|
test "Vector eq broadcast on dimensionless" {
|
||||||
const DimLess3 = Tensor(i32, .{}, .{}, &.{3});
|
const DimLess3 = Tensor(i32, .{}, .{}, &.{3});
|
||||||
const v = DimLess3{ .data = .{ 1, 2, 3 } };
|
const v = &DimLess3{ .data = .{ 1, 2, 3 } };
|
||||||
|
|
||||||
const eq_res = v.eq(2);
|
const eq_res = v.eq(2);
|
||||||
try std.testing.expectEqual(false, eq_res[0]);
|
try std.testing.expectEqual(false, eq_res[0]);
|
||||||
@ -1370,7 +1372,7 @@ test "Vector eq broadcast on dimensionless" {
|
|||||||
|
|
||||||
test "Tensor idx helper and matrix access" {
|
test "Tensor idx helper and matrix access" {
|
||||||
const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 });
|
const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 });
|
||||||
var m: Mat3x3 = Mat3x3.zero;
|
var m = Mat3x3.zero;
|
||||||
m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0;
|
m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0;
|
||||||
m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0;
|
m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0;
|
||||||
m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0;
|
m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0;
|
||||||
|
|||||||
@ -262,19 +262,19 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
|
|||||||
for (0..ITERS) |_| {
|
for (0..ITERS) |_| {
|
||||||
// Scalar logic branch
|
// Scalar logic branch
|
||||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||||
&qa.add(&qb)
|
qa.add(qb)
|
||||||
else if (comptime std.mem.eql(u8, op_name, "sub"))
|
else if (comptime std.mem.eql(u8, op_name, "sub"))
|
||||||
&qa.sub(&qb)
|
qa.sub(qb)
|
||||||
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
||||||
&qa.mul(&qb)
|
qa.mul(qb)
|
||||||
else if (comptime std.mem.eql(u8, op_name, "div"))
|
else if (comptime std.mem.eql(u8, op_name, "div"))
|
||||||
&qa.div(&qb)
|
qa.div(qb)
|
||||||
else if (comptime std.mem.eql(u8, op_name, "abs"))
|
else if (comptime std.mem.eql(u8, op_name, "abs"))
|
||||||
&qa.abs()
|
qa.abs()
|
||||||
else if (comptime std.mem.eql(u8, op_name, "eq"))
|
else if (comptime std.mem.eql(u8, op_name, "eq"))
|
||||||
&qa.eq(&qb)
|
qa.eq(qb)
|
||||||
else if (comptime std.mem.eql(u8, op_name, "gt"))
|
else if (comptime std.mem.eql(u8, op_name, "gt"))
|
||||||
&qa.gt(&qb)
|
qa.gt(qb)
|
||||||
else
|
else
|
||||||
unreachable;
|
unreachable;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user