Removed all inline for (0..total) for either builtin or for loop without inline
This is to prevent giant binary for Tensor with a lot of Scalar
This commit is contained in:
parent
cd954b379b
commit
44aaa8a8b2
@ -51,17 +51,17 @@ data: std.EnumArray(Dimension, comptime_int),
|
|||||||
/// Unspecified dimensions default to 0.
|
/// Unspecified dimensions default to 0.
|
||||||
pub fn init(comptime init_val: ArgOpts) Self {
|
pub fn init(comptime init_val: ArgOpts) Self {
|
||||||
var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) };
|
var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) };
|
||||||
inline for (std.meta.fields(@TypeOf(init_val))) |f|
|
for (std.meta.fields(@TypeOf(init_val))) |f|
|
||||||
s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
|
s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn initFill(comptime val: comptime_int) Self {
|
pub fn initFill(comptime val: comptime_int) Self {
|
||||||
return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) };
|
comptime return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(comptime self: Self, comptime key: Dimension) comptime_int {
|
pub fn get(comptime self: Self, comptime key: Dimension) comptime_int {
|
||||||
return self.data.get(key);
|
comptime return self.data.get(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void {
|
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void {
|
||||||
@ -70,40 +70,40 @@ pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void
|
|||||||
|
|
||||||
pub fn argsOpt(self: Self) ArgOpts {
|
pub fn argsOpt(self: Self) ArgOpts {
|
||||||
var args: ArgOpts = undefined;
|
var args: ArgOpts = undefined;
|
||||||
inline for (std.enums.values(Dimension)) |d|
|
for (std.enums.values(Dimension)) |d|
|
||||||
@field(args, @tagName(d)) = self.get(d);
|
@field(args, @tagName(d)) = self.get(d);
|
||||||
return args;
|
comptime return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add exponents component-wise. Used internally by `mul`.
|
/// Add exponents component-wise. Used internally by `mul`.
|
||||||
pub fn add(comptime a: Self, comptime b: Self) Self {
|
pub fn add(comptime a: Self, comptime b: Self) Self {
|
||||||
var result = Self.initFill(0);
|
var result = Self.initFill(0);
|
||||||
inline for (std.enums.values(Dimension)) |d|
|
for (std.enums.values(Dimension)) |d|
|
||||||
result.set(d, a.get(d) + b.get(d));
|
result.set(d, a.get(d) + b.get(d));
|
||||||
return result;
|
comptime return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Subtract exponents component-wise. Used internally by `div`.
|
/// Subtract exponents component-wise. Used internally by `div`.
|
||||||
pub fn sub(comptime a: Self, comptime b: Self) Self {
|
pub fn sub(comptime a: Self, comptime b: Self) Self {
|
||||||
var result = Self.initFill(0);
|
var result = Self.initFill(0);
|
||||||
inline for (std.enums.values(Dimension)) |d|
|
for (std.enums.values(Dimension)) |d|
|
||||||
result.set(d, a.get(d) - b.get(d));
|
result.set(d, a.get(d) - b.get(d));
|
||||||
return result;
|
comptime return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar.
|
/// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar.
|
||||||
pub fn scale(comptime a: Self, comptime exp: comptime_int) Self {
|
pub fn scale(comptime a: Self, comptime exp: comptime_int) Self {
|
||||||
var result = Self.initFill(0);
|
var result = Self.initFill(0);
|
||||||
inline for (std.enums.values(Dimension)) |d|
|
for (std.enums.values(Dimension)) |d|
|
||||||
result.set(d, a.get(d) * exp);
|
result.set(d, a.get(d) * exp);
|
||||||
return result;
|
comptime return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn div(comptime a: Self, comptime exp: comptime_int) Self {
|
pub fn div(comptime a: Self, comptime exp: comptime_int) Self {
|
||||||
var result = Self.initFill(0);
|
var result = Self.initFill(0);
|
||||||
inline for (std.enums.values(Dimension)) |d|
|
inline for (std.enums.values(Dimension)) |d|
|
||||||
result.set(d, a.get(d) / exp);
|
result.set(d, a.get(d) / exp);
|
||||||
return result;
|
comptime return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`.
|
/// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`.
|
||||||
|
|||||||
@ -65,7 +65,7 @@ pub const UnitScale = enum(isize) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn getFactor(self: @This()) comptime_float {
|
pub inline fn getFactor(self: @This()) comptime_float {
|
||||||
return comptime switch (self) {
|
comptime return switch (self) {
|
||||||
// Standard SI Exponents
|
// Standard SI Exponents
|
||||||
inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))),
|
inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))),
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ pub const UnitScale = enum(isize) {
|
|||||||
inline .lb => 453.59237,
|
inline .lb => 453.59237,
|
||||||
inline .st => 6350.29318,
|
inline .st => 6350.29318,
|
||||||
|
|
||||||
inline else => @floatFromInt(@intFromEnum(self)),
|
else => @floatFromInt(@intFromEnum(self)),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -103,11 +103,11 @@ pub fn init(comptime init_val: ArgOpts) Self {
|
|||||||
else
|
else
|
||||||
s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
|
s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
|
||||||
}
|
}
|
||||||
return s;
|
return comptime s;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn initFill(comptime val: UnitScale) Self {
|
pub fn initFill(comptime val: UnitScale) Self {
|
||||||
return comptime .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) };
|
comptime return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
|
pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
|
||||||
@ -115,7 +115,7 @@ pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void {
|
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void {
|
||||||
comptime self.data.set(key, val);
|
self.data.set(key, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn argsOpt(self: Self) ArgOpts {
|
pub fn argsOpt(self: Self) ArgOpts {
|
||||||
@ -144,5 +144,5 @@ pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float
|
|||||||
factor /= base;
|
factor /= base;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return comptime factor;
|
comptime return factor;
|
||||||
}
|
}
|
||||||
|
|||||||
269
src/Tensor.zig
269
src/Tensor.zig
@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension;
|
|||||||
// Comptime utilities
|
// Comptime utilities
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
pub fn shapeTotal(comptime shape: []const usize) usize {
|
pub fn shapeTotal(comptime shape: []const comptime_int) usize {
|
||||||
var t: usize = 1;
|
var t: comptime_int = 1;
|
||||||
for (shape) |s| t *= s;
|
for (shape) |s| t *= s;
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if two shapes are strictly identical.
|
/// Check if two shapes are strictly identical.
|
||||||
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
|
pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool {
|
||||||
if (a.len != b.len) return false;
|
if (a.len != b.len) return false;
|
||||||
for (a, 0..) |v, i|
|
for (a, 0..) |v, i|
|
||||||
if (v != b[i]) return false;
|
if (v != b[i]) return false;
|
||||||
@ -25,21 +25,21 @@ pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
|
|||||||
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
|
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
|
||||||
/// e.g. shape {3, 4} → strides {4, 1}
|
/// e.g. shape {3, 4} → strides {4, 1}
|
||||||
/// shape {2, 3, 4} → strides {12, 4, 1}
|
/// shape {2, 3, 4} → strides {12, 4, 1}
|
||||||
pub fn shapeStrides(comptime shape: []const usize) [shape.len]usize {
|
pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int {
|
||||||
var st: [shape.len]usize = undefined;
|
var st: [shape.len]comptime_int = undefined;
|
||||||
if (shape.len == 0) return st;
|
if (shape.len == 0) return st;
|
||||||
st[shape.len - 1] = 1;
|
st[shape.len - 1] = 1;
|
||||||
if (shape.len > 1) {
|
if (shape.len > 1) {
|
||||||
var i: usize = shape.len - 1;
|
var i: comptime_int = shape.len - 1;
|
||||||
while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i];
|
while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i];
|
||||||
}
|
}
|
||||||
return st;
|
return st;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return a copy of `shape` with the element at `axis` removed.
|
/// Return a copy of `shape` with the element at `axis` removed.
|
||||||
pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [shape.len - 1]usize {
|
pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int {
|
||||||
var out: [shape.len - 1]usize = undefined;
|
var out: [shape.len - 1]comptime_int = undefined;
|
||||||
var j: usize = 0;
|
var j: comptime_int = 0;
|
||||||
for (shape, 0..) |v, i| {
|
for (shape, 0..) |v, i| {
|
||||||
if (i != axis) {
|
if (i != axis) {
|
||||||
out[j] = v;
|
out[j] = v;
|
||||||
@ -50,8 +50,8 @@ pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [sha
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenate two compile-time slices.
|
/// Concatenate two compile-time slices.
|
||||||
pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b.len]usize {
|
pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int {
|
||||||
var out: [a.len + b.len]usize = undefined;
|
var out: [a.len + b.len]comptime_int = undefined;
|
||||||
for (a, 0..) |v, i| out[i] = v;
|
for (a, 0..) |v, i| out[i] = v;
|
||||||
for (b, 0..) |v, i| out[a.len + i] = v;
|
for (b, 0..) |v, i| out[a.len + i] = v;
|
||||||
return out;
|
return out;
|
||||||
@ -60,11 +60,11 @@ pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b
|
|||||||
/// Decode a flat row-major index into N-D coordinates.
|
/// Decode a flat row-major index into N-D coordinates.
|
||||||
/// Called only in comptime contexts (all arguments are comptime).
|
/// Called only in comptime contexts (all arguments are comptime).
|
||||||
pub fn decodeFlatCoords(
|
pub fn decodeFlatCoords(
|
||||||
comptime flat: usize,
|
comptime flat: comptime_int,
|
||||||
comptime n: usize,
|
comptime n: comptime_int,
|
||||||
comptime strd: [n]usize,
|
comptime strd: [n]comptime_int,
|
||||||
) [n]usize {
|
) [n]usize {
|
||||||
var coords: [n]usize = undefined;
|
var coords: [n]comptime_int = undefined;
|
||||||
var tmp = flat;
|
var tmp = flat;
|
||||||
for (0..n) |i| {
|
for (0..n) |i| {
|
||||||
coords[i] = if (strd[i] == 0) 0 else tmp / strd[i];
|
coords[i] = if (strd[i] == 0) 0 else tmp / strd[i];
|
||||||
@ -116,7 +116,7 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
|
|||||||
const s1: Scales = T1.scales;
|
const s1: Scales = T1.scales;
|
||||||
const s2: Scales = T2.scales;
|
const s2: Scales = T2.scales;
|
||||||
comptime var out = Scales.initFill(.none);
|
comptime var out = Scales.initFill(.none);
|
||||||
inline for (std.enums.values(Dimension)) |dim| {
|
for (std.enums.values(Dimension)) |dim| {
|
||||||
const scale1 = comptime s1.get(dim);
|
const scale1 = comptime s1.get(dim);
|
||||||
const scale2 = comptime s2.get(dim);
|
const scale2 = comptime s2.get(dim);
|
||||||
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
|
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
|
||||||
@ -201,7 +201,7 @@ pub fn Tensor(
|
|||||||
comptime T: type,
|
comptime T: type,
|
||||||
comptime d_opt: Dimensions.ArgOpts,
|
comptime d_opt: Dimensions.ArgOpts,
|
||||||
comptime s_opt: Scales.ArgOpts,
|
comptime s_opt: Scales.ArgOpts,
|
||||||
comptime shape_: []const usize,
|
comptime shape_: []const comptime_int,
|
||||||
) type {
|
) type {
|
||||||
comptime {
|
comptime {
|
||||||
if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
|
if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
|
||||||
@ -223,10 +223,10 @@ pub fn Tensor(
|
|||||||
pub const ValueType: type = T;
|
pub const ValueType: type = T;
|
||||||
pub const dims: Dimensions = Dimensions.init(d_opt);
|
pub const dims: Dimensions = Dimensions.init(d_opt);
|
||||||
pub const scales: Scales = Scales.init(s_opt);
|
pub const scales: Scales = Scales.init(s_opt);
|
||||||
pub const shape: []const usize = shape_;
|
pub const shape: []const comptime_int = shape_;
|
||||||
pub const rank: usize = shape_.len;
|
pub const rank: comptime_int = shape_.len;
|
||||||
pub const total: usize = _total;
|
pub const total: comptime_int = _total;
|
||||||
pub const strides_arr: [shape_.len]usize = _strides;
|
pub const strides_arr: [shape_.len]comptime_int = _strides;
|
||||||
pub const ISTENSOR = true;
|
pub const ISTENSOR = true;
|
||||||
|
|
||||||
/// Convert N-D coords (row-major) to flat index — fully comptime.
|
/// Convert N-D coords (row-major) to flat index — fully comptime.
|
||||||
@ -359,9 +359,7 @@ pub fn Tensor(
|
|||||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||||
if (comptime isInt(T)) {
|
if (comptime isInt(T)) {
|
||||||
var result: Vec = undefined;
|
return .{ .data = @divTrunc(l, rr) };
|
||||||
inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]);
|
|
||||||
return .{ .data = result };
|
|
||||||
} else {
|
} else {
|
||||||
return .{ .data = l / rr };
|
return .{ .data = l / rr };
|
||||||
}
|
}
|
||||||
@ -379,19 +377,27 @@ pub fn Tensor(
|
|||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
if (comptime isInt(T)) {
|
if (comptime exp == 0) return .{ .data = @splat(1) };
|
||||||
var result: Vec = undefined;
|
if (comptime exp == 1) return self;
|
||||||
inline for (0..total) |i|
|
|
||||||
result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T);
|
var base = self.data;
|
||||||
return .{ .data = result };
|
var result: Vec = @splat(1);
|
||||||
} else {
|
comptime var e = @abs(exp);
|
||||||
const abs_exp = comptime @abs(exp);
|
|
||||||
var result: Vec = @splat(1);
|
// $O(\log n)$ Exponentiation by squaring applied to the entire vector
|
||||||
comptime var i = 0;
|
inline while (e > 0) {
|
||||||
inline while (i < abs_exp) : (i += 1) result *= self.data;
|
if (e % 2 == 1) {
|
||||||
if (comptime exp < 0) result = @as(Vec, @splat(1)) / result;
|
result = if (comptime isInt(T)) result *| base else result * base;
|
||||||
return .{ .data = result };
|
}
|
||||||
|
e /= 2;
|
||||||
|
if (e > 0) {
|
||||||
|
base = if (comptime isInt(T)) base *| base else base * base;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if (comptime !isInt(T) and exp < 0) {
|
||||||
|
result = @as(Vec, @splat(1)) / result;
|
||||||
|
}
|
||||||
|
return .{ .data = result };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Square root of every element. All dimension exponents must be even.
|
/// Square root of every element. All dimension exponents must be even.
|
||||||
@ -404,15 +410,16 @@ pub fn Tensor(
|
|||||||
if (comptime !dims.isSquare())
|
if (comptime !dims.isSquare())
|
||||||
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
||||||
if (comptime @typeInfo(T) == .float) {
|
if (comptime @typeInfo(T) == .float) {
|
||||||
return .{ .data = @sqrt(self.data) };
|
return .{ .data = @sqrt(self.data) }; // Float is natively vectorized!
|
||||||
} else {
|
} else {
|
||||||
var result: Vec = undefined;
|
const arr: [total]T = self.data; // Add this!
|
||||||
|
var res_arr: [total]T = undefined;
|
||||||
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
|
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
|
||||||
inline for (0..total) |i| {
|
for (0..total) |i| {
|
||||||
const v = self.data[i];
|
const v = arr[i];
|
||||||
result[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
|
res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
|
||||||
}
|
}
|
||||||
return .{ .data = result };
|
return .{ .data = res_arr };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -433,28 +440,34 @@ pub fn Tensor(
|
|||||||
|
|
||||||
if (comptime Self == ActualDest) return self;
|
if (comptime Self == ActualDest) return self;
|
||||||
|
|
||||||
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
// Run validation checks FIRST before dealing with types
|
||||||
const DestT = ActualDest.ValueType;
|
|
||||||
const DestVec = @Vector(total, DestT);
|
|
||||||
|
|
||||||
// If ratio is 1, just handle type conversion
|
|
||||||
if (comptime ratio == 1.0) {
|
|
||||||
if (comptime T == DestT) return .{ .data = self.data };
|
|
||||||
return .{
|
|
||||||
.data = switch (comptime @typeInfo(DestT)) {
|
|
||||||
.float => @floatFromInt(self.data), // or @floatCast
|
|
||||||
.int => @intFromFloat(self.data), // or @intCast
|
|
||||||
else => unreachable,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (comptime !dims.eql(ActualDest.dims))
|
if (comptime !dims.eql(ActualDest.dims))
|
||||||
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
||||||
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
|
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
|
||||||
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
||||||
|
|
||||||
if (comptime Self == ActualDest) return self;
|
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
||||||
|
const DestT = ActualDest.ValueType;
|
||||||
|
const DestVec = @Vector(total, DestT);
|
||||||
|
|
||||||
|
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
|
||||||
|
if (comptime ratio == 1.0) {
|
||||||
|
const T_info = @typeInfo(T);
|
||||||
|
const Dest_info = @typeInfo(DestT);
|
||||||
|
|
||||||
|
return .{
|
||||||
|
.data = if (comptime T_info == .int and Dest_info == .int)
|
||||||
|
@as(DestVec, @intCast(self.data))
|
||||||
|
else if (comptime T_info == .float and Dest_info == .float)
|
||||||
|
@as(DestVec, @floatCast(self.data))
|
||||||
|
else if (comptime T_info == .int and Dest_info == .float)
|
||||||
|
@as(DestVec, @floatFromInt(self.data))
|
||||||
|
else if (comptime T_info == .float and Dest_info == .int)
|
||||||
|
@as(DestVec, @intFromFloat(self.data)) // Or @intFromFloat(@round(self.data)) if you want rounding
|
||||||
|
else
|
||||||
|
unreachable,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if (comptime T == DestT) {
|
if (comptime T == DestT) {
|
||||||
if (comptime @typeInfo(T) == .float)
|
if (comptime @typeInfo(T) == .float)
|
||||||
@ -466,33 +479,33 @@ pub fn Tensor(
|
|||||||
} else {
|
} else {
|
||||||
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
|
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
|
||||||
const half: T = comptime @divTrunc(div_val, 2);
|
const half: T = comptime @divTrunc(div_val, 2);
|
||||||
var result: DestVec = undefined;
|
|
||||||
inline for (0..total) |i| {
|
if (comptime @typeInfo(T).int.signedness == .unsigned) {
|
||||||
const val = self.data[i];
|
return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) };
|
||||||
result[i] = if (val >= 0)
|
} else {
|
||||||
@divTrunc(val + half, div_val)
|
// Vectorized branchless negative handling
|
||||||
else
|
const is_pos = self.data >= @as(Vec, @splat(0));
|
||||||
@divTrunc(val - half, div_val);
|
const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half)));
|
||||||
|
return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) };
|
||||||
}
|
}
|
||||||
return .{ .data = result };
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var result: DestVec = undefined;
|
// Cross-type fully vectorized casting with scales
|
||||||
inline for (0..total) |i| {
|
const FVec = @Vector(total, f64);
|
||||||
const float_val: f64 = switch (comptime @typeInfo(T)) {
|
const float_vec: FVec = switch (comptime @typeInfo(T)) {
|
||||||
.float => @floatCast(self.data[i]),
|
.float => @floatCast(self.data),
|
||||||
.int => @floatFromInt(self.data[i]),
|
.int => @floatFromInt(self.data),
|
||||||
else => unreachable,
|
else => unreachable,
|
||||||
};
|
};
|
||||||
const scaled = float_val * ratio;
|
|
||||||
result[i] = switch (comptime @typeInfo(DestT)) {
|
const scaled = float_vec * @as(FVec, @splat(ratio));
|
||||||
.float => @floatCast(scaled),
|
|
||||||
.int => @intFromFloat(@round(scaled)),
|
return switch (comptime @typeInfo(DestT)) {
|
||||||
else => unreachable,
|
.float => .{ .data = @floatCast(scaled) },
|
||||||
};
|
.int => .{ .data = @intFromFloat(@round(scaled)) },
|
||||||
}
|
else => unreachable,
|
||||||
return .{ .data = result };
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const CmpResult = if (total == 1) bool else [total]bool;
|
const CmpResult = if (total == 1) bool else [total]bool;
|
||||||
@ -592,7 +605,7 @@ pub fn Tensor(
|
|||||||
const sa = shapeRemoveAxis(shape_, axis_a);
|
const sa = shapeRemoveAxis(shape_, axis_a);
|
||||||
const sb = shapeRemoveAxis(OT.shape, axis_b);
|
const sb = shapeRemoveAxis(OT.shape, axis_b);
|
||||||
const rs_raw = shapeCat(&sa, &sb);
|
const rs_raw = shapeCat(&sa, &sb);
|
||||||
const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw;
|
const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||||
break :blk Tensor(
|
break :blk Tensor(
|
||||||
T,
|
T,
|
||||||
dims.add(OT.dims).argsOpt(),
|
dims.add(OT.dims).argsOpt(),
|
||||||
@ -606,7 +619,7 @@ pub fn Tensor(
|
|||||||
const sa = comptime shapeRemoveAxis(shape_, axis_a);
|
const sa = comptime shapeRemoveAxis(shape_, axis_a);
|
||||||
const sb = comptime shapeRemoveAxis(OT.shape, axis_b);
|
const sb = comptime shapeRemoveAxis(OT.shape, axis_b);
|
||||||
const rs_raw = comptime shapeCat(&sa, &sb);
|
const rs_raw = comptime shapeCat(&sa, &sb);
|
||||||
const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
|
const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||||
|
|
||||||
const ResultType = Tensor(
|
const ResultType = Tensor(
|
||||||
T,
|
T,
|
||||||
@ -617,34 +630,85 @@ pub fn Tensor(
|
|||||||
|
|
||||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
|
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
|
||||||
const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
|
const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
|
||||||
|
|
||||||
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||||
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
|
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
|
||||||
|
|
||||||
|
// FAST PATH: Dot Product
|
||||||
|
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
|
||||||
|
if (comptime !isInt(T)) {
|
||||||
|
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
|
||||||
|
} else {
|
||||||
|
// For integers, we do a vectorized saturating multiply,
|
||||||
|
// then convert to an array to do a saturating sum
|
||||||
|
const mul_arr: [total]T = a_data *| b_data;
|
||||||
|
var acc: T = 0;
|
||||||
|
for (mul_arr) |val| acc +|= val;
|
||||||
|
return .{ .data = @splat(acc) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING ---
|
||||||
|
const a_arr: [total]T = a_data;
|
||||||
|
const b_arr: [OT.total]T = b_data;
|
||||||
|
|
||||||
|
// FAST PATH: 2D Matrix Multiplication
|
||||||
|
if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) {
|
||||||
|
const rows = shape_[0];
|
||||||
|
const cols = OT.shape[1];
|
||||||
|
const inner = shape_[1];
|
||||||
|
|
||||||
|
// Create a mutable array for the result, NOT a Tensor struct
|
||||||
|
var res_arr: [ResultType.total]T = undefined;
|
||||||
|
|
||||||
|
for (0..rows) |i| {
|
||||||
|
for (0..cols) |j| {
|
||||||
|
var acc: T = 0;
|
||||||
|
for (0..inner) |id| {
|
||||||
|
const a_flat = i * _strides[0] + id * _strides[1];
|
||||||
|
const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1];
|
||||||
|
|
||||||
|
// Use a_arr and b_arr here
|
||||||
|
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
||||||
|
}
|
||||||
|
// Write to the array
|
||||||
|
res_arr[i * cols + j] = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Return the initialized Tensor struct
|
||||||
|
return .{ .data = res_arr };
|
||||||
|
}
|
||||||
|
|
||||||
|
// FALLBACK PATH
|
||||||
const rs_raw_strides = comptime shapeStrides(&rs_raw);
|
const rs_raw_strides = comptime shapeStrides(&rs_raw);
|
||||||
|
|
||||||
var result: ResultType = .{ .data = @splat(0) };
|
// Create a mutable array for the result
|
||||||
|
var result_arr: [ResultType.total]T = undefined;
|
||||||
|
|
||||||
inline for (0..ResultType.total) |res_flat| {
|
for (0..ResultType.total) |res_flat| {
|
||||||
const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
|
const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
|
||||||
|
|
||||||
const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*;
|
var a_free: [sa.len]usize = undefined;
|
||||||
const b_free: [sb.len]usize = comptime res_coords[sa.len..].*;
|
for (0..sa.len) |i| a_free[i] = res_coords[i];
|
||||||
|
var b_free: [sb.len]usize = undefined;
|
||||||
|
for (0..sb.len) |i| b_free[i] = res_coords[sa.len + i];
|
||||||
|
|
||||||
var acc: T = 0;
|
var acc: T = 0;
|
||||||
inline for (0..k) |ki| {
|
for (0..k) |ki| {
|
||||||
const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free);
|
const a_coords = insertAxis(rank, axis_a, ki, &a_free);
|
||||||
const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free);
|
const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free);
|
||||||
const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides);
|
const a_flat = encodeFlatCoords(&a_coords, rank, _strides);
|
||||||
const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
|
const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
|
||||||
|
|
||||||
if (comptime isInt(T))
|
// Use a_arr and b_arr here
|
||||||
acc +|= a_data[a_flat] *| b_data[b_flat]
|
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
||||||
else
|
|
||||||
acc += a_data[a_flat] * b_data[b_flat];
|
|
||||||
}
|
}
|
||||||
result.data[res_flat] = acc;
|
// Write to the array
|
||||||
|
result_arr[res_flat] = acc;
|
||||||
}
|
}
|
||||||
return result;
|
|
||||||
|
// Return the initialized Tensor struct
|
||||||
|
return .{ .data = result_arr };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
|
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
|
||||||
@ -724,7 +788,8 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
try writer.writeAll("(");
|
try writer.writeAll("(");
|
||||||
inline for (0..total) |i| {
|
const max_to_print = 6;
|
||||||
|
inline for (0..@min(total, max_to_print)) |i| {
|
||||||
if (i > 0) try writer.writeAll(", ");
|
if (i > 0) try writer.writeAll(", ");
|
||||||
switch (@typeInfo(T)) {
|
switch (@typeInfo(T)) {
|
||||||
.float, .comptime_float => try writer.printFloat(self.data[i], options),
|
.float, .comptime_float => try writer.printFloat(self.data[i], options),
|
||||||
@ -736,6 +801,8 @@ pub fn Tensor(
|
|||||||
}),
|
}),
|
||||||
else => unreachable,
|
else => unreachable,
|
||||||
}
|
}
|
||||||
|
if (comptime i == max_to_print - 1 and total != max_to_print - 1)
|
||||||
|
try writer.writeAll(", ...");
|
||||||
}
|
}
|
||||||
try writer.writeAll(")");
|
try writer.writeAll(")");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,10 +23,10 @@ pub fn main(init: std.process.Init) !void {
|
|||||||
|
|
||||||
try bench_Scalar(&stdout_writer.interface);
|
try bench_Scalar(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try bench_vsNative(&stdout_writer.interface);
|
try bench_vsNative(&stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try bench_crossTypeVsNative(&stdout_writer.interface);
|
try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
try bench_Vector(&stdout_writer.interface);
|
try bench_Vector(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user