Renamed Tensor to TensorStatic to later introduce TensorAlloc and TensorGPU
This commit is contained in:
parent
9635cfb481
commit
4d275dca2d
@ -3,7 +3,7 @@ const std = @import("std");
|
|||||||
// Adjust these imports to match your actual file names
|
// Adjust these imports to match your actual file names
|
||||||
const Dimensions = @import("Dimensions.zig");
|
const Dimensions = @import("Dimensions.zig");
|
||||||
const Scales = @import("Scales.zig");
|
const Scales = @import("Scales.zig");
|
||||||
const Tensor = @import("Tensor.zig").Tensor;
|
const Tensor = @import("Tensor.zig").TensorStatic;
|
||||||
|
|
||||||
fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime s: Scales.ArgOpts) type {
|
fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime s: Scales.ArgOpts) type {
|
||||||
return struct {
|
return struct {
|
||||||
|
|||||||
461
src/Tensor.zig
461
src/Tensor.zig
@ -3,127 +3,8 @@ const Scales = @import("Scales.zig");
|
|||||||
const UnitScale = Scales.UnitScale;
|
const UnitScale = Scales.UnitScale;
|
||||||
const Dimensions = @import("Dimensions.zig");
|
const Dimensions = @import("Dimensions.zig");
|
||||||
const Dimension = Dimensions.Dimension;
|
const Dimension = Dimensions.Dimension;
|
||||||
|
const sh = @import("shared.zig");
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
|
||||||
// Comptime utilities
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
pub fn shapeTotal(shape: []const comptime_int) usize {
|
|
||||||
var t: comptime_int = 1;
|
|
||||||
for (shape) |s| t *= s;
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if two shapes are strictly identical.
|
|
||||||
pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool {
|
|
||||||
if (a.len != b.len) return false;
|
|
||||||
for (a, 0..) |v, i|
|
|
||||||
if (v != b[i]) return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
|
|
||||||
/// e.g. shape {3, 4} → strides {4, 1}
|
|
||||||
/// shape {2, 3, 4} → strides {12, 4, 1}
|
|
||||||
pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int {
|
|
||||||
var st: [shape.len]comptime_int = undefined;
|
|
||||||
if (shape.len == 0) return st;
|
|
||||||
st[shape.len - 1] = 1;
|
|
||||||
if (shape.len > 1) {
|
|
||||||
var i: comptime_int = shape.len - 1;
|
|
||||||
while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i];
|
|
||||||
}
|
|
||||||
return st;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return a copy of `shape` with the element at `axis` removed.
|
|
||||||
pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int {
|
|
||||||
var out: [shape.len - 1]comptime_int = undefined;
|
|
||||||
var j: comptime_int = 0;
|
|
||||||
for (shape, 0..) |v, i| {
|
|
||||||
if (i != axis) {
|
|
||||||
out[j] = v;
|
|
||||||
j += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Concatenate two compile-time slices.
|
|
||||||
pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int {
|
|
||||||
var out: [a.len + b.len]comptime_int = undefined;
|
|
||||||
for (a, 0..) |v, i| out[i] = v;
|
|
||||||
for (b, 0..) |v, i| out[a.len + i] = v;
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode a flat row-major index into N-D coordinates.
|
|
||||||
/// Called only in comptime contexts (all arguments are comptime).
|
|
||||||
pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize {
|
|
||||||
var coords: [n]comptime_int = undefined;
|
|
||||||
var tmp = flat;
|
|
||||||
for (0..n) |i| {
|
|
||||||
coords[i] = if (strd[i] == 0) 0 else tmp / strd[i];
|
|
||||||
tmp = if (strd[i] == 0) 0 else tmp % strd[i];
|
|
||||||
}
|
|
||||||
return coords;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Encode N-D coordinates into a flat row-major index.
|
|
||||||
/// Called only in comptime contexts.
|
|
||||||
pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize {
|
|
||||||
var flat: usize = 0;
|
|
||||||
for (0..n) |i| flat += coords[i] * strd[i];
|
|
||||||
return flat;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Rebuild a full coordinate array by inserting `val` at `axis` into `free`.
|
|
||||||
/// `free` holds the remaining (non-contracted) coordinates in order.
|
|
||||||
pub fn insertAxis(
|
|
||||||
comptime n: usize,
|
|
||||||
comptime axis: usize,
|
|
||||||
comptime val: usize,
|
|
||||||
comptime free: []const usize,
|
|
||||||
) [n]usize {
|
|
||||||
var out: [n]usize = undefined;
|
|
||||||
var fi: usize = 0;
|
|
||||||
for (0..n) |i| {
|
|
||||||
if (i == axis) {
|
|
||||||
out[i] = val;
|
|
||||||
} else {
|
|
||||||
out[i] = free[fi];
|
|
||||||
fi += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline fn isInt(comptime T: type) bool {
|
|
||||||
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn finerScales(comptime T1: type, comptime T2: type) Scales {
|
|
||||||
const d1: Dimensions = T1.dims;
|
|
||||||
const d2: Dimensions = T2.dims;
|
|
||||||
const s1: Scales = T1.scales;
|
|
||||||
const s2: Scales = T2.scales;
|
|
||||||
comptime var out = Scales.initFill(.none);
|
|
||||||
for (std.enums.values(Dimension)) |dim| {
|
|
||||||
const scale1 = comptime s1.get(dim);
|
|
||||||
const scale2 = comptime s2.get(dim);
|
|
||||||
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
|
|
||||||
.none
|
|
||||||
else if (comptime d1.get(dim) == 0)
|
|
||||||
scale2
|
|
||||||
else if (comptime d2.get(dim) == 0)
|
|
||||||
scale1
|
|
||||||
else if (comptime scale1.getFactor() > scale2.getFactor())
|
|
||||||
scale2
|
|
||||||
else
|
|
||||||
scale1);
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
// File-scope RHS normalisation helpers
|
// File-scope RHS normalisation helpers
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
@ -134,7 +15,7 @@ inline fn isTensor(comptime Rhs: type) bool {
|
|||||||
|
|
||||||
inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
||||||
if (comptime isTensor(Rhs)) return Rhs;
|
if (comptime isTensor(Rhs)) return Rhs;
|
||||||
return Tensor(T, .{}, .{}, &.{1});
|
return TensorStatic(T, .{}, .{}, &.{1});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
|
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
|
||||||
@ -162,37 +43,13 @@ inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r))
|
|||||||
},
|
},
|
||||||
else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)),
|
else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)),
|
||||||
};
|
};
|
||||||
return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} };
|
return TensorStatic(T, .{}, .{}, &.{1}){ .data = .{scalar} };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
|
/// SIMD implementation of a Tensor.
|
||||||
if (n == 0) return;
|
/// Limited to tensor of ~2000 values.
|
||||||
var val = n;
|
/// For more, see either TensorAlloc or TensorGPU
|
||||||
if (val < 0) {
|
pub fn TensorStatic(
|
||||||
try writer.writeAll("\u{207B}");
|
|
||||||
val = -val;
|
|
||||||
}
|
|
||||||
var buf: [12]u8 = undefined;
|
|
||||||
const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return;
|
|
||||||
for (str) |c| {
|
|
||||||
const s = switch (c) {
|
|
||||||
'0' => "\u{2070}",
|
|
||||||
'1' => "\u{00B9}",
|
|
||||||
'2' => "\u{00B2}",
|
|
||||||
'3' => "\u{00B3}",
|
|
||||||
'4' => "\u{2074}",
|
|
||||||
'5' => "\u{2075}",
|
|
||||||
'6' => "\u{2076}",
|
|
||||||
'7' => "\u{2077}",
|
|
||||||
'8' => "\u{2078}",
|
|
||||||
'9' => "\u{2079}",
|
|
||||||
else => unreachable,
|
|
||||||
};
|
|
||||||
try writer.writeAll(s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn Tensor(
|
|
||||||
comptime T: type,
|
comptime T: type,
|
||||||
comptime d_opt: Dimensions.ArgOpts,
|
comptime d_opt: Dimensions.ArgOpts,
|
||||||
comptime s_opt: Scales.ArgOpts,
|
comptime s_opt: Scales.ArgOpts,
|
||||||
@ -204,12 +61,15 @@ pub fn Tensor(
|
|||||||
if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1.");
|
if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@setEvalBranchQuota(100_000);
|
@setEvalBranchQuota(100_000_000);
|
||||||
|
|
||||||
const _total: usize = comptime shapeTotal(shape_);
|
const _total: usize = comptime sh.shapeTotal(shape_);
|
||||||
const _strides = comptime shapeStrides(shape_);
|
const _strides = comptime sh.shapeStrides(shape_);
|
||||||
const Vec = @Vector(_total, T);
|
const Vec = @Vector(_total, T);
|
||||||
|
|
||||||
|
if (comptime _total * @bitSizeOf(T) > 1_000_000)
|
||||||
|
@compileError("Tensor too big, consider using a TensorGPU or TensorAlloc.");
|
||||||
|
|
||||||
return struct {
|
return struct {
|
||||||
data: Vec,
|
data: Vec,
|
||||||
|
|
||||||
@ -266,98 +126,103 @@ pub fn Tensor(
|
|||||||
|
|
||||||
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
||||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||||
pub inline fn add(self: *const Self, r: anytype) Tensor(
|
pub inline fn add(self: *const Self, r: anytype) TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.argsOpt(),
|
dims.argsOpt(),
|
||||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
const rhs_t = rhs(r);
|
const rhs_t = rhs(r);
|
||||||
const RhsType = @TypeOf(rhs_t);
|
const RhsType = @TypeOf(rhs_t);
|
||||||
if (comptime !dims.eql(RhsType.dims))
|
if (comptime !dims.eql(RhsType.dims))
|
||||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||||
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
||||||
return .{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data };
|
return .{ .data = if (comptime sh.isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data };
|
||||||
|
|
||||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||||
const rr: Vec = blk: {
|
const rr: Vec = blk: {
|
||||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
const RhsNorm = TensorStatic(
|
||||||
|
T,
|
||||||
|
RhsType.dims.argsOpt(),
|
||||||
|
sh.finerScales(Self, RhsType).argsOpt(),
|
||||||
|
RhsType.shape,
|
||||||
|
);
|
||||||
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
||||||
break :blk broadcastToVec(RhsNorm, rn);
|
break :blk broadcastToVec(RhsNorm, rn);
|
||||||
};
|
};
|
||||||
return .{ .data = if (comptime isInt(T)) l +| rr else l + rr };
|
return .{ .data = if (comptime sh.isInt(T)) l +| rr else l + rr };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
||||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||||
pub inline fn sub(self: *const Self, r: anytype) Tensor(
|
pub inline fn sub(self: *const Self, r: anytype) TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.argsOpt(),
|
dims.argsOpt(),
|
||||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
const rhs_t = rhs(r);
|
const rhs_t = rhs(r);
|
||||||
const RhsType = @TypeOf(rhs_t);
|
const RhsType = @TypeOf(rhs_t);
|
||||||
if (comptime !dims.eql(RhsType.dims))
|
if (comptime !dims.eql(RhsType.dims))
|
||||||
@compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
@compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||||
@compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
||||||
return .{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data };
|
return .{ .data = if (comptime sh.isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data };
|
||||||
|
|
||||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||||
const rr: Vec = blk: {
|
const rr: Vec = blk: {
|
||||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||||
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
||||||
break :blk broadcastToVec(RhsNorm, rn);
|
break :blk broadcastToVec(RhsNorm, rn);
|
||||||
};
|
};
|
||||||
return .{ .data = if (comptime isInt(T)) l -| rr else l - rr };
|
return .{ .data = if (comptime sh.isInt(T)) l -| rr else l - rr };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise multiply. Dimension exponents summed.
|
/// Element-wise multiply. Dimension exponents summed.
|
||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn mul(self: *const Self, r: anytype) Tensor(
|
pub inline fn mul(self: *const Self, r: anytype) TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.add(RhsT(@TypeOf(r)).dims).argsOpt(),
|
dims.add(RhsT(@TypeOf(r)).dims).argsOpt(),
|
||||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
const rhs_q = rhs(r);
|
const rhs_q = rhs(r);
|
||||||
const RhsType = @TypeOf(rhs_q);
|
const RhsType = @TypeOf(rhs_q);
|
||||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||||
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
|
||||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||||
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||||
return .{ .data = if (comptime isInt(T)) l *| rr else l * rr };
|
return .{ .data = if (comptime sh.isInt(T)) l *| rr else l * rr };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise divide. Dimension exponents subtracted.
|
/// Element-wise divide. Dimension exponents subtracted.
|
||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn div(self: *const Self, r: anytype) Tensor(
|
pub inline fn div(self: *const Self, r: anytype) TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(),
|
dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(),
|
||||||
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
const rhs_q = rhs(r);
|
const rhs_q = rhs(r);
|
||||||
const RhsType = @TypeOf(rhs_q);
|
const RhsType = @TypeOf(rhs_q);
|
||||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||||
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
|
||||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||||
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||||
if (comptime isInt(T)) {
|
if (comptime sh.isInt(T)) {
|
||||||
return .{ .data = @divTrunc(l, rr) };
|
return .{ .data = @divTrunc(l, rr) };
|
||||||
} else {
|
} else {
|
||||||
return .{ .data = l / rr };
|
return .{ .data = l / rr };
|
||||||
@ -370,7 +235,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Raise every element to a comptime integer exponent.
|
/// Raise every element to a comptime integer exponent.
|
||||||
pub inline fn pow(self: *const Self, comptime exp: comptime_int) Tensor(
|
pub inline fn pow(self: *const Self, comptime exp: comptime_int) TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.scale(exp).argsOpt(),
|
dims.scale(exp).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
@ -386,21 +251,21 @@ pub fn Tensor(
|
|||||||
// $O(\log n)$ Exponentiation by squaring applied to the entire vector
|
// $O(\log n)$ Exponentiation by squaring applied to the entire vector
|
||||||
inline while (e > 0) {
|
inline while (e > 0) {
|
||||||
if (e % 2 == 1) {
|
if (e % 2 == 1) {
|
||||||
result = if (comptime isInt(T)) result *| base else result * base;
|
result = if (comptime sh.isInt(T)) result *| base else result * base;
|
||||||
}
|
}
|
||||||
e /= 2;
|
e /= 2;
|
||||||
if (e > 0) {
|
if (e > 0) {
|
||||||
base = if (comptime isInt(T)) base *| base else base * base;
|
base = if (comptime sh.isInt(T)) base *| base else base * base;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (comptime !isInt(T) and exp < 0) {
|
if (comptime !sh.isInt(T) and exp < 0) {
|
||||||
result = @as(Vec, @splat(1)) / result;
|
result = @as(Vec, @splat(1)) / result;
|
||||||
}
|
}
|
||||||
return .{ .data = result };
|
return .{ .data = result };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Square root of every element. All dimension exponents must be even.
|
/// Square root of every element. All dimension exponents must be even.
|
||||||
pub inline fn sqrt(self: *const Self) Tensor(
|
pub inline fn sqrt(self: *const Self) TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.div(2).argsOpt(),
|
dims.div(2).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
@ -434,15 +299,15 @@ pub fn Tensor(
|
|||||||
pub inline fn to(
|
pub inline fn to(
|
||||||
self: *const Self,
|
self: *const Self,
|
||||||
comptime Dest: type,
|
comptime Dest: type,
|
||||||
) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
|
) TensorStatic(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
|
||||||
const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
|
const ActualDest = TensorStatic(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
|
||||||
|
|
||||||
if (comptime Self == ActualDest) return self;
|
if (comptime Self == ActualDest) return self;
|
||||||
|
|
||||||
// Run validation checks FIRST before dealing with types
|
// Run validation checks FIRST before dealing with types
|
||||||
if (comptime !dims.eql(ActualDest.dims))
|
if (comptime !dims.eql(ActualDest.dims))
|
||||||
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
||||||
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
|
if (comptime Dest.total != 1 and !sh.shapeEql(shape_, Dest.shape))
|
||||||
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
||||||
|
|
||||||
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
||||||
@ -519,13 +384,13 @@ pub fn Tensor(
|
|||||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||||
inline fn resolveScalePair(self: *const Self, rhs_q: anytype) struct { l: Vec, r: Vec } {
|
inline fn resolveScalePair(self: *const Self, rhs_q: anytype) struct { l: Vec, r: Vec } {
|
||||||
const RhsType = @TypeOf(rhs_q);
|
const RhsType = @TypeOf(rhs_q);
|
||||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||||
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
|
||||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||||
const rr: Vec = blk: {
|
const rr: Vec = blk: {
|
||||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||||
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||||
break :blk broadcastToVec(RhsNorm, rn);
|
break :blk broadcastToVec(RhsNorm, rn);
|
||||||
};
|
};
|
||||||
@ -604,41 +469,41 @@ pub fn Tensor(
|
|||||||
if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds");
|
if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds");
|
||||||
if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
|
if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
|
||||||
|
|
||||||
const sa = shapeRemoveAxis(shape_, axis_a);
|
const sa = sh.shapeRemoveAxis(shape_, axis_a);
|
||||||
const sb = shapeRemoveAxis(OT.shape, axis_b);
|
const sb = sh.shapeRemoveAxis(OT.shape, axis_b);
|
||||||
const rs_raw = shapeCat(&sa, &sb);
|
const rs_raw = sh.shapeCat(&sa, &sb);
|
||||||
const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
|
const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||||
break :blk Tensor(
|
break :blk TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.add(OT.dims).argsOpt(),
|
dims.add(OT.dims).argsOpt(),
|
||||||
finerScales(Self, OT).argsOpt(),
|
sh.finerScales(Self, OT).argsOpt(),
|
||||||
rs,
|
rs,
|
||||||
);
|
);
|
||||||
} {
|
} {
|
||||||
const OT = @TypeOf(other);
|
const OT = @TypeOf(other);
|
||||||
const k: usize = comptime shape_[axis_a]; // contraction dimension
|
const k: usize = comptime shape_[axis_a]; // contraction dimension
|
||||||
|
|
||||||
const sa = comptime shapeRemoveAxis(shape_, axis_a);
|
const sa = comptime sh.shapeRemoveAxis(shape_, axis_a);
|
||||||
const sb = comptime shapeRemoveAxis(OT.shape, axis_b);
|
const sb = comptime sh.shapeRemoveAxis(OT.shape, axis_b);
|
||||||
const rs_raw = comptime shapeCat(&sa, &sb);
|
const rs_raw = comptime sh.shapeCat(&sa, &sb);
|
||||||
const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
|
const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||||
|
|
||||||
const ResultType = Tensor(
|
const ResultType = TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.add(OT.dims).argsOpt(),
|
dims.add(OT.dims).argsOpt(),
|
||||||
finerScales(Self, OT).argsOpt(),
|
sh.finerScales(Self, OT).argsOpt(),
|
||||||
rs,
|
rs,
|
||||||
);
|
);
|
||||||
|
|
||||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
|
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), shape_);
|
||||||
const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
|
const OtherNorm = TensorStatic(T, OT.dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), OT.shape);
|
||||||
|
|
||||||
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||||
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
|
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
|
||||||
|
|
||||||
// FAST PATH: Dot Product
|
// FAST PATH: Dot Product
|
||||||
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
|
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
|
||||||
if (comptime !isInt(T)) {
|
if (comptime !sh.isInt(T)) {
|
||||||
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
|
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
|
||||||
} else {
|
} else {
|
||||||
// For integers, we do a vectorized saturating multiply,
|
// For integers, we do a vectorized saturating multiply,
|
||||||
@ -671,7 +536,7 @@ pub fn Tensor(
|
|||||||
const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1];
|
const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1];
|
||||||
|
|
||||||
// Use a_arr and b_arr here
|
// Use a_arr and b_arr here
|
||||||
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
||||||
}
|
}
|
||||||
// Write to the array
|
// Write to the array
|
||||||
res_arr[i * cols + j] = acc;
|
res_arr[i * cols + j] = acc;
|
||||||
@ -682,13 +547,13 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FALLBACK PATH
|
// FALLBACK PATH
|
||||||
const rs_raw_strides = comptime shapeStrides(&rs_raw);
|
const rs_raw_strides = comptime sh.shapeStrides(&rs_raw);
|
||||||
|
|
||||||
// Create a mutable array for the result
|
// Create a mutable array for the result
|
||||||
var result_arr: [ResultType.total]T = undefined;
|
var result_arr: [ResultType.total]T = undefined;
|
||||||
|
|
||||||
for (0..ResultType.total) |res_flat| {
|
for (0..ResultType.total) |res_flat| {
|
||||||
const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
|
const res_coords = sh.decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
|
||||||
|
|
||||||
var a_free: [sa.len]usize = undefined;
|
var a_free: [sa.len]usize = undefined;
|
||||||
for (0..sa.len) |i| a_free[i] = res_coords[i];
|
for (0..sa.len) |i| a_free[i] = res_coords[i];
|
||||||
@ -697,13 +562,13 @@ pub fn Tensor(
|
|||||||
|
|
||||||
var acc: T = 0;
|
var acc: T = 0;
|
||||||
for (0..k) |ki| {
|
for (0..k) |ki| {
|
||||||
const a_coords = insertAxis(rank, axis_a, ki, &a_free);
|
const a_coords = sh.insertAxis(rank, axis_a, ki, &a_free);
|
||||||
const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free);
|
const b_coords = sh.insertAxis(OT.rank, axis_b, ki, &b_free);
|
||||||
const a_flat = encodeFlatCoords(&a_coords, rank, _strides);
|
const a_flat = sh.encodeFlatCoords(&a_coords, rank, _strides);
|
||||||
const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
|
const b_flat = sh.encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
|
||||||
|
|
||||||
// Use a_arr and b_arr here
|
// Use a_arr and b_arr here
|
||||||
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
||||||
}
|
}
|
||||||
// Write to the array
|
// Write to the array
|
||||||
result_arr[res_flat] = acc;
|
result_arr[res_flat] = acc;
|
||||||
@ -715,10 +580,10 @@ pub fn Tensor(
|
|||||||
|
|
||||||
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
|
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
|
||||||
/// Result dimensions are the sum of input dimensions.
|
/// Result dimensions are the sum of input dimensions.
|
||||||
pub inline fn cross(self: *const Self, other: anytype) Tensor(
|
pub inline fn cross(self: *const Self, other: anytype) TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.add(RhsT(@TypeOf(other)).dims).argsOpt(),
|
dims.add(RhsT(@TypeOf(other)).dims).argsOpt(),
|
||||||
finerScales(Self, RhsT(@TypeOf(other))).argsOpt(),
|
sh.finerScales(Self, RhsT(@TypeOf(other))).argsOpt(),
|
||||||
&.{3},
|
&.{3},
|
||||||
) {
|
) {
|
||||||
const rhs_q = rhs(other);
|
const rhs_q = rhs(other);
|
||||||
@ -734,7 +599,7 @@ pub fn Tensor(
|
|||||||
const r = p.r;
|
const r = p.r;
|
||||||
|
|
||||||
var res: [3]T = undefined;
|
var res: [3]T = undefined;
|
||||||
if (comptime isInt(T)) {
|
if (comptime sh.isInt(T)) {
|
||||||
res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]);
|
res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]);
|
||||||
res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]);
|
res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]);
|
||||||
res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]);
|
res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]);
|
||||||
@ -763,7 +628,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Product of all elements. Result has shape {1}; dimension exponent * total.
|
/// Product of all elements. Result has shape {1}; dimension exponent * total.
|
||||||
pub inline fn product(self: *const Self) Tensor(
|
pub inline fn product(self: *const Self) TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.scale(@as(comptime_int, total)).argsOpt(),
|
dims.scale(@as(comptime_int, total)).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
@ -822,7 +687,7 @@ pub fn Tensor(
|
|||||||
else
|
else
|
||||||
try writer.print("{s}{s}", .{ uscale.str(), bu.unit() });
|
try writer.print("{s}{s}", .{ uscale.str(), bu.unit() });
|
||||||
|
|
||||||
if (v != 1) try printSuperscript(writer, v);
|
if (v != 1) try sh.printSuperscript(writer, v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -835,8 +700,8 @@ pub fn Tensor(
|
|||||||
// ─── Scalar tests ─────────────────────────────────────────────────────────
|
// ─── Scalar tests ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
test "Scalar initiat" {
|
test "Scalar initiat" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }, &.{1});
|
||||||
const Second = Tensor(f32, .{ .T = 1 }, .{ .T = .n }, &.{1});
|
const Second = TensorStatic(f32, .{ .T = 1 }, .{ .T = .n }, &.{1});
|
||||||
|
|
||||||
const distance = Meter.splat(10);
|
const distance = Meter.splat(10);
|
||||||
const time = Second.splat(2);
|
const time = Second.splat(2);
|
||||||
@ -846,8 +711,8 @@ test "Scalar initiat" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" {
|
test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const KiloMeter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
|
|
||||||
const m1000 = Meter.splat(1000);
|
const m1000 = Meter.splat(1000);
|
||||||
const km1 = KiloMeter.splat(1);
|
const km1 = KiloMeter.splat(1);
|
||||||
@ -869,9 +734,9 @@ test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Add" {
|
test "Scalar Add" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const KiloMeter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const KiloMeter_f = TensorStatic(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
|
|
||||||
const distance = Meter.splat(10);
|
const distance = Meter.splat(10);
|
||||||
const distance2 = Meter.splat(20);
|
const distance2 = Meter.splat(20);
|
||||||
@ -892,8 +757,8 @@ test "Scalar Add" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Sub" {
|
test "Scalar Sub" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const KiloMeter_f = TensorStatic(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
|
|
||||||
const a = Meter.splat(500);
|
const a = Meter.splat(500);
|
||||||
const b = Meter.splat(200);
|
const b = Meter.splat(200);
|
||||||
@ -909,8 +774,8 @@ test "Scalar Sub" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar MulBy" {
|
test "Scalar MulBy" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1});
|
const Second = TensorStatic(f32, .{ .T = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
const d = Meter.splat(3);
|
const d = Meter.splat(3);
|
||||||
const t = Second.splat(4);
|
const t = Second.splat(4);
|
||||||
@ -926,8 +791,8 @@ test "Scalar MulBy" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar MulBy with scale" {
|
test "Scalar MulBy with scale" {
|
||||||
const KiloMeter = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const KiloMeter = TensorStatic(f32, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
const KiloGram = Tensor(f32, .{ .M = 1 }, .{ .M = .k }, &.{1});
|
const KiloGram = TensorStatic(f32, .{ .M = 1 }, .{ .M = .k }, &.{1});
|
||||||
|
|
||||||
const dist = KiloMeter.splat(2.0);
|
const dist = KiloMeter.splat(2.0);
|
||||||
const mass = KiloGram.splat(3.0);
|
const mass = KiloGram.splat(3.0);
|
||||||
@ -937,10 +802,10 @@ test "Scalar MulBy with scale" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar MulBy with type change" {
|
test "Scalar MulBy with type change" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
const Second = Tensor(f64, .{ .T = 1 }, .{}, &.{1});
|
const Second = TensorStatic(f64, .{ .T = 1 }, .{}, &.{1});
|
||||||
const KmSec = Tensor(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1});
|
const KmSec = TensorStatic(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1});
|
||||||
const KmSec_f = Tensor(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1});
|
const KmSec_f = TensorStatic(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1});
|
||||||
|
|
||||||
const d = Meter.splat(3);
|
const d = Meter.splat(3);
|
||||||
const t = Second.splat(4);
|
const t = Second.splat(4);
|
||||||
@ -950,24 +815,24 @@ test "Scalar MulBy with type change" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar MulBy small" {
|
test "Scalar MulBy small" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .n }, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .n }, &.{1});
|
||||||
const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1});
|
const Second = TensorStatic(f32, .{ .T = 1 }, .{}, &.{1});
|
||||||
const d = Meter.splat(3);
|
const d = Meter.splat(3);
|
||||||
const t = Second.splat(4);
|
const t = Second.splat(4);
|
||||||
try std.testing.expectEqual(12, d.mul(t).data[0]);
|
try std.testing.expectEqual(12, d.mul(t).data[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar MulBy dimensionless" {
|
test "Scalar MulBy dimensionless" {
|
||||||
const DimLess = Tensor(i128, .{}, .{}, &.{1});
|
const DimLess = TensorStatic(i128, .{}, .{}, &.{1});
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const d = Meter.splat(7);
|
const d = Meter.splat(7);
|
||||||
const scaled = d.mul(DimLess.splat(3));
|
const scaled = d.mul(DimLess.splat(3));
|
||||||
try std.testing.expectEqual(21, scaled.data[0]);
|
try std.testing.expectEqual(21, scaled.data[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Sqrt" {
|
test "Scalar Sqrt" {
|
||||||
const MeterSquare = Tensor(i128, .{ .L = 2 }, .{}, &.{1});
|
const MeterSquare = TensorStatic(i128, .{ .L = 2 }, .{}, &.{1});
|
||||||
const MeterSquare_f = Tensor(f64, .{ .L = 2 }, .{}, &.{1});
|
const MeterSquare_f = TensorStatic(f64, .{ .L = 2 }, .{}, &.{1});
|
||||||
|
|
||||||
var d = MeterSquare.splat(9);
|
var d = MeterSquare.splat(9);
|
||||||
var scaled = d.sqrt();
|
var scaled = d.sqrt();
|
||||||
@ -984,8 +849,8 @@ test "Scalar Sqrt" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Chained: velocity and acceleration" {
|
test "Scalar Chained: velocity and acceleration" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1});
|
const Second = TensorStatic(f32, .{ .T = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
const dist = Meter.splat(100);
|
const dist = Meter.splat(100);
|
||||||
const t1 = Second.splat(5);
|
const t1 = Second.splat(5);
|
||||||
@ -998,8 +863,8 @@ test "Scalar Chained: velocity and acceleration" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar DivBy integer exact" {
|
test "Scalar DivBy integer exact" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1});
|
const Second = TensorStatic(f32, .{ .T = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
const dist = Meter.splat(120);
|
const dist = Meter.splat(120);
|
||||||
const time = Second.splat(4);
|
const time = Second.splat(4);
|
||||||
@ -1008,8 +873,8 @@ test "Scalar DivBy integer exact" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Finer scales skip dim 0" {
|
test "Scalar Finer scales skip dim 0" {
|
||||||
const Dimless = Tensor(i128, .{}, .{}, &.{1});
|
const Dimless = TensorStatic(i128, .{}, .{}, &.{1});
|
||||||
const KiloMetre = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const KiloMetre = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
|
|
||||||
const r = Dimless.splat(30);
|
const r = Dimless.splat(30);
|
||||||
const km = KiloMetre.splat(4);
|
const km = KiloMetre.splat(4);
|
||||||
@ -1019,9 +884,9 @@ test "Scalar Finer scales skip dim 0" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Conversion chain: km -> m -> cm" {
|
test "Scalar Conversion chain: km -> m -> cm" {
|
||||||
const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const KiloMeter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const CentiMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .c }, &.{1});
|
const CentiMeter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .c }, &.{1});
|
||||||
|
|
||||||
const km = KiloMeter.splat(15);
|
const km = KiloMeter.splat(15);
|
||||||
const m = km.to(Meter);
|
const m = km.to(Meter);
|
||||||
@ -1031,9 +896,9 @@ test "Scalar Conversion chain: km -> m -> cm" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Conversion: hours -> minutes -> seconds" {
|
test "Scalar Conversion: hours -> minutes -> seconds" {
|
||||||
const Hour = Tensor(i128, .{ .T = 1 }, .{ .T = .hour }, &.{1});
|
const Hour = TensorStatic(i128, .{ .T = 1 }, .{ .T = .hour }, &.{1});
|
||||||
const Minute = Tensor(i128, .{ .T = 1 }, .{ .T = .min }, &.{1});
|
const Minute = TensorStatic(i128, .{ .T = 1 }, .{ .T = .min }, &.{1});
|
||||||
const Second = Tensor(i128, .{ .T = 1 }, .{}, &.{1});
|
const Second = TensorStatic(i128, .{ .T = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
const h = Hour.splat(1);
|
const h = Hour.splat(1);
|
||||||
const min = h.to(Minute);
|
const min = h.to(Minute);
|
||||||
@ -1043,8 +908,8 @@ test "Scalar Conversion: hours -> minutes -> seconds" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Format" {
|
test "Scalar Format" {
|
||||||
const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{1});
|
const MeterPerSecondSq = TensorStatic(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{1});
|
||||||
const Meter = Tensor(f32, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(f32, .{ .L = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
const m = Meter.splat(1.23456);
|
const m = Meter.splat(1.23456);
|
||||||
const accel = MeterPerSecondSq.splat(9.81);
|
const accel = MeterPerSecondSq.splat(9.81);
|
||||||
@ -1058,52 +923,52 @@ test "Scalar Format" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Abs" {
|
test "Scalar Abs" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const MeterF = Tensor(f32, .{ .L = 1 }, .{}, &.{1});
|
const MeterF = TensorStatic(f32, .{ .L = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
try std.testing.expectEqual(50, Meter.splat(-50).abs().data[0]);
|
try std.testing.expectEqual(50, Meter.splat(-50).abs().data[0]);
|
||||||
try std.testing.expectEqual(42.5, MeterF.splat(-42.5).abs().data[0]);
|
try std.testing.expectEqual(42.5, MeterF.splat(-42.5).abs().data[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Pow" {
|
test "Scalar Pow" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const d = Meter.splat(4);
|
const d = Meter.splat(4);
|
||||||
try std.testing.expectEqual(16, d.pow(2).data[0]);
|
try std.testing.expectEqual(16, d.pow(2).data[0]);
|
||||||
try std.testing.expectEqual(64, d.pow(3).data[0]);
|
try std.testing.expectEqual(64, d.pow(3).data[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar mul comptime_int" {
|
test "Scalar mul comptime_int" {
|
||||||
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
const d = Meter.splat(7);
|
const d = Meter.splat(7);
|
||||||
try std.testing.expectEqual(21, d.mul(3).data[0]);
|
try std.testing.expectEqual(21, d.mul(3).data[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar add/sub bare number on dimensionless scalar" {
|
test "Scalar add/sub bare number on dimensionless scalar" {
|
||||||
const DimLess = Tensor(i128, .{}, .{}, &.{1});
|
const DimLess = TensorStatic(i128, .{}, .{}, &.{1});
|
||||||
const a = DimLess.splat(10);
|
const a = DimLess.splat(10);
|
||||||
try std.testing.expectEqual(15, a.add(5).data[0]);
|
try std.testing.expectEqual(15, a.add(5).data[0]);
|
||||||
try std.testing.expectEqual(7, a.sub(3).data[0]);
|
try std.testing.expectEqual(7, a.sub(3).data[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Imperial length scales" {
|
test "Scalar Imperial length scales" {
|
||||||
const Foot = Tensor(f64, .{ .L = 1 }, .{ .L = .ft }, &.{1});
|
const Foot = TensorStatic(f64, .{ .L = 1 }, .{ .L = .ft }, &.{1});
|
||||||
const Meter = Tensor(f64, .{ .L = 1 }, .{}, &.{1});
|
const Meter = TensorStatic(f64, .{ .L = 1 }, .{}, &.{1});
|
||||||
const Inch = Tensor(f64, .{ .L = 1 }, .{ .L = .inch }, &.{1});
|
const Inch = TensorStatic(f64, .{ .L = 1 }, .{ .L = .inch }, &.{1});
|
||||||
|
|
||||||
try std.testing.expectApproxEqAbs(0.3048, Foot.splat(1.0).to(Meter).data[0], 1e-9);
|
try std.testing.expectApproxEqAbs(0.3048, Foot.splat(1.0).to(Meter).data[0], 1e-9);
|
||||||
try std.testing.expectApproxEqAbs(1.0, Inch.splat(12.0).to(Foot).data[0], 1e-9);
|
try std.testing.expectApproxEqAbs(1.0, Inch.splat(12.0).to(Foot).data[0], 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar Imperial mass scales" {
|
test "Scalar Imperial mass scales" {
|
||||||
const Pound = Tensor(f64, .{ .M = 1 }, .{ .M = .lb }, &.{1});
|
const Pound = TensorStatic(f64, .{ .M = 1 }, .{ .M = .lb }, &.{1});
|
||||||
const Ounce = Tensor(f64, .{ .M = 1 }, .{ .M = .oz }, &.{1});
|
const Ounce = TensorStatic(f64, .{ .M = 1 }, .{ .M = .oz }, &.{1});
|
||||||
|
|
||||||
const total = Pound.splat(2.0).add(Ounce.splat(8.0)).to(Pound);
|
const total = Pound.splat(2.0).add(Ounce.splat(8.0)).to(Pound);
|
||||||
try std.testing.expectApproxEqAbs(2.5, total.data[0], 1e-6);
|
try std.testing.expectApproxEqAbs(2.5, total.data[0], 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Scalar comparisons with comptime_int on dimensionless scalar" {
|
test "Scalar comparisons with comptime_int on dimensionless scalar" {
|
||||||
const DimLess = Tensor(i128, .{}, .{}, &.{1});
|
const DimLess = TensorStatic(i128, .{}, .{}, &.{1});
|
||||||
const x = DimLess.splat(42);
|
const x = DimLess.splat(42);
|
||||||
try std.testing.expect(x.eq(42));
|
try std.testing.expect(x.eq(42));
|
||||||
try std.testing.expect(x.gt(10));
|
try std.testing.expect(x.gt(10));
|
||||||
@ -1112,15 +977,15 @@ test "Scalar comparisons with comptime_int on dimensionless scalar" {
|
|||||||
// ─── Vector / Tensor tests ────────────────────────────────────────────────
|
// ─── Vector / Tensor tests ────────────────────────────────────────────────
|
||||||
|
|
||||||
test "Vector initiate" {
|
test "Vector initiate" {
|
||||||
const Meter4 = Tensor(f32, .{ .L = 1 }, .{}, &.{4});
|
const Meter4 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{4});
|
||||||
const m = Meter4.splat(1);
|
const m = Meter4.splat(1);
|
||||||
try std.testing.expect(m.data[0] == 1);
|
try std.testing.expect(m.data[0] == 1);
|
||||||
try std.testing.expect(m.data[3] == 1);
|
try std.testing.expect(m.data[3] == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector format" {
|
test "Vector format" {
|
||||||
const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{3});
|
const MeterPerSecondSq = TensorStatic(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{3});
|
||||||
const KgMeterPerSecond = Tensor(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3});
|
const KgMeterPerSecond = TensorStatic(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3});
|
||||||
|
|
||||||
const accel = MeterPerSecondSq.splat(9.81);
|
const accel = MeterPerSecondSq.splat(9.81);
|
||||||
const momentum = KgMeterPerSecond{ .data = .{ 43, 0, 11 } };
|
const momentum = KgMeterPerSecond{ .data = .{ 43, 0, 11 } };
|
||||||
@ -1134,7 +999,7 @@ test "Vector format" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Vec3 Init and Basic Arithmetic" {
|
test "Vector Vec3 Init and Basic Arithmetic" {
|
||||||
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
|
|
||||||
const v_zero = Meter3.zero;
|
const v_zero = Meter3.zero;
|
||||||
try std.testing.expectEqual(0, v_zero.data[0]);
|
try std.testing.expectEqual(0, v_zero.data[0]);
|
||||||
@ -1166,8 +1031,8 @@ test "Vector Vec3 Init and Basic Arithmetic" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Kinematics (scalar mul/div broadcast)" {
|
test "Vector Kinematics (scalar mul/div broadcast)" {
|
||||||
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const Second1 = Tensor(i32, .{ .T = 1 }, .{}, &.{1});
|
const Second1 = TensorStatic(i32, .{ .T = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
const pos = Meter3{ .data = .{ 100, 200, 300 } };
|
const pos = Meter3{ .data = .{ 100, 200, 300 } };
|
||||||
const time = Second1.splat(10);
|
const time = Second1.splat(10);
|
||||||
@ -1185,7 +1050,7 @@ test "Vector Kinematics (scalar mul/div broadcast)" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Element-wise Math and Scaling" {
|
test "Vector Element-wise Math and Scaling" {
|
||||||
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
|
|
||||||
const v1 = Meter3{ .data = .{ 10, 20, 30 } };
|
const v1 = Meter3{ .data = .{ 10, 20, 30 } };
|
||||||
const v2 = Meter3{ .data = .{ 2, 5, 10 } };
|
const v2 = Meter3{ .data = .{ 2, 5, 10 } };
|
||||||
@ -1197,8 +1062,8 @@ test "Vector Element-wise Math and Scaling" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Conversions" {
|
test "Vector Conversions" {
|
||||||
const KiloMeter3 = Tensor(i32, .{ .L = 1 }, .{ .L = .k }, &.{3});
|
const KiloMeter3 = TensorStatic(i32, .{ .L = 1 }, .{ .L = .k }, &.{3});
|
||||||
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
|
|
||||||
const v_km = KiloMeter3{ .data = .{ 1, 2, 3 } };
|
const v_km = KiloMeter3{ .data = .{ 1, 2, 3 } };
|
||||||
const v_m = v_km.to(Meter3);
|
const v_m = v_km.to(Meter3);
|
||||||
@ -1209,8 +1074,8 @@ test "Vector Conversions" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Length" {
|
test "Vector Length" {
|
||||||
const MeterInt3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const MeterInt3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const MeterFloat3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const MeterFloat3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
|
|
||||||
const v_int = MeterInt3{ .data = .{ 3, 4, 0 } };
|
const v_int = MeterInt3{ .data = .{ 3, 4, 0 } };
|
||||||
try std.testing.expectEqual(25, v_int.lengthSqr());
|
try std.testing.expectEqual(25, v_int.lengthSqr());
|
||||||
@ -1222,8 +1087,8 @@ test "Vector Length" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Comparisons" {
|
test "Vector Comparisons" {
|
||||||
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const KiloMeter3 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{3});
|
const KiloMeter3 = TensorStatic(f32, .{ .L = 1 }, .{ .L = .k }, &.{3});
|
||||||
|
|
||||||
const v1 = Meter3{ .data = .{ 1000.0, 500.0, 0.0 } };
|
const v1 = Meter3{ .data = .{ 1000.0, 500.0, 0.0 } };
|
||||||
const v2 = KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } };
|
const v2 = KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } };
|
||||||
@ -1247,8 +1112,8 @@ test "Vector Comparisons" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector vs Scalar broadcast comparison" {
|
test "Vector vs Scalar broadcast comparison" {
|
||||||
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const KiloMeter1 = TensorStatic(f32, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
|
|
||||||
const positions = Meter3{ .data = .{ 500.0, 1200.0, 3000.0 } };
|
const positions = Meter3{ .data = .{ 500.0, 1200.0, 3000.0 } };
|
||||||
const threshold = KiloMeter1.splat(1); // 1 km = 1000 m
|
const threshold = KiloMeter1.splat(1); // 1 km = 1000 m
|
||||||
@ -1258,15 +1123,15 @@ test "Vector vs Scalar broadcast comparison" {
|
|||||||
try std.testing.expectEqual(true, exceeded[1]);
|
try std.testing.expectEqual(true, exceeded[1]);
|
||||||
try std.testing.expectEqual(true, exceeded[2]);
|
try std.testing.expectEqual(true, exceeded[2]);
|
||||||
|
|
||||||
const Meter1 = Tensor(f32, .{ .L = 1 }, .{}, &.{1});
|
const Meter1 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{1});
|
||||||
const exact = positions.eq(Meter1.splat(500));
|
const exact = positions.eq(Meter1.splat(500));
|
||||||
try std.testing.expect(exact[0] == true);
|
try std.testing.expect(exact[0] == true);
|
||||||
try std.testing.expect(exact[1] == false);
|
try std.testing.expect(exact[1] == false);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector contract — dot product (rank-1 * rank-1)" {
|
test "Vector contract — dot product (rank-1 * rank-1)" {
|
||||||
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const Newton3 = Tensor(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3});
|
const Newton3 = TensorStatic(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3});
|
||||||
|
|
||||||
const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } };
|
const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } };
|
||||||
const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } };
|
const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } };
|
||||||
@ -1279,21 +1144,21 @@ test "Vector contract — dot product (rank-1 * rank-1)" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector contract — matrix multiply (rank-2 * rank-2)" {
|
test "Vector contract — matrix multiply (rank-2 * rank-2)" {
|
||||||
const A = Tensor(f32, .{}, .{}, &.{ 2, 3 });
|
const A = TensorStatic(f32, .{}, .{}, &.{ 2, 3 });
|
||||||
const B = Tensor(f32, .{}, .{}, &.{ 3, 2 });
|
const B = TensorStatic(f32, .{}, .{}, &.{ 3, 2 });
|
||||||
|
|
||||||
const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } };
|
const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } };
|
||||||
const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } };
|
const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } };
|
||||||
|
|
||||||
const c = a.contract(b, 1, 0);
|
const c = a.contract(b, 1, 0);
|
||||||
try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]);
|
try std.testing.expectEqual(58, c.data[TensorStatic(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]);
|
||||||
try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]);
|
try std.testing.expectEqual(64, c.data[TensorStatic(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]);
|
||||||
try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 0 })]);
|
try std.testing.expectEqual(139, c.data[TensorStatic(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 0 })]);
|
||||||
try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 1 })]);
|
try std.testing.expectEqual(154, c.data[TensorStatic(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 1 })]);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Vector Abs, Pow, Sqrt and Product" {
|
test "Vector Abs, Pow, Sqrt and Product" {
|
||||||
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
|
|
||||||
const v1 = Meter3{ .data = .{ -2.0, 3.0, -4.0 } };
|
const v1 = Meter3{ .data = .{ -2.0, 3.0, -4.0 } };
|
||||||
const v_abs = v1.abs();
|
const v_abs = v1.abs();
|
||||||
@ -1316,7 +1181,7 @@ test "Vector Abs, Pow, Sqrt and Product" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector mul comptime_int broadcast" {
|
test "Vector mul comptime_int broadcast" {
|
||||||
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const v = Meter3{ .data = .{ 1, 2, 3 } };
|
const v = Meter3{ .data = .{ 1, 2, 3 } };
|
||||||
const scaled = v.mul(10);
|
const scaled = v.mul(10);
|
||||||
try std.testing.expectEqual(10, scaled.data[0]);
|
try std.testing.expectEqual(10, scaled.data[0]);
|
||||||
@ -1326,7 +1191,7 @@ test "Vector mul comptime_int broadcast" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector mul comptime_float broadcast" {
|
test "Vector mul comptime_float broadcast" {
|
||||||
const MeterF3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
const MeterF3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } };
|
const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } };
|
||||||
const scaled = v.mul(0.5);
|
const scaled = v.mul(0.5);
|
||||||
try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6);
|
try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6);
|
||||||
@ -1336,7 +1201,7 @@ test "Vector mul comptime_float broadcast" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector div comptime_int broadcast" {
|
test "Vector div comptime_int broadcast" {
|
||||||
const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3});
|
const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3});
|
||||||
const v = Meter3{ .data = .{ 10, 20, 30 } };
|
const v = Meter3{ .data = .{ 10, 20, 30 } };
|
||||||
const halved = v.div(2);
|
const halved = v.div(2);
|
||||||
try std.testing.expectEqual(5, halved.data[0]);
|
try std.testing.expectEqual(5, halved.data[0]);
|
||||||
@ -1346,7 +1211,7 @@ test "Vector div comptime_int broadcast" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector div comptime_float broadcast" {
|
test "Vector div comptime_float broadcast" {
|
||||||
const MeterF3 = Tensor(f64, .{ .L = 1 }, .{}, &.{3});
|
const MeterF3 = TensorStatic(f64, .{ .L = 1 }, .{}, &.{3});
|
||||||
const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } };
|
const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } };
|
||||||
const r = v.div(3.0);
|
const r = v.div(3.0);
|
||||||
try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9);
|
try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9);
|
||||||
@ -1355,7 +1220,7 @@ test "Vector div comptime_float broadcast" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Vector eq broadcast on dimensionless" {
|
test "Vector eq broadcast on dimensionless" {
|
||||||
const DimLess3 = Tensor(i32, .{}, .{}, &.{3});
|
const DimLess3 = TensorStatic(i32, .{}, .{}, &.{3});
|
||||||
const v = DimLess3{ .data = .{ 1, 2, 3 } };
|
const v = DimLess3{ .data = .{ 1, 2, 3 } };
|
||||||
|
|
||||||
const eq_res = v.eq(2);
|
const eq_res = v.eq(2);
|
||||||
@ -1370,7 +1235,7 @@ test "Vector eq broadcast on dimensionless" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Tensor idx helper and matrix access" {
|
test "Tensor idx helper and matrix access" {
|
||||||
const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 });
|
const Mat3x3 = TensorStatic(f32, .{}, .{}, &.{ 3, 3 });
|
||||||
var m: Mat3x3 = Mat3x3.zero;
|
var m: Mat3x3 = Mat3x3.zero;
|
||||||
m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0;
|
m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0;
|
||||||
m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0;
|
m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0;
|
||||||
@ -1383,9 +1248,9 @@ test "Tensor idx helper and matrix access" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "Tensor strides_arr correctness" {
|
test "Tensor strides_arr correctness" {
|
||||||
const T1 = Tensor(f32, .{}, .{}, &.{3});
|
const T1 = TensorStatic(f32, .{}, .{}, &.{3});
|
||||||
const T2 = Tensor(f32, .{}, .{}, &.{ 3, 4 });
|
const T2 = TensorStatic(f32, .{}, .{}, &.{ 3, 4 });
|
||||||
const T3 = Tensor(f32, .{}, .{}, &.{ 2, 3, 4 });
|
const T3 = TensorStatic(f32, .{}, .{}, &.{ 2, 3, 4 });
|
||||||
|
|
||||||
try std.testing.expectEqual(1, T1.strides_arr[0]);
|
try std.testing.expectEqual(1, T1.strides_arr[0]);
|
||||||
try std.testing.expectEqual(4, T2.strides_arr[0]);
|
try std.testing.expectEqual(4, T2.strides_arr[0]);
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const Io = std.Io;
|
const Io = std.Io;
|
||||||
const Tensor = @import("Tensor.zig").Tensor;
|
const Tensor = @import("Tensor.zig").TensorStatic;
|
||||||
|
|
||||||
var io: Io = undefined;
|
var io: Io = undefined;
|
||||||
pub fn main(init: std.process.Init) !void {
|
pub fn main(init: std.process.Init) !void {
|
||||||
|
|||||||
149
src/shared.zig
Normal file
149
src/shared.zig
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const Scales = @import("Scales.zig");
|
||||||
|
const UnitScale = Scales.UnitScale;
|
||||||
|
const Dimensions = @import("Dimensions.zig");
|
||||||
|
const Dimension = Dimensions.Dimension;
|
||||||
|
|
||||||
|
pub fn shapeTotal(shape: []const comptime_int) usize {
|
||||||
|
var t: comptime_int = 1;
|
||||||
|
for (shape) |s| t *= s;
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if two shapes are strictly identical.
|
||||||
|
pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool {
|
||||||
|
if (a.len != b.len) return false;
|
||||||
|
for (a, 0..) |v, i|
|
||||||
|
if (v != b[i]) return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
|
||||||
|
/// e.g. shape {3, 4} → strides {4, 1}
|
||||||
|
/// shape {2, 3, 4} → strides {12, 4, 1}
|
||||||
|
pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int {
|
||||||
|
var st: [shape.len]comptime_int = undefined;
|
||||||
|
if (shape.len == 0) return st;
|
||||||
|
st[shape.len - 1] = 1;
|
||||||
|
if (shape.len > 1) {
|
||||||
|
var i: comptime_int = shape.len - 1;
|
||||||
|
while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i];
|
||||||
|
}
|
||||||
|
return st;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a copy of `shape` with the element at `axis` removed.
|
||||||
|
pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int {
|
||||||
|
var out: [shape.len - 1]comptime_int = undefined;
|
||||||
|
var j: comptime_int = 0;
|
||||||
|
for (shape, 0..) |v, i| {
|
||||||
|
if (i != axis) {
|
||||||
|
out[j] = v;
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Concatenate two compile-time slices.
|
||||||
|
pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int {
|
||||||
|
var out: [a.len + b.len]comptime_int = undefined;
|
||||||
|
for (a, 0..) |v, i| out[i] = v;
|
||||||
|
for (b, 0..) |v, i| out[a.len + i] = v;
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode a flat row-major index into N-D coordinates.
|
||||||
|
/// Called only in comptime contexts (all arguments are comptime).
|
||||||
|
pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize {
|
||||||
|
var coords: [n]comptime_int = undefined;
|
||||||
|
var tmp = flat;
|
||||||
|
for (0..n) |i| {
|
||||||
|
coords[i] = if (strd[i] == 0) 0 else tmp / strd[i];
|
||||||
|
tmp = if (strd[i] == 0) 0 else tmp % strd[i];
|
||||||
|
}
|
||||||
|
return coords;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode N-D coordinates into a flat row-major index.
|
||||||
|
/// Called only in comptime contexts.
|
||||||
|
pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize {
|
||||||
|
var flat: usize = 0;
|
||||||
|
for (0..n) |i| flat += coords[i] * strd[i];
|
||||||
|
return flat;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rebuild a full coordinate array by inserting `val` at `axis` into `free`.
|
||||||
|
/// `free` holds the remaining (non-contracted) coordinates in order.
|
||||||
|
pub fn insertAxis(
|
||||||
|
comptime n: usize,
|
||||||
|
comptime axis: usize,
|
||||||
|
comptime val: usize,
|
||||||
|
comptime free: []const usize,
|
||||||
|
) [n]usize {
|
||||||
|
var out: [n]usize = undefined;
|
||||||
|
var fi: usize = 0;
|
||||||
|
for (0..n) |i| {
|
||||||
|
if (i == axis) {
|
||||||
|
out[i] = val;
|
||||||
|
} else {
|
||||||
|
out[i] = free[fi];
|
||||||
|
fi += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn isInt(comptime T: type) bool {
|
||||||
|
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finerScales(comptime T1: type, comptime T2: type) Scales {
|
||||||
|
const d1: Dimensions = T1.dims;
|
||||||
|
const d2: Dimensions = T2.dims;
|
||||||
|
const s1: Scales = T1.scales;
|
||||||
|
const s2: Scales = T2.scales;
|
||||||
|
comptime var out = Scales.initFill(.none);
|
||||||
|
for (std.enums.values(Dimension)) |dim| {
|
||||||
|
const scale1 = comptime s1.get(dim);
|
||||||
|
const scale2 = comptime s2.get(dim);
|
||||||
|
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
|
||||||
|
.none
|
||||||
|
else if (comptime d1.get(dim) == 0)
|
||||||
|
scale2
|
||||||
|
else if (comptime d2.get(dim) == 0)
|
||||||
|
scale1
|
||||||
|
else if (comptime scale1.getFactor() > scale2.getFactor())
|
||||||
|
scale2
|
||||||
|
else
|
||||||
|
scale1);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
|
||||||
|
if (n == 0) return;
|
||||||
|
var val = n;
|
||||||
|
if (val < 0) {
|
||||||
|
try writer.writeAll("\u{207B}");
|
||||||
|
val = -val;
|
||||||
|
}
|
||||||
|
var buf: [12]u8 = undefined;
|
||||||
|
const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return;
|
||||||
|
for (str) |c| {
|
||||||
|
const s = switch (c) {
|
||||||
|
'0' => "\u{2070}",
|
||||||
|
'1' => "\u{00B9}",
|
||||||
|
'2' => "\u{00B2}",
|
||||||
|
'3' => "\u{00B3}",
|
||||||
|
'4' => "\u{2074}",
|
||||||
|
'5' => "\u{2075}",
|
||||||
|
'6' => "\u{2076}",
|
||||||
|
'7' => "\u{2077}",
|
||||||
|
'8' => "\u{2078}",
|
||||||
|
'9' => "\u{2079}",
|
||||||
|
else => unreachable,
|
||||||
|
};
|
||||||
|
try writer.writeAll(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user