Removed all inline for (0..total) for either builtin or for loop without inline

This is to prevent giant binary for Tensor with a lot of Scalar
This commit is contained in:
AdrienBouvais 2026-04-27 16:11:46 +02:00
parent cd954b379b
commit 44aaa8a8b2
4 changed files with 190 additions and 123 deletions

View File

@ -51,17 +51,17 @@ data: std.EnumArray(Dimension, comptime_int),
/// Unspecified dimensions default to 0.
pub fn init(comptime init_val: ArgOpts) Self {
var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) };
inline for (std.meta.fields(@TypeOf(init_val))) |f|
for (std.meta.fields(@TypeOf(init_val))) |f|
s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
return s;
}
pub fn initFill(comptime val: comptime_int) Self {
return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) };
comptime return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) };
}
pub fn get(comptime self: Self, comptime key: Dimension) comptime_int {
return self.data.get(key);
comptime return self.data.get(key);
}
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void {
@ -70,40 +70,40 @@ pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void
pub fn argsOpt(self: Self) ArgOpts {
var args: ArgOpts = undefined;
inline for (std.enums.values(Dimension)) |d|
for (std.enums.values(Dimension)) |d|
@field(args, @tagName(d)) = self.get(d);
return args;
comptime return args;
}
/// Add exponents component-wise. Used internally by `mul`.
pub fn add(comptime a: Self, comptime b: Self) Self {
var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d|
for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) + b.get(d));
return result;
comptime return result;
}
/// Subtract exponents component-wise. Used internally by `div`.
pub fn sub(comptime a: Self, comptime b: Self) Self {
var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d|
for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) - b.get(d));
return result;
comptime return result;
}
/// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar.
pub fn scale(comptime a: Self, comptime exp: comptime_int) Self {
var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d|
for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) * exp);
return result;
comptime return result;
}
pub fn div(comptime a: Self, comptime exp: comptime_int) Self {
var result = Self.initFill(0);
inline for (std.enums.values(Dimension)) |d|
result.set(d, a.get(d) / exp);
return result;
comptime return result;
}
/// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`.

View File

@ -65,7 +65,7 @@ pub const UnitScale = enum(isize) {
}
pub inline fn getFactor(self: @This()) comptime_float {
return comptime switch (self) {
comptime return switch (self) {
// Standard SI Exponents
inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))),
@ -83,7 +83,7 @@ pub const UnitScale = enum(isize) {
inline .lb => 453.59237,
inline .st => 6350.29318,
inline else => @floatFromInt(@intFromEnum(self)),
else => @floatFromInt(@intFromEnum(self)),
};
}
};
@ -103,11 +103,11 @@ pub fn init(comptime init_val: ArgOpts) Self {
else
s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
}
return s;
return comptime s;
}
pub fn initFill(comptime val: UnitScale) Self {
return comptime .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) };
comptime return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) };
}
pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
@ -115,7 +115,7 @@ pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
}
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void {
comptime self.data.set(key, val);
self.data.set(key, val);
}
pub fn argsOpt(self: Self) ArgOpts {
@ -144,5 +144,5 @@ pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float
factor /= base;
}
}
return comptime factor;
comptime return factor;
}

View File

@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension;
// Comptime utilities
//
pub fn shapeTotal(comptime shape: []const usize) usize {
var t: usize = 1;
pub fn shapeTotal(comptime shape: []const comptime_int) usize {
var t: comptime_int = 1;
for (shape) |s| t *= s;
return t;
}
/// Check if two shapes are strictly identical.
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool {
if (a.len != b.len) return false;
for (a, 0..) |v, i|
if (v != b[i]) return false;
@ -25,21 +25,21 @@ pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
/// e.g. shape {3, 4} strides {4, 1}
/// shape {2, 3, 4} strides {12, 4, 1}
pub fn shapeStrides(comptime shape: []const usize) [shape.len]usize {
var st: [shape.len]usize = undefined;
pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int {
var st: [shape.len]comptime_int = undefined;
if (shape.len == 0) return st;
st[shape.len - 1] = 1;
if (shape.len > 1) {
var i: usize = shape.len - 1;
var i: comptime_int = shape.len - 1;
while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i];
}
return st;
}
/// Return a copy of `shape` with the element at `axis` removed.
pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [shape.len - 1]usize {
var out: [shape.len - 1]usize = undefined;
var j: usize = 0;
pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int {
var out: [shape.len - 1]comptime_int = undefined;
var j: comptime_int = 0;
for (shape, 0..) |v, i| {
if (i != axis) {
out[j] = v;
@ -50,8 +50,8 @@ pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [sha
}
/// Concatenate two compile-time slices.
pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b.len]usize {
var out: [a.len + b.len]usize = undefined;
pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int {
var out: [a.len + b.len]comptime_int = undefined;
for (a, 0..) |v, i| out[i] = v;
for (b, 0..) |v, i| out[a.len + i] = v;
return out;
@ -60,11 +60,11 @@ pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b
/// Decode a flat row-major index into N-D coordinates.
/// Called only in comptime contexts (all arguments are comptime).
pub fn decodeFlatCoords(
comptime flat: usize,
comptime n: usize,
comptime strd: [n]usize,
comptime flat: comptime_int,
comptime n: comptime_int,
comptime strd: [n]comptime_int,
) [n]usize {
var coords: [n]usize = undefined;
var coords: [n]comptime_int = undefined;
var tmp = flat;
for (0..n) |i| {
coords[i] = if (strd[i] == 0) 0 else tmp / strd[i];
@ -116,7 +116,7 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
const s1: Scales = T1.scales;
const s2: Scales = T2.scales;
comptime var out = Scales.initFill(.none);
inline for (std.enums.values(Dimension)) |dim| {
for (std.enums.values(Dimension)) |dim| {
const scale1 = comptime s1.get(dim);
const scale2 = comptime s2.get(dim);
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
@ -201,7 +201,7 @@ pub fn Tensor(
comptime T: type,
comptime d_opt: Dimensions.ArgOpts,
comptime s_opt: Scales.ArgOpts,
comptime shape_: []const usize,
comptime shape_: []const comptime_int,
) type {
comptime {
if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
@ -223,10 +223,10 @@ pub fn Tensor(
pub const ValueType: type = T;
pub const dims: Dimensions = Dimensions.init(d_opt);
pub const scales: Scales = Scales.init(s_opt);
pub const shape: []const usize = shape_;
pub const rank: usize = shape_.len;
pub const total: usize = _total;
pub const strides_arr: [shape_.len]usize = _strides;
pub const shape: []const comptime_int = shape_;
pub const rank: comptime_int = shape_.len;
pub const total: comptime_int = _total;
pub const strides_arr: [shape_.len]comptime_int = _strides;
pub const ISTENSOR = true;
/// Convert N-D coords (row-major) to flat index fully comptime.
@ -359,9 +359,7 @@ pub fn Tensor(
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
if (comptime isInt(T)) {
var result: Vec = undefined;
inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]);
return .{ .data = result };
return .{ .data = @divTrunc(l, rr) };
} else {
return .{ .data = l / rr };
}
@ -379,19 +377,27 @@ pub fn Tensor(
scales.argsOpt(),
shape_,
) {
if (comptime isInt(T)) {
var result: Vec = undefined;
inline for (0..total) |i|
result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T);
return .{ .data = result };
} else {
const abs_exp = comptime @abs(exp);
var result: Vec = @splat(1);
comptime var i = 0;
inline while (i < abs_exp) : (i += 1) result *= self.data;
if (comptime exp < 0) result = @as(Vec, @splat(1)) / result;
return .{ .data = result };
if (comptime exp == 0) return .{ .data = @splat(1) };
if (comptime exp == 1) return self;
var base = self.data;
var result: Vec = @splat(1);
comptime var e = @abs(exp);
// $O(\log n)$ Exponentiation by squaring applied to the entire vector
inline while (e > 0) {
if (e % 2 == 1) {
result = if (comptime isInt(T)) result *| base else result * base;
}
e /= 2;
if (e > 0) {
base = if (comptime isInt(T)) base *| base else base * base;
}
}
if (comptime !isInt(T) and exp < 0) {
result = @as(Vec, @splat(1)) / result;
}
return .{ .data = result };
}
/// Square root of every element. All dimension exponents must be even.
@ -404,15 +410,16 @@ pub fn Tensor(
if (comptime !dims.isSquare())
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
if (comptime @typeInfo(T) == .float) {
return .{ .data = @sqrt(self.data) };
return .{ .data = @sqrt(self.data) }; // Float is natively vectorized!
} else {
var result: Vec = undefined;
const arr: [total]T = self.data; // Add this!
var res_arr: [total]T = undefined;
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
inline for (0..total) |i| {
const v = self.data[i];
result[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
for (0..total) |i| {
const v = arr[i];
res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
}
return .{ .data = result };
return .{ .data = res_arr };
}
}
@ -433,28 +440,34 @@ pub fn Tensor(
if (comptime Self == ActualDest) return self;
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
const DestT = ActualDest.ValueType;
const DestVec = @Vector(total, DestT);
// If ratio is 1, just handle type conversion
if (comptime ratio == 1.0) {
if (comptime T == DestT) return .{ .data = self.data };
return .{
.data = switch (comptime @typeInfo(DestT)) {
.float => @floatFromInt(self.data), // or @floatCast
.int => @intFromFloat(self.data), // or @intCast
else => unreachable,
},
};
}
// Run validation checks FIRST before dealing with types
if (comptime !dims.eql(ActualDest.dims))
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
if (comptime Self == ActualDest) return self;
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
const DestT = ActualDest.ValueType;
const DestVec = @Vector(total, DestT);
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
if (comptime ratio == 1.0) {
const T_info = @typeInfo(T);
const Dest_info = @typeInfo(DestT);
return .{
.data = if (comptime T_info == .int and Dest_info == .int)
@as(DestVec, @intCast(self.data))
else if (comptime T_info == .float and Dest_info == .float)
@as(DestVec, @floatCast(self.data))
else if (comptime T_info == .int and Dest_info == .float)
@as(DestVec, @floatFromInt(self.data))
else if (comptime T_info == .float and Dest_info == .int)
@as(DestVec, @intFromFloat(self.data)) // Or @intFromFloat(@round(self.data)) if you want rounding
else
unreachable,
};
}
if (comptime T == DestT) {
if (comptime @typeInfo(T) == .float)
@ -466,33 +479,33 @@ pub fn Tensor(
} else {
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
const half: T = comptime @divTrunc(div_val, 2);
var result: DestVec = undefined;
inline for (0..total) |i| {
const val = self.data[i];
result[i] = if (val >= 0)
@divTrunc(val + half, div_val)
else
@divTrunc(val - half, div_val);
if (comptime @typeInfo(T).int.signedness == .unsigned) {
return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) };
} else {
// Vectorized branchless negative handling
const is_pos = self.data >= @as(Vec, @splat(0));
const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half)));
return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) };
}
return .{ .data = result };
}
}
var result: DestVec = undefined;
inline for (0..total) |i| {
const float_val: f64 = switch (comptime @typeInfo(T)) {
.float => @floatCast(self.data[i]),
.int => @floatFromInt(self.data[i]),
else => unreachable,
};
const scaled = float_val * ratio;
result[i] = switch (comptime @typeInfo(DestT)) {
.float => @floatCast(scaled),
.int => @intFromFloat(@round(scaled)),
else => unreachable,
};
}
return .{ .data = result };
// Cross-type fully vectorized casting with scales
const FVec = @Vector(total, f64);
const float_vec: FVec = switch (comptime @typeInfo(T)) {
.float => @floatCast(self.data),
.int => @floatFromInt(self.data),
else => unreachable,
};
const scaled = float_vec * @as(FVec, @splat(ratio));
return switch (comptime @typeInfo(DestT)) {
.float => .{ .data = @floatCast(scaled) },
.int => .{ .data = @intFromFloat(@round(scaled)) },
else => unreachable,
};
}
const CmpResult = if (total == 1) bool else [total]bool;
@ -592,7 +605,7 @@ pub fn Tensor(
const sa = shapeRemoveAxis(shape_, axis_a);
const sb = shapeRemoveAxis(OT.shape, axis_b);
const rs_raw = shapeCat(&sa, &sb);
const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw;
const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
break :blk Tensor(
T,
dims.add(OT.dims).argsOpt(),
@ -606,7 +619,7 @@ pub fn Tensor(
const sa = comptime shapeRemoveAxis(shape_, axis_a);
const sb = comptime shapeRemoveAxis(OT.shape, axis_b);
const rs_raw = comptime shapeCat(&sa, &sb);
const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
const ResultType = Tensor(
T,
@ -617,34 +630,85 @@ pub fn Tensor(
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
// FAST PATH: Dot Product
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
if (comptime !isInt(T)) {
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
} else {
// For integers, we do a vectorized saturating multiply,
// then convert to an array to do a saturating sum
const mul_arr: [total]T = a_data *| b_data;
var acc: T = 0;
for (mul_arr) |val| acc +|= val;
return .{ .data = @splat(acc) };
}
}
// --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING ---
const a_arr: [total]T = a_data;
const b_arr: [OT.total]T = b_data;
// FAST PATH: 2D Matrix Multiplication
if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) {
const rows = shape_[0];
const cols = OT.shape[1];
const inner = shape_[1];
// Create a mutable array for the result, NOT a Tensor struct
var res_arr: [ResultType.total]T = undefined;
for (0..rows) |i| {
for (0..cols) |j| {
var acc: T = 0;
for (0..inner) |id| {
const a_flat = i * _strides[0] + id * _strides[1];
const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1];
// Use a_arr and b_arr here
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
}
// Write to the array
res_arr[i * cols + j] = acc;
}
}
// Return the initialized Tensor struct
return .{ .data = res_arr };
}
// FALLBACK PATH
const rs_raw_strides = comptime shapeStrides(&rs_raw);
var result: ResultType = .{ .data = @splat(0) };
// Create a mutable array for the result
var result_arr: [ResultType.total]T = undefined;
inline for (0..ResultType.total) |res_flat| {
const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
for (0..ResultType.total) |res_flat| {
const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*;
const b_free: [sb.len]usize = comptime res_coords[sa.len..].*;
var a_free: [sa.len]usize = undefined;
for (0..sa.len) |i| a_free[i] = res_coords[i];
var b_free: [sb.len]usize = undefined;
for (0..sb.len) |i| b_free[i] = res_coords[sa.len + i];
var acc: T = 0;
inline for (0..k) |ki| {
const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free);
const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free);
const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides);
const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
for (0..k) |ki| {
const a_coords = insertAxis(rank, axis_a, ki, &a_free);
const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free);
const a_flat = encodeFlatCoords(&a_coords, rank, _strides);
const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
if (comptime isInt(T))
acc +|= a_data[a_flat] *| b_data[b_flat]
else
acc += a_data[a_flat] * b_data[b_flat];
// Use a_arr and b_arr here
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
}
result.data[res_flat] = acc;
// Write to the array
result_arr[res_flat] = acc;
}
return result;
// Return the initialized Tensor struct
return .{ .data = result_arr };
}
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
@ -724,7 +788,8 @@ pub fn Tensor(
}
} else {
try writer.writeAll("(");
inline for (0..total) |i| {
const max_to_print = 6;
inline for (0..@min(total, max_to_print)) |i| {
if (i > 0) try writer.writeAll(", ");
switch (@typeInfo(T)) {
.float, .comptime_float => try writer.printFloat(self.data[i], options),
@ -736,6 +801,8 @@ pub fn Tensor(
}),
else => unreachable,
}
if (comptime i == max_to_print - 1 and total != max_to_print - 1)
try writer.writeAll(", ...");
}
try writer.writeAll(")");
}

View File

@ -23,10 +23,10 @@ pub fn main(init: std.process.Init) !void {
try bench_Scalar(&stdout_writer.interface);
try stdout_writer.flush();
// try bench_vsNative(&stdout_writer.interface);
// try stdout_writer.flush();
// try bench_crossTypeVsNative(&stdout_writer.interface);
// try stdout_writer.flush();
try bench_vsNative(&stdout_writer.interface);
try stdout_writer.flush();
try bench_crossTypeVsNative(&stdout_writer.interface);
try stdout_writer.flush();
try bench_Vector(&stdout_writer.interface);
try stdout_writer.flush();
}