From 4d275dca2d205b5602c54e038ae28a17a10c16dc Mon Sep 17 00:00:00 2001 From: adrien Date: Wed, 29 Apr 2026 18:07:13 +0200 Subject: [PATCH] Renamed Tensor to TensorStatic to later introduce TensorAlloc and TensorGPU --- src/Base.zig | 2 +- src/Tensor.zig | 461 ++++++++++++++++------------------------------ src/benchmark.zig | 2 +- src/shared.zig | 149 +++++++++++++++ 4 files changed, 314 insertions(+), 300 deletions(-) create mode 100644 src/shared.zig diff --git a/src/Base.zig b/src/Base.zig index 9002204..c36cb00 100644 --- a/src/Base.zig +++ b/src/Base.zig @@ -3,7 +3,7 @@ const std = @import("std"); // Adjust these imports to match your actual file names const Dimensions = @import("Dimensions.zig"); const Scales = @import("Scales.zig"); -const Tensor = @import("Tensor.zig").Tensor; +const Tensor = @import("Tensor.zig").TensorStatic; fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime s: Scales.ArgOpts) type { return struct { diff --git a/src/Tensor.zig b/src/Tensor.zig index 7619c03..0335dc2 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -3,127 +3,8 @@ const Scales = @import("Scales.zig"); const UnitScale = Scales.UnitScale; const Dimensions = @import("Dimensions.zig"); const Dimension = Dimensions.Dimension; +const sh = @import("shared.zig"); -// ───────────────────────────────────────────────────────────────────────────── -// Comptime utilities -// ───────────────────────────────────────────────────────────────────────────── - -pub fn shapeTotal(shape: []const comptime_int) usize { - var t: comptime_int = 1; - for (shape) |s| t *= s; - return t; -} - -/// Check if two shapes are strictly identical. -pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool { - if (a.len != b.len) return false; - for (a, 0..) |v, i| - if (v != b[i]) return false; - return true; -} - -/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). -/// e.g. shape {3, 4} → strides {4, 1} -/// shape {2, 3, 4} → strides {12, 4, 1} -pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int { - var st: [shape.len]comptime_int = undefined; - if (shape.len == 0) return st; - st[shape.len - 1] = 1; - if (shape.len > 1) { - var i: comptime_int = shape.len - 1; - while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i]; - } - return st; -} - -/// Return a copy of `shape` with the element at `axis` removed. -pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int { - var out: [shape.len - 1]comptime_int = undefined; - var j: comptime_int = 0; - for (shape, 0..) |v, i| { - if (i != axis) { - out[j] = v; - j += 1; - } - } - return out; -} - -/// Concatenate two compile-time slices. -pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int { - var out: [a.len + b.len]comptime_int = undefined; - for (a, 0..) |v, i| out[i] = v; - for (b, 0..) |v, i| out[a.len + i] = v; - return out; -} - -/// Decode a flat row-major index into N-D coordinates. -/// Called only in comptime contexts (all arguments are comptime). -pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize { - var coords: [n]comptime_int = undefined; - var tmp = flat; - for (0..n) |i| { - coords[i] = if (strd[i] == 0) 0 else tmp / strd[i]; - tmp = if (strd[i] == 0) 0 else tmp % strd[i]; - } - return coords; -} - -/// Encode N-D coordinates into a flat row-major index. -/// Called only in comptime contexts. -pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize { - var flat: usize = 0; - for (0..n) |i| flat += coords[i] * strd[i]; - return flat; -} - -/// Rebuild a full coordinate array by inserting `val` at `axis` into `free`. -/// `free` holds the remaining (non-contracted) coordinates in order. -pub fn insertAxis( - comptime n: usize, - comptime axis: usize, - comptime val: usize, - comptime free: []const usize, -) [n]usize { - var out: [n]usize = undefined; - var fi: usize = 0; - for (0..n) |i| { - if (i == axis) { - out[i] = val; - } else { - out[i] = free[fi]; - fi += 1; - } - } - return out; -} - -inline fn isInt(comptime T: type) bool { - return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int; -} - -fn finerScales(comptime T1: type, comptime T2: type) Scales { - const d1: Dimensions = T1.dims; - const d2: Dimensions = T2.dims; - const s1: Scales = T1.scales; - const s2: Scales = T2.scales; - comptime var out = Scales.initFill(.none); - for (std.enums.values(Dimension)) |dim| { - const scale1 = comptime s1.get(dim); - const scale2 = comptime s2.get(dim); - out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0) - .none - else if (comptime d1.get(dim) == 0) - scale2 - else if (comptime d2.get(dim) == 0) - scale1 - else if (comptime scale1.getFactor() > scale2.getFactor()) - scale2 - else - scale1); - } - return out; -} // ───────────────────────────────────────────────────────────────────────────── // File-scope RHS normalisation helpers // ───────────────────────────────────────────────────────────────────────────── @@ -134,7 +15,7 @@ inline fn isTensor(comptime Rhs: type) bool { inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type { if (comptime isTensor(Rhs)) return Rhs; - return Tensor(T, .{}, .{}, &.{1}); + return TensorStatic(T, .{}, .{}, &.{1}); } /// Take the anyvalue coming from operation and if it is a Tensor, return it. @@ -162,37 +43,13 @@ inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) }, else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)), }; - return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} }; + return TensorStatic(T, .{}, .{}, &.{1}){ .data = .{scalar} }; } -pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void { - if (n == 0) return; - var val = n; - if (val < 0) { - try writer.writeAll("\u{207B}"); - val = -val; - } - var buf: [12]u8 = undefined; - const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return; - for (str) |c| { - const s = switch (c) { - '0' => "\u{2070}", - '1' => "\u{00B9}", - '2' => "\u{00B2}", - '3' => "\u{00B3}", - '4' => "\u{2074}", - '5' => "\u{2075}", - '6' => "\u{2076}", - '7' => "\u{2077}", - '8' => "\u{2078}", - '9' => "\u{2079}", - else => unreachable, - }; - try writer.writeAll(s); - } -} - -pub fn Tensor( +/// SIMD implementation of a Tensor. +/// Limited to tensor of ~2000 values. +/// For more, see either TensorAlloc or TensorGPU +pub fn TensorStatic( comptime T: type, comptime d_opt: Dimensions.ArgOpts, comptime s_opt: Scales.ArgOpts, @@ -204,12 +61,15 @@ pub fn Tensor( if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1."); } } - @setEvalBranchQuota(100_000); + @setEvalBranchQuota(100_000_000); - const _total: usize = comptime shapeTotal(shape_); - const _strides = comptime shapeStrides(shape_); + const _total: usize = comptime sh.shapeTotal(shape_); + const _strides = comptime sh.shapeStrides(shape_); const Vec = @Vector(_total, T); + if (comptime _total * @bitSizeOf(T) > 1_000_000) + @compileError("Tensor too big, consider using a TensorGPU or TensorAlloc."); + return struct { data: Vec, @@ -266,98 +126,103 @@ pub fn Tensor( /// Element-wise add. Dimensions must match; scales resolve to finer. /// RHS must have the same shape as self, or total == 1 (broadcast). - pub inline fn add(self: *const Self, r: anytype) Tensor( + pub inline fn add(self: *const Self, r: anytype) TensorStatic( T, dims.argsOpt(), - finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { const rhs_t = rhs(r); const RhsType = @TypeOf(rhs_t); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too - return .{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; + return .{ .data = if (comptime sh.isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; - const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const RhsNorm = TensorStatic( + T, + RhsType.dims.argsOpt(), + sh.finerScales(Self, RhsType).argsOpt(), + RhsType.shape, + ); const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; - return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; + return .{ .data = if (comptime sh.isInt(T)) l +| rr else l + rr }; } /// Element-wise sub. Dimensions must match; scales resolve to finer. /// RHS must have the same shape as self, or total == 1 (broadcast). - pub inline fn sub(self: *const Self, r: anytype) Tensor( + pub inline fn sub(self: *const Self, r: anytype) TensorStatic( T, dims.argsOpt(), - finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { const rhs_t = rhs(r); const RhsType = @TypeOf(rhs_t); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS."); if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too - return .{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; + return .{ .data = if (comptime sh.isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; - const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; - return .{ .data = if (comptime isInt(T)) l -| rr else l - rr }; + return .{ .data = if (comptime sh.isInt(T)) l -| rr else l - rr }; } /// Element-wise multiply. Dimension exponents summed. /// Shape {1} RHS is automatically broadcast across all elements. - pub inline fn mul(self: *const Self, r: anytype) Tensor( + pub inline fn mul(self: *const Self, r: anytype) TensorStatic( T, dims.add(RhsT(@TypeOf(r)).dims).argsOpt(), - finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS."); - const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr: Vec = broadcastToVec(RhsNorm, rr_base); - return .{ .data = if (comptime isInt(T)) l *| rr else l * rr }; + return .{ .data = if (comptime sh.isInt(T)) l *| rr else l * rr }; } /// Element-wise divide. Dimension exponents subtracted. /// Shape {1} RHS is automatically broadcast across all elements. - pub inline fn div(self: *const Self, r: anytype) Tensor( + pub inline fn div(self: *const Self, r: anytype) TensorStatic( T, dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(), - finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), + sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { const rhs_q = rhs(r); const RhsType = @TypeOf(rhs_q); - if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS."); - const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_); + const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr: Vec = broadcastToVec(RhsNorm, rr_base); - if (comptime isInt(T)) { + if (comptime sh.isInt(T)) { return .{ .data = @divTrunc(l, rr) }; } else { return .{ .data = l / rr }; @@ -370,7 +235,7 @@ pub fn Tensor( } /// Raise every element to a comptime integer exponent. - pub inline fn pow(self: *const Self, comptime exp: comptime_int) Tensor( + pub inline fn pow(self: *const Self, comptime exp: comptime_int) TensorStatic( T, dims.scale(exp).argsOpt(), scales.argsOpt(), @@ -386,21 +251,21 @@ pub fn Tensor( // $O(\log n)$ Exponentiation by squaring applied to the entire vector inline while (e > 0) { if (e % 2 == 1) { - result = if (comptime isInt(T)) result *| base else result * base; + result = if (comptime sh.isInt(T)) result *| base else result * base; } e /= 2; if (e > 0) { - base = if (comptime isInt(T)) base *| base else base * base; + base = if (comptime sh.isInt(T)) base *| base else base * base; } } - if (comptime !isInt(T) and exp < 0) { + if (comptime !sh.isInt(T) and exp < 0) { result = @as(Vec, @splat(1)) / result; } return .{ .data = result }; } /// Square root of every element. All dimension exponents must be even. - pub inline fn sqrt(self: *const Self) Tensor( + pub inline fn sqrt(self: *const Self) TensorStatic( T, dims.div(2).argsOpt(), scales.argsOpt(), @@ -434,15 +299,15 @@ pub fn Tensor( pub inline fn to( self: *const Self, comptime Dest: type, - ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { - const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); + ) TensorStatic(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { + const ActualDest = TensorStatic(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); if (comptime Self == ActualDest) return self; // Run validation checks FIRST before dealing with types if (comptime !dims.eql(ActualDest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); - if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) + if (comptime Dest.total != 1 and !sh.shapeEql(shape_, Dest.shape)) @compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar."); const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); @@ -519,13 +384,13 @@ pub fn Tensor( /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. inline fn resolveScalePair(self: *const Self, rhs_q: anytype) struct { l: Vec, r: Vec } { const RhsType = @TypeOf(rhs_q); - if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) + if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); - const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); + const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { - const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); + const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; @@ -604,41 +469,41 @@ pub fn Tensor( if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds"); if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes"); - const sa = shapeRemoveAxis(shape_, axis_a); - const sb = shapeRemoveAxis(OT.shape, axis_b); - const rs_raw = shapeCat(&sa, &sb); + const sa = sh.shapeRemoveAxis(shape_, axis_a); + const sb = sh.shapeRemoveAxis(OT.shape, axis_b); + const rs_raw = sh.shapeCat(&sa, &sb); const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw; - break :blk Tensor( + break :blk TensorStatic( T, dims.add(OT.dims).argsOpt(), - finerScales(Self, OT).argsOpt(), + sh.finerScales(Self, OT).argsOpt(), rs, ); } { const OT = @TypeOf(other); const k: usize = comptime shape_[axis_a]; // contraction dimension - const sa = comptime shapeRemoveAxis(shape_, axis_a); - const sb = comptime shapeRemoveAxis(OT.shape, axis_b); - const rs_raw = comptime shapeCat(&sa, &sb); + const sa = comptime sh.shapeRemoveAxis(shape_, axis_a); + const sb = comptime sh.shapeRemoveAxis(OT.shape, axis_b); + const rs_raw = comptime sh.shapeCat(&sa, &sb); const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; - const ResultType = Tensor( + const ResultType = TensorStatic( T, dims.add(OT.dims).argsOpt(), - finerScales(Self, OT).argsOpt(), + sh.finerScales(Self, OT).argsOpt(), rs, ); - const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_); - const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape); + const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), shape_); + const OtherNorm = TensorStatic(T, OT.dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), OT.shape); const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; // FAST PATH: Dot Product if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) { - if (comptime !isInt(T)) { + if (comptime !sh.isInt(T)) { return .{ .data = @splat(@reduce(.Add, a_data * b_data)) }; } else { // For integers, we do a vectorized saturating multiply, @@ -671,7 +536,7 @@ pub fn Tensor( const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1]; // Use a_arr and b_arr here - if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; + if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; } // Write to the array res_arr[i * cols + j] = acc; @@ -682,13 +547,13 @@ pub fn Tensor( } // FALLBACK PATH - const rs_raw_strides = comptime shapeStrides(&rs_raw); + const rs_raw_strides = comptime sh.shapeStrides(&rs_raw); // Create a mutable array for the result var result_arr: [ResultType.total]T = undefined; for (0..ResultType.total) |res_flat| { - const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); + const res_coords = sh.decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); var a_free: [sa.len]usize = undefined; for (0..sa.len) |i| a_free[i] = res_coords[i]; @@ -697,13 +562,13 @@ pub fn Tensor( var acc: T = 0; for (0..k) |ki| { - const a_coords = insertAxis(rank, axis_a, ki, &a_free); - const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free); - const a_flat = encodeFlatCoords(&a_coords, rank, _strides); - const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); + const a_coords = sh.insertAxis(rank, axis_a, ki, &a_free); + const b_coords = sh.insertAxis(OT.rank, axis_b, ki, &b_free); + const a_flat = sh.encodeFlatCoords(&a_coords, rank, _strides); + const b_flat = sh.encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); // Use a_arr and b_arr here - if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; + if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; } // Write to the array result_arr[res_flat] = acc; @@ -715,10 +580,10 @@ pub fn Tensor( /// 3D Cross Product. Only defined for Rank-1 tensors of length 3. /// Result dimensions are the sum of input dimensions. - pub inline fn cross(self: *const Self, other: anytype) Tensor( + pub inline fn cross(self: *const Self, other: anytype) TensorStatic( T, dims.add(RhsT(@TypeOf(other)).dims).argsOpt(), - finerScales(Self, RhsT(@TypeOf(other))).argsOpt(), + sh.finerScales(Self, RhsT(@TypeOf(other))).argsOpt(), &.{3}, ) { const rhs_q = rhs(other); @@ -734,7 +599,7 @@ pub fn Tensor( const r = p.r; var res: [3]T = undefined; - if (comptime isInt(T)) { + if (comptime sh.isInt(T)) { res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]); res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]); res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]); @@ -763,7 +628,7 @@ pub fn Tensor( } /// Product of all elements. Result has shape {1}; dimension exponent * total. - pub inline fn product(self: *const Self) Tensor( + pub inline fn product(self: *const Self) TensorStatic( T, dims.scale(@as(comptime_int, total)).argsOpt(), scales.argsOpt(), @@ -822,7 +687,7 @@ pub fn Tensor( else try writer.print("{s}{s}", .{ uscale.str(), bu.unit() }); - if (v != 1) try printSuperscript(writer, v); + if (v != 1) try sh.printSuperscript(writer, v); } } }; @@ -835,8 +700,8 @@ pub fn Tensor( // ─── Scalar tests ───────────────────────────────────────────────────────── test "Scalar initiat" { - const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{ .T = .n }, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{ .L = @enumFromInt(-3) }, &.{1}); + const Second = TensorStatic(f32, .{ .T = 1 }, .{ .T = .n }, &.{1}); const distance = Meter.splat(10); const time = Second.splat(2); @@ -846,8 +711,8 @@ test "Scalar initiat" { } test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); const m1000 = Meter.splat(1000); const km1 = KiloMeter.splat(1); @@ -869,9 +734,9 @@ test "Scalar comparisons (eq, ne, gt, gte, lt, lte)" { } test "Scalar Add" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const KiloMeter_f = TensorStatic(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); const distance = Meter.splat(10); const distance2 = Meter.splat(20); @@ -892,8 +757,8 @@ test "Scalar Add" { } test "Scalar Sub" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter_f = TensorStatic(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); const a = Meter.splat(500); const b = Meter.splat(200); @@ -909,8 +774,8 @@ test "Scalar Sub" { } test "Scalar MulBy" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = TensorStatic(f32, .{ .T = 1 }, .{}, &.{1}); const d = Meter.splat(3); const t = Second.splat(4); @@ -926,8 +791,8 @@ test "Scalar MulBy" { } test "Scalar MulBy with scale" { - const KiloMeter = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const KiloGram = Tensor(f32, .{ .M = 1 }, .{ .M = .k }, &.{1}); + const KiloMeter = TensorStatic(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const KiloGram = TensorStatic(f32, .{ .M = 1 }, .{ .M = .k }, &.{1}); const dist = KiloMeter.splat(2.0); const mass = KiloGram.splat(3.0); @@ -937,10 +802,10 @@ test "Scalar MulBy with scale" { } test "Scalar MulBy with type change" { - const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const Second = Tensor(f64, .{ .T = 1 }, .{}, &.{1}); - const KmSec = Tensor(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); - const KmSec_f = Tensor(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Second = TensorStatic(f64, .{ .T = 1 }, .{}, &.{1}); + const KmSec = TensorStatic(i64, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); + const KmSec_f = TensorStatic(f32, .{ .L = 1, .T = 1 }, .{ .L = .k }, &.{1}); const d = Meter.splat(3); const t = Second.splat(4); @@ -950,24 +815,24 @@ test "Scalar MulBy with type change" { } test "Scalar MulBy small" { - const Meter = Tensor(i128, .{ .L = 1 }, .{ .L = .n }, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .n }, &.{1}); + const Second = TensorStatic(f32, .{ .T = 1 }, .{}, &.{1}); const d = Meter.splat(3); const t = Second.splat(4); try std.testing.expectEqual(12, d.mul(t).data[0]); } test "Scalar MulBy dimensionless" { - const DimLess = Tensor(i128, .{}, .{}, &.{1}); - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const DimLess = TensorStatic(i128, .{}, .{}, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); const d = Meter.splat(7); const scaled = d.mul(DimLess.splat(3)); try std.testing.expectEqual(21, scaled.data[0]); } test "Scalar Sqrt" { - const MeterSquare = Tensor(i128, .{ .L = 2 }, .{}, &.{1}); - const MeterSquare_f = Tensor(f64, .{ .L = 2 }, .{}, &.{1}); + const MeterSquare = TensorStatic(i128, .{ .L = 2 }, .{}, &.{1}); + const MeterSquare_f = TensorStatic(f64, .{ .L = 2 }, .{}, &.{1}); var d = MeterSquare.splat(9); var scaled = d.sqrt(); @@ -984,8 +849,8 @@ test "Scalar Sqrt" { } test "Scalar Chained: velocity and acceleration" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = TensorStatic(f32, .{ .T = 1 }, .{}, &.{1}); const dist = Meter.splat(100); const t1 = Second.splat(5); @@ -998,8 +863,8 @@ test "Scalar Chained: velocity and acceleration" { } test "Scalar DivBy integer exact" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const Second = Tensor(f32, .{ .T = 1 }, .{}, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); + const Second = TensorStatic(f32, .{ .T = 1 }, .{}, &.{1}); const dist = Meter.splat(120); const time = Second.splat(4); @@ -1008,8 +873,8 @@ test "Scalar DivBy integer exact" { } test "Scalar Finer scales skip dim 0" { - const Dimless = Tensor(i128, .{}, .{}, &.{1}); - const KiloMetre = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Dimless = TensorStatic(i128, .{}, .{}, &.{1}); + const KiloMetre = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); const r = Dimless.splat(30); const km = KiloMetre.splat(4); @@ -1019,9 +884,9 @@ test "Scalar Finer scales skip dim 0" { } test "Scalar Conversion chain: km -> m -> cm" { - const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const CentiMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .c }, &.{1}); + const KiloMeter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); + const CentiMeter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .c }, &.{1}); const km = KiloMeter.splat(15); const m = km.to(Meter); @@ -1031,9 +896,9 @@ test "Scalar Conversion chain: km -> m -> cm" { } test "Scalar Conversion: hours -> minutes -> seconds" { - const Hour = Tensor(i128, .{ .T = 1 }, .{ .T = .hour }, &.{1}); - const Minute = Tensor(i128, .{ .T = 1 }, .{ .T = .min }, &.{1}); - const Second = Tensor(i128, .{ .T = 1 }, .{}, &.{1}); + const Hour = TensorStatic(i128, .{ .T = 1 }, .{ .T = .hour }, &.{1}); + const Minute = TensorStatic(i128, .{ .T = 1 }, .{ .T = .min }, &.{1}); + const Second = TensorStatic(i128, .{ .T = 1 }, .{}, &.{1}); const h = Hour.splat(1); const min = h.to(Minute); @@ -1043,8 +908,8 @@ test "Scalar Conversion: hours -> minutes -> seconds" { } test "Scalar Format" { - const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{1}); - const Meter = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + const MeterPerSecondSq = TensorStatic(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{1}); + const Meter = TensorStatic(f32, .{ .L = 1 }, .{}, &.{1}); const m = Meter.splat(1.23456); const accel = MeterPerSecondSq.splat(9.81); @@ -1058,52 +923,52 @@ test "Scalar Format" { } test "Scalar Abs" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); - const MeterF = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); + const MeterF = TensorStatic(f32, .{ .L = 1 }, .{}, &.{1}); try std.testing.expectEqual(50, Meter.splat(-50).abs().data[0]); try std.testing.expectEqual(42.5, MeterF.splat(-42.5).abs().data[0]); } test "Scalar Pow" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); const d = Meter.splat(4); try std.testing.expectEqual(16, d.pow(2).data[0]); try std.testing.expectEqual(64, d.pow(3).data[0]); } test "Scalar mul comptime_int" { - const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1}); const d = Meter.splat(7); try std.testing.expectEqual(21, d.mul(3).data[0]); } test "Scalar add/sub bare number on dimensionless scalar" { - const DimLess = Tensor(i128, .{}, .{}, &.{1}); + const DimLess = TensorStatic(i128, .{}, .{}, &.{1}); const a = DimLess.splat(10); try std.testing.expectEqual(15, a.add(5).data[0]); try std.testing.expectEqual(7, a.sub(3).data[0]); } test "Scalar Imperial length scales" { - const Foot = Tensor(f64, .{ .L = 1 }, .{ .L = .ft }, &.{1}); - const Meter = Tensor(f64, .{ .L = 1 }, .{}, &.{1}); - const Inch = Tensor(f64, .{ .L = 1 }, .{ .L = .inch }, &.{1}); + const Foot = TensorStatic(f64, .{ .L = 1 }, .{ .L = .ft }, &.{1}); + const Meter = TensorStatic(f64, .{ .L = 1 }, .{}, &.{1}); + const Inch = TensorStatic(f64, .{ .L = 1 }, .{ .L = .inch }, &.{1}); try std.testing.expectApproxEqAbs(0.3048, Foot.splat(1.0).to(Meter).data[0], 1e-9); try std.testing.expectApproxEqAbs(1.0, Inch.splat(12.0).to(Foot).data[0], 1e-9); } test "Scalar Imperial mass scales" { - const Pound = Tensor(f64, .{ .M = 1 }, .{ .M = .lb }, &.{1}); - const Ounce = Tensor(f64, .{ .M = 1 }, .{ .M = .oz }, &.{1}); + const Pound = TensorStatic(f64, .{ .M = 1 }, .{ .M = .lb }, &.{1}); + const Ounce = TensorStatic(f64, .{ .M = 1 }, .{ .M = .oz }, &.{1}); const total = Pound.splat(2.0).add(Ounce.splat(8.0)).to(Pound); try std.testing.expectApproxEqAbs(2.5, total.data[0], 1e-6); } test "Scalar comparisons with comptime_int on dimensionless scalar" { - const DimLess = Tensor(i128, .{}, .{}, &.{1}); + const DimLess = TensorStatic(i128, .{}, .{}, &.{1}); const x = DimLess.splat(42); try std.testing.expect(x.eq(42)); try std.testing.expect(x.gt(10)); @@ -1112,15 +977,15 @@ test "Scalar comparisons with comptime_int on dimensionless scalar" { // ─── Vector / Tensor tests ──────────────────────────────────────────────── test "Vector initiate" { - const Meter4 = Tensor(f32, .{ .L = 1 }, .{}, &.{4}); + const Meter4 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{4}); const m = Meter4.splat(1); try std.testing.expect(m.data[0] == 1); try std.testing.expect(m.data[3] == 1); } test "Vector format" { - const MeterPerSecondSq = Tensor(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{3}); - const KgMeterPerSecond = Tensor(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3}); + const MeterPerSecondSq = TensorStatic(f32, .{ .L = 1, .T = -2 }, .{ .T = .n }, &.{3}); + const KgMeterPerSecond = TensorStatic(f32, .{ .M = 1, .L = 1, .T = -1 }, .{ .M = .k }, &.{3}); const accel = MeterPerSecondSq.splat(9.81); const momentum = KgMeterPerSecond{ .data = .{ 43, 0, 11 } }; @@ -1134,7 +999,7 @@ test "Vector format" { } test "Vector Vec3 Init and Basic Arithmetic" { - const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3}); const v_zero = Meter3.zero; try std.testing.expectEqual(0, v_zero.data[0]); @@ -1166,8 +1031,8 @@ test "Vector Vec3 Init and Basic Arithmetic" { } test "Vector Kinematics (scalar mul/div broadcast)" { - const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); - const Second1 = Tensor(i32, .{ .T = 1 }, .{}, &.{1}); + const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3}); + const Second1 = TensorStatic(i32, .{ .T = 1 }, .{}, &.{1}); const pos = Meter3{ .data = .{ 100, 200, 300 } }; const time = Second1.splat(10); @@ -1185,7 +1050,7 @@ test "Vector Kinematics (scalar mul/div broadcast)" { } test "Vector Element-wise Math and Scaling" { - const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3}); const v1 = Meter3{ .data = .{ 10, 20, 30 } }; const v2 = Meter3{ .data = .{ 2, 5, 10 } }; @@ -1197,8 +1062,8 @@ test "Vector Element-wise Math and Scaling" { } test "Vector Conversions" { - const KiloMeter3 = Tensor(i32, .{ .L = 1 }, .{ .L = .k }, &.{3}); - const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter3 = TensorStatic(i32, .{ .L = 1 }, .{ .L = .k }, &.{3}); + const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3}); const v_km = KiloMeter3{ .data = .{ 1, 2, 3 } }; const v_m = v_km.to(Meter3); @@ -1209,8 +1074,8 @@ test "Vector Conversions" { } test "Vector Length" { - const MeterInt3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); - const MeterFloat3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const MeterInt3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3}); + const MeterFloat3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3}); const v_int = MeterInt3{ .data = .{ 3, 4, 0 } }; try std.testing.expectEqual(25, v_int.lengthSqr()); @@ -1222,8 +1087,8 @@ test "Vector Length" { } test "Vector Comparisons" { - const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - const KiloMeter3 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{3}); + const Meter3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter3 = TensorStatic(f32, .{ .L = 1 }, .{ .L = .k }, &.{3}); const v1 = Meter3{ .data = .{ 1000.0, 500.0, 0.0 } }; const v2 = KiloMeter3{ .data = .{ 1.0, 0.5, 0.0 } }; @@ -1247,8 +1112,8 @@ test "Vector Comparisons" { } test "Vector vs Scalar broadcast comparison" { - const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const Meter3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3}); + const KiloMeter1 = TensorStatic(f32, .{ .L = 1 }, .{ .L = .k }, &.{1}); const positions = Meter3{ .data = .{ 500.0, 1200.0, 3000.0 } }; const threshold = KiloMeter1.splat(1); // 1 km = 1000 m @@ -1258,15 +1123,15 @@ test "Vector vs Scalar broadcast comparison" { try std.testing.expectEqual(true, exceeded[1]); try std.testing.expectEqual(true, exceeded[2]); - const Meter1 = Tensor(f32, .{ .L = 1 }, .{}, &.{1}); + const Meter1 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{1}); const exact = positions.eq(Meter1.splat(500)); try std.testing.expect(exact[0] == true); try std.testing.expect(exact[1] == false); } test "Vector contract — dot product (rank-1 * rank-1)" { - const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); - const Newton3 = Tensor(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3}); + const Meter3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3}); + const Newton3 = TensorStatic(f32, .{ .M = 1, .L = 1, .T = -2 }, .{}, &.{3}); const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } }; const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } }; @@ -1279,21 +1144,21 @@ test "Vector contract — dot product (rank-1 * rank-1)" { } test "Vector contract — matrix multiply (rank-2 * rank-2)" { - const A = Tensor(f32, .{}, .{}, &.{ 2, 3 }); - const B = Tensor(f32, .{}, .{}, &.{ 3, 2 }); + const A = TensorStatic(f32, .{}, .{}, &.{ 2, 3 }); + const B = TensorStatic(f32, .{}, .{}, &.{ 3, 2 }); const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } }; const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } }; const c = a.contract(b, 1, 0); - try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]); - try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]); - try std.testing.expectEqual(139, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 0 })]); - try std.testing.expectEqual(154, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 1 })]); + try std.testing.expectEqual(58, c.data[TensorStatic(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]); + try std.testing.expectEqual(64, c.data[TensorStatic(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]); + try std.testing.expectEqual(139, c.data[TensorStatic(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 0 })]); + try std.testing.expectEqual(154, c.data[TensorStatic(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 1, 1 })]); } test "Vector Abs, Pow, Sqrt and Product" { - const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const Meter3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3}); const v1 = Meter3{ .data = .{ -2.0, 3.0, -4.0 } }; const v_abs = v1.abs(); @@ -1316,7 +1181,7 @@ test "Vector Abs, Pow, Sqrt and Product" { } test "Vector mul comptime_int broadcast" { - const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3}); const v = Meter3{ .data = .{ 1, 2, 3 } }; const scaled = v.mul(10); try std.testing.expectEqual(10, scaled.data[0]); @@ -1326,7 +1191,7 @@ test "Vector mul comptime_int broadcast" { } test "Vector mul comptime_float broadcast" { - const MeterF3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3}); + const MeterF3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3}); const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } }; const scaled = v.mul(0.5); try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6); @@ -1336,7 +1201,7 @@ test "Vector mul comptime_float broadcast" { } test "Vector div comptime_int broadcast" { - const Meter3 = Tensor(i32, .{ .L = 1 }, .{}, &.{3}); + const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3}); const v = Meter3{ .data = .{ 10, 20, 30 } }; const halved = v.div(2); try std.testing.expectEqual(5, halved.data[0]); @@ -1346,7 +1211,7 @@ test "Vector div comptime_int broadcast" { } test "Vector div comptime_float broadcast" { - const MeterF3 = Tensor(f64, .{ .L = 1 }, .{}, &.{3}); + const MeterF3 = TensorStatic(f64, .{ .L = 1 }, .{}, &.{3}); const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } }; const r = v.div(3.0); try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9); @@ -1355,7 +1220,7 @@ test "Vector div comptime_float broadcast" { } test "Vector eq broadcast on dimensionless" { - const DimLess3 = Tensor(i32, .{}, .{}, &.{3}); + const DimLess3 = TensorStatic(i32, .{}, .{}, &.{3}); const v = DimLess3{ .data = .{ 1, 2, 3 } }; const eq_res = v.eq(2); @@ -1370,7 +1235,7 @@ test "Vector eq broadcast on dimensionless" { } test "Tensor idx helper and matrix access" { - const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 }); + const Mat3x3 = TensorStatic(f32, .{}, .{}, &.{ 3, 3 }); var m: Mat3x3 = Mat3x3.zero; m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0; m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0; @@ -1383,9 +1248,9 @@ test "Tensor idx helper and matrix access" { } test "Tensor strides_arr correctness" { - const T1 = Tensor(f32, .{}, .{}, &.{3}); - const T2 = Tensor(f32, .{}, .{}, &.{ 3, 4 }); - const T3 = Tensor(f32, .{}, .{}, &.{ 2, 3, 4 }); + const T1 = TensorStatic(f32, .{}, .{}, &.{3}); + const T2 = TensorStatic(f32, .{}, .{}, &.{ 3, 4 }); + const T3 = TensorStatic(f32, .{}, .{}, &.{ 2, 3, 4 }); try std.testing.expectEqual(1, T1.strides_arr[0]); try std.testing.expectEqual(4, T2.strides_arr[0]); diff --git a/src/benchmark.zig b/src/benchmark.zig index df656a9..8b567c5 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -1,6 +1,6 @@ const std = @import("std"); const Io = std.Io; -const Tensor = @import("Tensor.zig").Tensor; +const Tensor = @import("Tensor.zig").TensorStatic; var io: Io = undefined; pub fn main(init: std.process.Init) !void { diff --git a/src/shared.zig b/src/shared.zig new file mode 100644 index 0000000..137011f --- /dev/null +++ b/src/shared.zig @@ -0,0 +1,149 @@ +const std = @import("std"); +const Scales = @import("Scales.zig"); +const UnitScale = Scales.UnitScale; +const Dimensions = @import("Dimensions.zig"); +const Dimension = Dimensions.Dimension; + +pub fn shapeTotal(shape: []const comptime_int) usize { + var t: comptime_int = 1; + for (shape) |s| t *= s; + return t; +} + +/// Check if two shapes are strictly identical. +pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool { + if (a.len != b.len) return false; + for (a, 0..) |v, i| + if (v != b[i]) return false; + return true; +} + +/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). +/// e.g. shape {3, 4} → strides {4, 1} +/// shape {2, 3, 4} → strides {12, 4, 1} +pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int { + var st: [shape.len]comptime_int = undefined; + if (shape.len == 0) return st; + st[shape.len - 1] = 1; + if (shape.len > 1) { + var i: comptime_int = shape.len - 1; + while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i]; + } + return st; +} + +/// Return a copy of `shape` with the element at `axis` removed. +pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int { + var out: [shape.len - 1]comptime_int = undefined; + var j: comptime_int = 0; + for (shape, 0..) |v, i| { + if (i != axis) { + out[j] = v; + j += 1; + } + } + return out; +} + +/// Concatenate two compile-time slices. +pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int { + var out: [a.len + b.len]comptime_int = undefined; + for (a, 0..) |v, i| out[i] = v; + for (b, 0..) |v, i| out[a.len + i] = v; + return out; +} + +/// Decode a flat row-major index into N-D coordinates. +/// Called only in comptime contexts (all arguments are comptime). +pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize { + var coords: [n]comptime_int = undefined; + var tmp = flat; + for (0..n) |i| { + coords[i] = if (strd[i] == 0) 0 else tmp / strd[i]; + tmp = if (strd[i] == 0) 0 else tmp % strd[i]; + } + return coords; +} + +/// Encode N-D coordinates into a flat row-major index. +/// Called only in comptime contexts. +pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize { + var flat: usize = 0; + for (0..n) |i| flat += coords[i] * strd[i]; + return flat; +} + +/// Rebuild a full coordinate array by inserting `val` at `axis` into `free`. +/// `free` holds the remaining (non-contracted) coordinates in order. +pub fn insertAxis( + comptime n: usize, + comptime axis: usize, + comptime val: usize, + comptime free: []const usize, +) [n]usize { + var out: [n]usize = undefined; + var fi: usize = 0; + for (0..n) |i| { + if (i == axis) { + out[i] = val; + } else { + out[i] = free[fi]; + fi += 1; + } + } + return out; +} + +pub inline fn isInt(comptime T: type) bool { + return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int; +} + +pub fn finerScales(comptime T1: type, comptime T2: type) Scales { + const d1: Dimensions = T1.dims; + const d2: Dimensions = T2.dims; + const s1: Scales = T1.scales; + const s2: Scales = T2.scales; + comptime var out = Scales.initFill(.none); + for (std.enums.values(Dimension)) |dim| { + const scale1 = comptime s1.get(dim); + const scale2 = comptime s2.get(dim); + out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0) + .none + else if (comptime d1.get(dim) == 0) + scale2 + else if (comptime d2.get(dim) == 0) + scale1 + else if (comptime scale1.getFactor() > scale2.getFactor()) + scale2 + else + scale1); + } + return out; +} + +pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void { + if (n == 0) return; + var val = n; + if (val < 0) { + try writer.writeAll("\u{207B}"); + val = -val; + } + var buf: [12]u8 = undefined; + const str = std.fmt.bufPrint(&buf, "{d}", .{val}) catch return; + for (str) |c| { + const s = switch (c) { + '0' => "\u{2070}", + '1' => "\u{00B9}", + '2' => "\u{00B2}", + '3' => "\u{00B3}", + '4' => "\u{2074}", + '5' => "\u{2075}", + '6' => "\u{2076}", + '7' => "\u{2077}", + '8' => "\u{2078}", + '9' => "\u{2079}", + else => unreachable, + }; + try writer.writeAll(s); + } +}