Removed all inline for (0..total) for either builtin or for loop without inline
This is to prevent giant binary for Tensor with a lot of Scalar
This commit is contained in:
parent
cd954b379b
commit
44aaa8a8b2
@ -51,17 +51,17 @@ data: std.EnumArray(Dimension, comptime_int),
|
||||
/// Unspecified dimensions default to 0.
|
||||
pub fn init(comptime init_val: ArgOpts) Self {
|
||||
var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) };
|
||||
inline for (std.meta.fields(@TypeOf(init_val))) |f|
|
||||
for (std.meta.fields(@TypeOf(init_val))) |f|
|
||||
s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
|
||||
return s;
|
||||
}
|
||||
|
||||
pub fn initFill(comptime val: comptime_int) Self {
|
||||
return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) };
|
||||
comptime return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) };
|
||||
}
|
||||
|
||||
pub fn get(comptime self: Self, comptime key: Dimension) comptime_int {
|
||||
return self.data.get(key);
|
||||
comptime return self.data.get(key);
|
||||
}
|
||||
|
||||
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void {
|
||||
@ -70,40 +70,40 @@ pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void
|
||||
|
||||
pub fn argsOpt(self: Self) ArgOpts {
|
||||
var args: ArgOpts = undefined;
|
||||
inline for (std.enums.values(Dimension)) |d|
|
||||
for (std.enums.values(Dimension)) |d|
|
||||
@field(args, @tagName(d)) = self.get(d);
|
||||
return args;
|
||||
comptime return args;
|
||||
}
|
||||
|
||||
/// Add exponents component-wise. Used internally by `mul`.
|
||||
pub fn add(comptime a: Self, comptime b: Self) Self {
|
||||
var result = Self.initFill(0);
|
||||
inline for (std.enums.values(Dimension)) |d|
|
||||
for (std.enums.values(Dimension)) |d|
|
||||
result.set(d, a.get(d) + b.get(d));
|
||||
return result;
|
||||
comptime return result;
|
||||
}
|
||||
|
||||
/// Subtract exponents component-wise. Used internally by `div`.
|
||||
pub fn sub(comptime a: Self, comptime b: Self) Self {
|
||||
var result = Self.initFill(0);
|
||||
inline for (std.enums.values(Dimension)) |d|
|
||||
for (std.enums.values(Dimension)) |d|
|
||||
result.set(d, a.get(d) - b.get(d));
|
||||
return result;
|
||||
comptime return result;
|
||||
}
|
||||
|
||||
/// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar.
|
||||
pub fn scale(comptime a: Self, comptime exp: comptime_int) Self {
|
||||
var result = Self.initFill(0);
|
||||
inline for (std.enums.values(Dimension)) |d|
|
||||
for (std.enums.values(Dimension)) |d|
|
||||
result.set(d, a.get(d) * exp);
|
||||
return result;
|
||||
comptime return result;
|
||||
}
|
||||
|
||||
pub fn div(comptime a: Self, comptime exp: comptime_int) Self {
|
||||
var result = Self.initFill(0);
|
||||
inline for (std.enums.values(Dimension)) |d|
|
||||
result.set(d, a.get(d) / exp);
|
||||
return result;
|
||||
comptime return result;
|
||||
}
|
||||
|
||||
/// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`.
|
||||
|
||||
@ -65,7 +65,7 @@ pub const UnitScale = enum(isize) {
|
||||
}
|
||||
|
||||
pub inline fn getFactor(self: @This()) comptime_float {
|
||||
return comptime switch (self) {
|
||||
comptime return switch (self) {
|
||||
// Standard SI Exponents
|
||||
inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))),
|
||||
|
||||
@ -83,7 +83,7 @@ pub const UnitScale = enum(isize) {
|
||||
inline .lb => 453.59237,
|
||||
inline .st => 6350.29318,
|
||||
|
||||
inline else => @floatFromInt(@intFromEnum(self)),
|
||||
else => @floatFromInt(@intFromEnum(self)),
|
||||
};
|
||||
}
|
||||
};
|
||||
@ -103,11 +103,11 @@ pub fn init(comptime init_val: ArgOpts) Self {
|
||||
else
|
||||
s.data.set(@field(Dimension, f.name), @field(init_val, f.name));
|
||||
}
|
||||
return s;
|
||||
return comptime s;
|
||||
}
|
||||
|
||||
pub fn initFill(comptime val: UnitScale) Self {
|
||||
return comptime .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) };
|
||||
comptime return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) };
|
||||
}
|
||||
|
||||
pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
|
||||
@ -115,7 +115,7 @@ pub fn get(comptime self: Self, comptime key: Dimension) UnitScale {
|
||||
}
|
||||
|
||||
pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void {
|
||||
comptime self.data.set(key, val);
|
||||
self.data.set(key, val);
|
||||
}
|
||||
|
||||
pub fn argsOpt(self: Self) ArgOpts {
|
||||
@ -144,5 +144,5 @@ pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float
|
||||
factor /= base;
|
||||
}
|
||||
}
|
||||
return comptime factor;
|
||||
comptime return factor;
|
||||
}
|
||||
|
||||
269
src/Tensor.zig
269
src/Tensor.zig
@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension;
|
||||
// Comptime utilities
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn shapeTotal(comptime shape: []const usize) usize {
|
||||
var t: usize = 1;
|
||||
pub fn shapeTotal(comptime shape: []const comptime_int) usize {
|
||||
var t: comptime_int = 1;
|
||||
for (shape) |s| t *= s;
|
||||
return t;
|
||||
}
|
||||
|
||||
/// Check if two shapes are strictly identical.
|
||||
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
|
||||
pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool {
|
||||
if (a.len != b.len) return false;
|
||||
for (a, 0..) |v, i|
|
||||
if (v != b[i]) return false;
|
||||
@ -25,21 +25,21 @@ pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
|
||||
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
|
||||
/// e.g. shape {3, 4} → strides {4, 1}
|
||||
/// shape {2, 3, 4} → strides {12, 4, 1}
|
||||
pub fn shapeStrides(comptime shape: []const usize) [shape.len]usize {
|
||||
var st: [shape.len]usize = undefined;
|
||||
pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int {
|
||||
var st: [shape.len]comptime_int = undefined;
|
||||
if (shape.len == 0) return st;
|
||||
st[shape.len - 1] = 1;
|
||||
if (shape.len > 1) {
|
||||
var i: usize = shape.len - 1;
|
||||
var i: comptime_int = shape.len - 1;
|
||||
while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i];
|
||||
}
|
||||
return st;
|
||||
}
|
||||
|
||||
/// Return a copy of `shape` with the element at `axis` removed.
|
||||
pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [shape.len - 1]usize {
|
||||
var out: [shape.len - 1]usize = undefined;
|
||||
var j: usize = 0;
|
||||
pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int {
|
||||
var out: [shape.len - 1]comptime_int = undefined;
|
||||
var j: comptime_int = 0;
|
||||
for (shape, 0..) |v, i| {
|
||||
if (i != axis) {
|
||||
out[j] = v;
|
||||
@ -50,8 +50,8 @@ pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [sha
|
||||
}
|
||||
|
||||
/// Concatenate two compile-time slices.
|
||||
pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b.len]usize {
|
||||
var out: [a.len + b.len]usize = undefined;
|
||||
pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int {
|
||||
var out: [a.len + b.len]comptime_int = undefined;
|
||||
for (a, 0..) |v, i| out[i] = v;
|
||||
for (b, 0..) |v, i| out[a.len + i] = v;
|
||||
return out;
|
||||
@ -60,11 +60,11 @@ pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b
|
||||
/// Decode a flat row-major index into N-D coordinates.
|
||||
/// Called only in comptime contexts (all arguments are comptime).
|
||||
pub fn decodeFlatCoords(
|
||||
comptime flat: usize,
|
||||
comptime n: usize,
|
||||
comptime strd: [n]usize,
|
||||
comptime flat: comptime_int,
|
||||
comptime n: comptime_int,
|
||||
comptime strd: [n]comptime_int,
|
||||
) [n]usize {
|
||||
var coords: [n]usize = undefined;
|
||||
var coords: [n]comptime_int = undefined;
|
||||
var tmp = flat;
|
||||
for (0..n) |i| {
|
||||
coords[i] = if (strd[i] == 0) 0 else tmp / strd[i];
|
||||
@ -116,7 +116,7 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
|
||||
const s1: Scales = T1.scales;
|
||||
const s2: Scales = T2.scales;
|
||||
comptime var out = Scales.initFill(.none);
|
||||
inline for (std.enums.values(Dimension)) |dim| {
|
||||
for (std.enums.values(Dimension)) |dim| {
|
||||
const scale1 = comptime s1.get(dim);
|
||||
const scale2 = comptime s2.get(dim);
|
||||
out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0)
|
||||
@ -201,7 +201,7 @@ pub fn Tensor(
|
||||
comptime T: type,
|
||||
comptime d_opt: Dimensions.ArgOpts,
|
||||
comptime s_opt: Scales.ArgOpts,
|
||||
comptime shape_: []const usize,
|
||||
comptime shape_: []const comptime_int,
|
||||
) type {
|
||||
comptime {
|
||||
if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
|
||||
@ -223,10 +223,10 @@ pub fn Tensor(
|
||||
pub const ValueType: type = T;
|
||||
pub const dims: Dimensions = Dimensions.init(d_opt);
|
||||
pub const scales: Scales = Scales.init(s_opt);
|
||||
pub const shape: []const usize = shape_;
|
||||
pub const rank: usize = shape_.len;
|
||||
pub const total: usize = _total;
|
||||
pub const strides_arr: [shape_.len]usize = _strides;
|
||||
pub const shape: []const comptime_int = shape_;
|
||||
pub const rank: comptime_int = shape_.len;
|
||||
pub const total: comptime_int = _total;
|
||||
pub const strides_arr: [shape_.len]comptime_int = _strides;
|
||||
pub const ISTENSOR = true;
|
||||
|
||||
/// Convert N-D coords (row-major) to flat index — fully comptime.
|
||||
@ -359,9 +359,7 @@ pub fn Tensor(
|
||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||
if (comptime isInt(T)) {
|
||||
var result: Vec = undefined;
|
||||
inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]);
|
||||
return .{ .data = result };
|
||||
return .{ .data = @divTrunc(l, rr) };
|
||||
} else {
|
||||
return .{ .data = l / rr };
|
||||
}
|
||||
@ -379,19 +377,27 @@ pub fn Tensor(
|
||||
scales.argsOpt(),
|
||||
shape_,
|
||||
) {
|
||||
if (comptime isInt(T)) {
|
||||
var result: Vec = undefined;
|
||||
inline for (0..total) |i|
|
||||
result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T);
|
||||
return .{ .data = result };
|
||||
} else {
|
||||
const abs_exp = comptime @abs(exp);
|
||||
var result: Vec = @splat(1);
|
||||
comptime var i = 0;
|
||||
inline while (i < abs_exp) : (i += 1) result *= self.data;
|
||||
if (comptime exp < 0) result = @as(Vec, @splat(1)) / result;
|
||||
return .{ .data = result };
|
||||
if (comptime exp == 0) return .{ .data = @splat(1) };
|
||||
if (comptime exp == 1) return self;
|
||||
|
||||
var base = self.data;
|
||||
var result: Vec = @splat(1);
|
||||
comptime var e = @abs(exp);
|
||||
|
||||
// $O(\log n)$ Exponentiation by squaring applied to the entire vector
|
||||
inline while (e > 0) {
|
||||
if (e % 2 == 1) {
|
||||
result = if (comptime isInt(T)) result *| base else result * base;
|
||||
}
|
||||
e /= 2;
|
||||
if (e > 0) {
|
||||
base = if (comptime isInt(T)) base *| base else base * base;
|
||||
}
|
||||
}
|
||||
if (comptime !isInt(T) and exp < 0) {
|
||||
result = @as(Vec, @splat(1)) / result;
|
||||
}
|
||||
return .{ .data = result };
|
||||
}
|
||||
|
||||
/// Square root of every element. All dimension exponents must be even.
|
||||
@ -404,15 +410,16 @@ pub fn Tensor(
|
||||
if (comptime !dims.isSquare())
|
||||
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
||||
if (comptime @typeInfo(T) == .float) {
|
||||
return .{ .data = @sqrt(self.data) };
|
||||
return .{ .data = @sqrt(self.data) }; // Float is natively vectorized!
|
||||
} else {
|
||||
var result: Vec = undefined;
|
||||
const arr: [total]T = self.data; // Add this!
|
||||
var res_arr: [total]T = undefined;
|
||||
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
|
||||
inline for (0..total) |i| {
|
||||
const v = self.data[i];
|
||||
result[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
|
||||
for (0..total) |i| {
|
||||
const v = arr[i];
|
||||
res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
|
||||
}
|
||||
return .{ .data = result };
|
||||
return .{ .data = res_arr };
|
||||
}
|
||||
}
|
||||
|
||||
@ -433,28 +440,34 @@ pub fn Tensor(
|
||||
|
||||
if (comptime Self == ActualDest) return self;
|
||||
|
||||
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
||||
const DestT = ActualDest.ValueType;
|
||||
const DestVec = @Vector(total, DestT);
|
||||
|
||||
// If ratio is 1, just handle type conversion
|
||||
if (comptime ratio == 1.0) {
|
||||
if (comptime T == DestT) return .{ .data = self.data };
|
||||
return .{
|
||||
.data = switch (comptime @typeInfo(DestT)) {
|
||||
.float => @floatFromInt(self.data), // or @floatCast
|
||||
.int => @intFromFloat(self.data), // or @intCast
|
||||
else => unreachable,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Run validation checks FIRST before dealing with types
|
||||
if (comptime !dims.eql(ActualDest.dims))
|
||||
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
||||
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
|
||||
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
||||
|
||||
if (comptime Self == ActualDest) return self;
|
||||
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
||||
const DestT = ActualDest.ValueType;
|
||||
const DestVec = @Vector(total, DestT);
|
||||
|
||||
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
|
||||
if (comptime ratio == 1.0) {
|
||||
const T_info = @typeInfo(T);
|
||||
const Dest_info = @typeInfo(DestT);
|
||||
|
||||
return .{
|
||||
.data = if (comptime T_info == .int and Dest_info == .int)
|
||||
@as(DestVec, @intCast(self.data))
|
||||
else if (comptime T_info == .float and Dest_info == .float)
|
||||
@as(DestVec, @floatCast(self.data))
|
||||
else if (comptime T_info == .int and Dest_info == .float)
|
||||
@as(DestVec, @floatFromInt(self.data))
|
||||
else if (comptime T_info == .float and Dest_info == .int)
|
||||
@as(DestVec, @intFromFloat(self.data)) // Or @intFromFloat(@round(self.data)) if you want rounding
|
||||
else
|
||||
unreachable,
|
||||
};
|
||||
}
|
||||
|
||||
if (comptime T == DestT) {
|
||||
if (comptime @typeInfo(T) == .float)
|
||||
@ -466,33 +479,33 @@ pub fn Tensor(
|
||||
} else {
|
||||
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
|
||||
const half: T = comptime @divTrunc(div_val, 2);
|
||||
var result: DestVec = undefined;
|
||||
inline for (0..total) |i| {
|
||||
const val = self.data[i];
|
||||
result[i] = if (val >= 0)
|
||||
@divTrunc(val + half, div_val)
|
||||
else
|
||||
@divTrunc(val - half, div_val);
|
||||
|
||||
if (comptime @typeInfo(T).int.signedness == .unsigned) {
|
||||
return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) };
|
||||
} else {
|
||||
// Vectorized branchless negative handling
|
||||
const is_pos = self.data >= @as(Vec, @splat(0));
|
||||
const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half)));
|
||||
return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) };
|
||||
}
|
||||
return .{ .data = result };
|
||||
}
|
||||
}
|
||||
|
||||
var result: DestVec = undefined;
|
||||
inline for (0..total) |i| {
|
||||
const float_val: f64 = switch (comptime @typeInfo(T)) {
|
||||
.float => @floatCast(self.data[i]),
|
||||
.int => @floatFromInt(self.data[i]),
|
||||
else => unreachable,
|
||||
};
|
||||
const scaled = float_val * ratio;
|
||||
result[i] = switch (comptime @typeInfo(DestT)) {
|
||||
.float => @floatCast(scaled),
|
||||
.int => @intFromFloat(@round(scaled)),
|
||||
else => unreachable,
|
||||
};
|
||||
}
|
||||
return .{ .data = result };
|
||||
// Cross-type fully vectorized casting with scales
|
||||
const FVec = @Vector(total, f64);
|
||||
const float_vec: FVec = switch (comptime @typeInfo(T)) {
|
||||
.float => @floatCast(self.data),
|
||||
.int => @floatFromInt(self.data),
|
||||
else => unreachable,
|
||||
};
|
||||
|
||||
const scaled = float_vec * @as(FVec, @splat(ratio));
|
||||
|
||||
return switch (comptime @typeInfo(DestT)) {
|
||||
.float => .{ .data = @floatCast(scaled) },
|
||||
.int => .{ .data = @intFromFloat(@round(scaled)) },
|
||||
else => unreachable,
|
||||
};
|
||||
}
|
||||
|
||||
const CmpResult = if (total == 1) bool else [total]bool;
|
||||
@ -592,7 +605,7 @@ pub fn Tensor(
|
||||
const sa = shapeRemoveAxis(shape_, axis_a);
|
||||
const sb = shapeRemoveAxis(OT.shape, axis_b);
|
||||
const rs_raw = shapeCat(&sa, &sb);
|
||||
const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||
const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||
break :blk Tensor(
|
||||
T,
|
||||
dims.add(OT.dims).argsOpt(),
|
||||
@ -606,7 +619,7 @@ pub fn Tensor(
|
||||
const sa = comptime shapeRemoveAxis(shape_, axis_a);
|
||||
const sb = comptime shapeRemoveAxis(OT.shape, axis_b);
|
||||
const rs_raw = comptime shapeCat(&sa, &sb);
|
||||
const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||
const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||
|
||||
const ResultType = Tensor(
|
||||
T,
|
||||
@ -617,34 +630,85 @@ pub fn Tensor(
|
||||
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
|
||||
const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
|
||||
|
||||
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
|
||||
|
||||
// FAST PATH: Dot Product
|
||||
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
|
||||
if (comptime !isInt(T)) {
|
||||
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
|
||||
} else {
|
||||
// For integers, we do a vectorized saturating multiply,
|
||||
// then convert to an array to do a saturating sum
|
||||
const mul_arr: [total]T = a_data *| b_data;
|
||||
var acc: T = 0;
|
||||
for (mul_arr) |val| acc +|= val;
|
||||
return .{ .data = @splat(acc) };
|
||||
}
|
||||
}
|
||||
|
||||
// --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING ---
|
||||
const a_arr: [total]T = a_data;
|
||||
const b_arr: [OT.total]T = b_data;
|
||||
|
||||
// FAST PATH: 2D Matrix Multiplication
|
||||
if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) {
|
||||
const rows = shape_[0];
|
||||
const cols = OT.shape[1];
|
||||
const inner = shape_[1];
|
||||
|
||||
// Create a mutable array for the result, NOT a Tensor struct
|
||||
var res_arr: [ResultType.total]T = undefined;
|
||||
|
||||
for (0..rows) |i| {
|
||||
for (0..cols) |j| {
|
||||
var acc: T = 0;
|
||||
for (0..inner) |id| {
|
||||
const a_flat = i * _strides[0] + id * _strides[1];
|
||||
const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1];
|
||||
|
||||
// Use a_arr and b_arr here
|
||||
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
||||
}
|
||||
// Write to the array
|
||||
res_arr[i * cols + j] = acc;
|
||||
}
|
||||
}
|
||||
// Return the initialized Tensor struct
|
||||
return .{ .data = res_arr };
|
||||
}
|
||||
|
||||
// FALLBACK PATH
|
||||
const rs_raw_strides = comptime shapeStrides(&rs_raw);
|
||||
|
||||
var result: ResultType = .{ .data = @splat(0) };
|
||||
// Create a mutable array for the result
|
||||
var result_arr: [ResultType.total]T = undefined;
|
||||
|
||||
inline for (0..ResultType.total) |res_flat| {
|
||||
const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
|
||||
for (0..ResultType.total) |res_flat| {
|
||||
const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
|
||||
|
||||
const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*;
|
||||
const b_free: [sb.len]usize = comptime res_coords[sa.len..].*;
|
||||
var a_free: [sa.len]usize = undefined;
|
||||
for (0..sa.len) |i| a_free[i] = res_coords[i];
|
||||
var b_free: [sb.len]usize = undefined;
|
||||
for (0..sb.len) |i| b_free[i] = res_coords[sa.len + i];
|
||||
|
||||
var acc: T = 0;
|
||||
inline for (0..k) |ki| {
|
||||
const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free);
|
||||
const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free);
|
||||
const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides);
|
||||
const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
|
||||
for (0..k) |ki| {
|
||||
const a_coords = insertAxis(rank, axis_a, ki, &a_free);
|
||||
const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free);
|
||||
const a_flat = encodeFlatCoords(&a_coords, rank, _strides);
|
||||
const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
|
||||
|
||||
if (comptime isInt(T))
|
||||
acc +|= a_data[a_flat] *| b_data[b_flat]
|
||||
else
|
||||
acc += a_data[a_flat] * b_data[b_flat];
|
||||
// Use a_arr and b_arr here
|
||||
if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
||||
}
|
||||
result.data[res_flat] = acc;
|
||||
// Write to the array
|
||||
result_arr[res_flat] = acc;
|
||||
}
|
||||
return result;
|
||||
|
||||
// Return the initialized Tensor struct
|
||||
return .{ .data = result_arr };
|
||||
}
|
||||
|
||||
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
|
||||
@ -724,7 +788,8 @@ pub fn Tensor(
|
||||
}
|
||||
} else {
|
||||
try writer.writeAll("(");
|
||||
inline for (0..total) |i| {
|
||||
const max_to_print = 6;
|
||||
inline for (0..@min(total, max_to_print)) |i| {
|
||||
if (i > 0) try writer.writeAll(", ");
|
||||
switch (@typeInfo(T)) {
|
||||
.float, .comptime_float => try writer.printFloat(self.data[i], options),
|
||||
@ -736,6 +801,8 @@ pub fn Tensor(
|
||||
}),
|
||||
else => unreachable,
|
||||
}
|
||||
if (comptime i == max_to_print - 1 and total != max_to_print - 1)
|
||||
try writer.writeAll(", ...");
|
||||
}
|
||||
try writer.writeAll(")");
|
||||
}
|
||||
|
||||
@ -23,10 +23,10 @@ pub fn main(init: std.process.Init) !void {
|
||||
|
||||
try bench_Scalar(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
// try bench_vsNative(&stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
try bench_vsNative(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_Vector(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user