Removed all inline for (0..total) for either builtin or for loop without inline

This is to prevent giant binary for Tensor with a lot of Scalar
This commit is contained in:
AdrienBouvais 2026-04-27 16:11:46 +02:00
parent cd954b379b
commit 44aaa8a8b2
4 changed files with 190 additions and 123 deletions

View File

@ -51,17 +51,17 @@ data: std.EnumArray(Dimension, comptime_int),
/// Unspecified dimensions default to 0. /// Unspecified dimensions default to 0.
pub fn init(comptime init_val: ArgOpts) Self { pub fn init(comptime init_val: ArgOpts) Self {
var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) }; var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) };
inline for (std.meta.fields(@TypeOf(init_val))) |f| for (std.meta.fields(@TypeOf(init_val))) |f|
s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
return s; return s;
} }
pub fn initFill(comptime val: comptime_int) Self { pub fn initFill(comptime val: comptime_int) Self {
return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) }; comptime return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) };
} }
pub fn get(comptime self: Self, comptime key: Dimension) comptime_int { pub fn get(comptime self: Self, comptime key: Dimension) comptime_int {
return self.data.get(key); comptime return self.data.get(key);
} }
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void { pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void {
@ -70,40 +70,40 @@ pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void
pub fn argsOpt(self: Self) ArgOpts { pub fn argsOpt(self: Self) ArgOpts {
var args: ArgOpts = undefined; var args: ArgOpts = undefined;
inline for (std.enums.values(Dimension)) |d| for (std.enums.values(Dimension)) |d|
@field(args, @tagName(d)) = self.get(d); @field(args, @tagName(d)) = self.get(d);
return args; comptime return args;
} }
/// Add exponents component-wise. Used internally by `mul`. /// Add exponents component-wise. Used internally by `mul`.
pub fn add(comptime a: Self, comptime b: Self) Self { pub fn add(comptime a: Self, comptime b: Self) Self {
var result = Self.initFill(0); var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d| for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) + b.get(d)); result.set(d, a.get(d) + b.get(d));
return result; comptime return result;
} }
/// Subtract exponents component-wise. Used internally by `div`. /// Subtract exponents component-wise. Used internally by `div`.
pub fn sub(comptime a: Self, comptime b: Self) Self { pub fn sub(comptime a: Self, comptime b: Self) Self {
var result = Self.initFill(0); var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d| for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) - b.get(d)); result.set(d, a.get(d) - b.get(d));
return result; comptime return result;
} }
/// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar. /// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar.
pub fn scale(comptime a: Self, comptime exp: comptime_int) Self { pub fn scale(comptime a: Self, comptime exp: comptime_int) Self {
var result = Self.initFill(0); var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d| for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) * exp); result.set(d, a.get(d) * exp);
return result; comptime return result;
} }
pub fn div(comptime a: Self, comptime exp: comptime_int) Self { pub fn div(comptime a: Self, comptime exp: comptime_int) Self {
var result = Self.initFill(0); var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d| inline for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) / exp); result.set(d, a.get(d) / exp);
return result; comptime return result;
} }
/// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`. /// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`.

View File

@ -65,7 +65,7 @@ pub const UnitScale = enum(isize) {
} }
pub inline fn getFactor(self: @This()) comptime_float { pub inline fn getFactor(self: @This()) comptime_float {
return comptime switch (self) { comptime return switch (self) {
// Standard SI Exponents // Standard SI Exponents
inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))), inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))),
@ -83,7 +83,7 @@ pub const UnitScale = enum(isize) {
inline .lb => 453.59237, inline .lb => 453.59237,
inline .st => 6350.29318, inline .st => 6350.29318,
inline else => @floatFromInt(@intFromEnum(self)), else => @floatFromInt(@intFromEnum(self)),
}; };
} }
}; };
@ -103,11 +103,11 @@ pub fn init(comptime init_val: ArgOpts) Self {
else else
s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
} }
return s; return comptime s;
} }
pub fn initFill(comptime val: UnitScale) Self { pub fn initFill(comptime val: UnitScale) Self {
return comptime .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; comptime return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) };
} }
pub fn get(comptime self: Self, comptime key: Dimension) UnitScale { pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
@ -115,7 +115,7 @@ pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
} }
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void { pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void {
comptime self.data.set(key, val); self.data.set(key, val);
} }
pub fn argsOpt(self: Self) ArgOpts { pub fn argsOpt(self: Self) ArgOpts {
@ -144,5 +144,5 @@ pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float
factor /= base; factor /= base;
} }
} }
return comptime factor; comptime return factor;
} }

View File

@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension;
// Comptime utilities // Comptime utilities
// //
pub fn shapeTotal(comptime shape: []const usize) usize { pub fn shapeTotal(comptime shape: []const comptime_int) usize {
var t: usize = 1; var t: comptime_int = 1;
for (shape) |s| t *= s; for (shape) |s| t *= s;
return t; return t;
} }
/// Check if two shapes are strictly identical. /// Check if two shapes are strictly identical.
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool {
if (a.len != b.len) return false; if (a.len != b.len) return false;
for (a, 0..) |v, i| for (a, 0..) |v, i|
if (v != b[i]) return false; if (v != b[i]) return false;
@ -25,21 +25,21 @@ pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). /// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
/// e.g. shape {3, 4} strides {4, 1} /// e.g. shape {3, 4} strides {4, 1}
/// shape {2, 3, 4} strides {12, 4, 1} /// shape {2, 3, 4} strides {12, 4, 1}
pub fn shapeStrides(comptime shape: []const usize) [shape.len]usize { pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int {
var st: [shape.len]usize = undefined; var st: [shape.len]comptime_int = undefined;
if (shape.len == 0) return st; if (shape.len == 0) return st;
st[shape.len - 1] = 1; st[shape.len - 1] = 1;
if (shape.len > 1) { if (shape.len > 1) {
var i: usize = shape.len - 1; var i: comptime_int = shape.len - 1;
while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i]; while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i];
} }
return st; return st;
} }
/// Return a copy of `shape` with the element at `axis` removed. /// Return a copy of `shape` with the element at `axis` removed.
pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [shape.len - 1]usize { pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int {
var out: [shape.len - 1]usize = undefined; var out: [shape.len - 1]comptime_int = undefined;
var j: usize = 0; var j: comptime_int = 0;
for (shape, 0..) |v, i| { for (shape, 0..) |v, i| {
if (i != axis) { if (i != axis) {
out[j] = v; out[j] = v;
@ -50,8 +50,8 @@ pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [sha
} }
/// Concatenate two compile-time slices. /// Concatenate two compile-time slices.
pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b.len]usize { pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int {
var out: [a.len + b.len]usize = undefined; var out: [a.len + b.len]comptime_int = undefined;
for (a, 0..) |v, i| out[i] = v; for (a, 0..) |v, i| out[i] = v;
for (b, 0..) |v, i| out[a.len + i] = v; for (b, 0..) |v, i| out[a.len + i] = v;
return out; return out;
@ -60,11 +60,11 @@ pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b
/// Decode a flat row-major index into N-D coordinates. /// Decode a flat row-major index into N-D coordinates.
/// Called only in comptime contexts (all arguments are comptime). /// Called only in comptime contexts (all arguments are comptime).
pub fn decodeFlatCoords( pub fn decodeFlatCoords(
comptime flat: usize, comptime flat: comptime_int,
comptime n: usize, comptime n: comptime_int,
comptime strd: [n]usize, comptime strd: [n]comptime_int,
) [n]usize { ) [n]usize {
var coords: [n]usize = undefined; var coords: [n]comptime_int = undefined;
var tmp = flat; var tmp = flat;
for (0..n) |i| { for (0..n) |i| {
coords[i] = if (strd[i] == 0) 0 else tmp / strd[i]; coords[i] = if (strd[i] == 0) 0 else tmp / strd[i];
@ -116,7 +116,7 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
const s1: Scales = T1.scales; const s1: Scales = T1.scales;
const s2: Scales = T2.scales; const s2: Scales = T2.scales;
comptime var out = Scales.initFill(.none); comptime var out = Scales.initFill(.none);
inline for (std.enums.values(Dimension)) |dim| { for (std.enums.values(Dimension)) |dim| {
const scale1 = comptime s1.get(dim); const scale1 = comptime s1.get(dim);
const scale2 = comptime s2.get(dim); const scale2 = comptime s2.get(dim);
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0) out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
@ -201,7 +201,7 @@ pub fn Tensor(
comptime T: type, comptime T: type,
comptime d_opt: Dimensions.ArgOpts, comptime d_opt: Dimensions.ArgOpts,
comptime s_opt: Scales.ArgOpts, comptime s_opt: Scales.ArgOpts,
comptime shape_: []const usize, comptime shape_: []const comptime_int,
) type { ) type {
comptime { comptime {
if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1)."); if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
@ -223,10 +223,10 @@ pub fn Tensor(
pub const ValueType: type = T; pub const ValueType: type = T;
pub const dims: Dimensions = Dimensions.init(d_opt); pub const dims: Dimensions = Dimensions.init(d_opt);
pub const scales: Scales = Scales.init(s_opt); pub const scales: Scales = Scales.init(s_opt);
pub const shape: []const usize = shape_; pub const shape: []const comptime_int = shape_;
pub const rank: usize = shape_.len; pub const rank: comptime_int = shape_.len;
pub const total: usize = _total; pub const total: comptime_int = _total;
pub const strides_arr: [shape_.len]usize = _strides; pub const strides_arr: [shape_.len]comptime_int = _strides;
pub const ISTENSOR = true; pub const ISTENSOR = true;
/// Convert N-D coords (row-major) to flat index fully comptime. /// Convert N-D coords (row-major) to flat index fully comptime.
@ -359,9 +359,7 @@ pub fn Tensor(
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
const rr: Vec = broadcastToVec(RhsNorm, rr_base); const rr: Vec = broadcastToVec(RhsNorm, rr_base);
if (comptime isInt(T)) { if (comptime isInt(T)) {
var result: Vec = undefined; return .{ .data = @divTrunc(l, rr) };
inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]);
return .{ .data = result };
} else { } else {
return .{ .data = l / rr }; return .{ .data = l / rr };
} }
@ -379,19 +377,27 @@ pub fn Tensor(
scales.argsOpt(), scales.argsOpt(),
shape_, shape_,
) { ) {
if (comptime isInt(T)) { if (comptime exp == 0) return .{ .data = @splat(1) };
var result: Vec = undefined; if (comptime exp == 1) return self;
inline for (0..total) |i|
result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T); var base = self.data;
return .{ .data = result };
} else {
const abs_exp = comptime @abs(exp);
var result: Vec = @splat(1); var result: Vec = @splat(1);
comptime var i = 0; comptime var e = @abs(exp);
inline while (i < abs_exp) : (i += 1) result *= self.data;
if (comptime exp < 0) result = @as(Vec, @splat(1)) / result; // $O(\log n)$ Exponentiation by squaring applied to the entire vector
return .{ .data = result }; inline while (e > 0) {
if (e % 2 == 1) {
result = if (comptime isInt(T)) result *| base else result * base;
} }
e /= 2;
if (e > 0) {
base = if (comptime isInt(T)) base *| base else base * base;
}
}
if (comptime !isInt(T) and exp < 0) {
result = @as(Vec, @splat(1)) / result;
}
return .{ .data = result };
} }
/// Square root of every element. All dimension exponents must be even. /// Square root of every element. All dimension exponents must be even.
@ -404,15 +410,16 @@ pub fn Tensor(
if (comptime !dims.isSquare()) if (comptime !dims.isSquare())
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
if (comptime @typeInfo(T) == .float) { if (comptime @typeInfo(T) == .float) {
return .{ .data = @sqrt(self.data) }; return .{ .data = @sqrt(self.data) }; // Float is natively vectorized!
} else { } else {
var result: Vec = undefined; const arr: [total]T = self.data; // Add this!
var res_arr: [total]T = undefined;
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
inline for (0..total) |i| { for (0..total) |i| {
const v = self.data[i]; const v = arr[i];
result[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
} }
return .{ .data = result }; return .{ .data = res_arr };
} }
} }
@ -433,28 +440,34 @@ pub fn Tensor(
if (comptime Self == ActualDest) return self; if (comptime Self == ActualDest) return self;
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); // Run validation checks FIRST before dealing with types
const DestT = ActualDest.ValueType;
const DestVec = @Vector(total, DestT);
// If ratio is 1, just handle type conversion
if (comptime ratio == 1.0) {
if (comptime T == DestT) return .{ .data = self.data };
return .{
.data = switch (comptime @typeInfo(DestT)) {
.float => @floatFromInt(self.data), // or @floatCast
.int => @intFromFloat(self.data), // or @intCast
else => unreachable,
},
};
}
if (comptime !dims.eql(ActualDest.dims)) if (comptime !dims.eql(ActualDest.dims))
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar."); @compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
if (comptime Self == ActualDest) return self; const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
const DestT = ActualDest.ValueType;
const DestVec = @Vector(total, DestT);
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
if (comptime ratio == 1.0) {
const T_info = @typeInfo(T);
const Dest_info = @typeInfo(DestT);
return .{
.data = if (comptime T_info == .int and Dest_info == .int)
@as(DestVec, @intCast(self.data))
else if (comptime T_info == .float and Dest_info == .float)
@as(DestVec, @floatCast(self.data))
else if (comptime T_info == .int and Dest_info == .float)
@as(DestVec, @floatFromInt(self.data))
else if (comptime T_info == .float and Dest_info == .int)
@as(DestVec, @intFromFloat(self.data)) // Or @intFromFloat(@round(self.data)) if you want rounding
else
unreachable,
};
}
if (comptime T == DestT) { if (comptime T == DestT) {
if (comptime @typeInfo(T) == .float) if (comptime @typeInfo(T) == .float)
@ -466,34 +479,34 @@ pub fn Tensor(
} else { } else {
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
const half: T = comptime @divTrunc(div_val, 2); const half: T = comptime @divTrunc(div_val, 2);
var result: DestVec = undefined;
inline for (0..total) |i| { if (comptime @typeInfo(T).int.signedness == .unsigned) {
const val = self.data[i]; return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) };
result[i] = if (val >= 0) } else {
@divTrunc(val + half, div_val) // Vectorized branchless negative handling
else const is_pos = self.data >= @as(Vec, @splat(0));
@divTrunc(val - half, div_val); const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half)));
return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) };
} }
return .{ .data = result };
} }
} }
var result: DestVec = undefined; // Cross-type fully vectorized casting with scales
inline for (0..total) |i| { const FVec = @Vector(total, f64);
const float_val: f64 = switch (comptime @typeInfo(T)) { const float_vec: FVec = switch (comptime @typeInfo(T)) {
.float => @floatCast(self.data[i]), .float => @floatCast(self.data),
.int => @floatFromInt(self.data[i]), .int => @floatFromInt(self.data),
else => unreachable, else => unreachable,
}; };
const scaled = float_val * ratio;
result[i] = switch (comptime @typeInfo(DestT)) { const scaled = float_vec * @as(FVec, @splat(ratio));
.float => @floatCast(scaled),
.int => @intFromFloat(@round(scaled)), return switch (comptime @typeInfo(DestT)) {
.float => .{ .data = @floatCast(scaled) },
.int => .{ .data = @intFromFloat(@round(scaled)) },
else => unreachable, else => unreachable,
}; };
} }
return .{ .data = result };
}
const CmpResult = if (total == 1) bool else [total]bool; const CmpResult = if (total == 1) bool else [total]bool;
@ -592,7 +605,7 @@ pub fn Tensor(
const sa = shapeRemoveAxis(shape_, axis_a); const sa = shapeRemoveAxis(shape_, axis_a);
const sb = shapeRemoveAxis(OT.shape, axis_b); const sb = shapeRemoveAxis(OT.shape, axis_b);
const rs_raw = shapeCat(&sa, &sb); const rs_raw = shapeCat(&sa, &sb);
const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw; const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
break :blk Tensor( break :blk Tensor(
T, T,
dims.add(OT.dims).argsOpt(), dims.add(OT.dims).argsOpt(),
@ -606,7 +619,7 @@ pub fn Tensor(
const sa = comptime shapeRemoveAxis(shape_, axis_a); const sa = comptime shapeRemoveAxis(shape_, axis_a);
const sb = comptime shapeRemoveAxis(OT.shape, axis_b); const sb = comptime shapeRemoveAxis(OT.shape, axis_b);
const rs_raw = comptime shapeCat(&sa, &sb); const rs_raw = comptime shapeCat(&sa, &sb);
const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
const ResultType = Tensor( const ResultType = Tensor(
T, T,
@ -617,34 +630,85 @@ pub fn Tensor(
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_); const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape); const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
// FAST PATH: Dot Product
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
if (comptime !isInt(T)) {
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
} else {
// For integers, we do a vectorized saturating multiply,
// then convert to an array to do a saturating sum
const mul_arr: [total]T = a_data *| b_data;
var acc: T = 0;
for (mul_arr) |val| acc +|= val;
return .{ .data = @splat(acc) };
}
}
// --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING ---
const a_arr: [total]T = a_data;
const b_arr: [OT.total]T = b_data;
// FAST PATH: 2D Matrix Multiplication
if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) {
const rows = shape_[0];
const cols = OT.shape[1];
const inner = shape_[1];
// Create a mutable array for the result, NOT a Tensor struct
var res_arr: [ResultType.total]T = undefined;
for (0..rows) |i| {
for (0..cols) |j| {
var acc: T = 0;
for (0..inner) |id| {
const a_flat = i * _strides[0] + id * _strides[1];
const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1];
// Use a_arr and b_arr here
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
}
// Write to the array
res_arr[i * cols + j] = acc;
}
}
// Return the initialized Tensor struct
return .{ .data = res_arr };
}
// FALLBACK PATH
const rs_raw_strides = comptime shapeStrides(&rs_raw); const rs_raw_strides = comptime shapeStrides(&rs_raw);
var result: ResultType = .{ .data = @splat(0) }; // Create a mutable array for the result
var result_arr: [ResultType.total]T = undefined;
inline for (0..ResultType.total) |res_flat| { for (0..ResultType.total) |res_flat| {
const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*; var a_free: [sa.len]usize = undefined;
const b_free: [sb.len]usize = comptime res_coords[sa.len..].*; for (0..sa.len) |i| a_free[i] = res_coords[i];
var b_free: [sb.len]usize = undefined;
for (0..sb.len) |i| b_free[i] = res_coords[sa.len + i];
var acc: T = 0; var acc: T = 0;
inline for (0..k) |ki| { for (0..k) |ki| {
const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free); const a_coords = insertAxis(rank, axis_a, ki, &a_free);
const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free); const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free);
const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); const a_flat = encodeFlatCoords(&a_coords, rank, _strides);
const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
if (comptime isInt(T)) // Use a_arr and b_arr here
acc +|= a_data[a_flat] *| b_data[b_flat] if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
else
acc += a_data[a_flat] * b_data[b_flat];
} }
result.data[res_flat] = acc; // Write to the array
result_arr[res_flat] = acc;
} }
return result;
// Return the initialized Tensor struct
return .{ .data = result_arr };
} }
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3. /// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
@ -724,7 +788,8 @@ pub fn Tensor(
} }
} else { } else {
try writer.writeAll("("); try writer.writeAll("(");
inline for (0..total) |i| { const max_to_print = 6;
inline for (0..@min(total, max_to_print)) |i| {
if (i > 0) try writer.writeAll(", "); if (i > 0) try writer.writeAll(", ");
switch (@typeInfo(T)) { switch (@typeInfo(T)) {
.float, .comptime_float => try writer.printFloat(self.data[i], options), .float, .comptime_float => try writer.printFloat(self.data[i], options),
@ -736,6 +801,8 @@ pub fn Tensor(
}), }),
else => unreachable, else => unreachable,
} }
if (comptime i == max_to_print - 1 and total != max_to_print - 1)
try writer.writeAll(", ...");
} }
try writer.writeAll(")"); try writer.writeAll(")");
} }

View File

@ -23,10 +23,10 @@ pub fn main(init: std.process.Init) !void {
try bench_Scalar(&stdout_writer.interface); try bench_Scalar(&stdout_writer.interface);
try stdout_writer.flush(); try stdout_writer.flush();
// try bench_vsNative(&stdout_writer.interface); try bench_vsNative(&stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
// try bench_crossTypeVsNative(&stdout_writer.interface); try bench_crossTypeVsNative(&stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
try bench_Vector(&stdout_writer.interface); try bench_Vector(&stdout_writer.interface);
try stdout_writer.flush(); try stdout_writer.flush();
} }