Changed TEnsor to use *const

This commit is contained in:
adrien 2026-04-28 13:10:14 +02:00
parent bb6dd59b9a
commit 26ff02c50f
2 changed files with 160 additions and 153 deletions

View File

@ -139,8 +139,8 @@ inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
/// Take the anyvalue coming from operation and if it is a Tensor, return it. /// Take the anyvalue coming from operation and if it is a Tensor, return it.
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r). /// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { inline fn toRhsTensor(comptime T: type, r: anytype) *const RhsTensorType(T, @TypeOf(r.*)) {
const Rhs = @TypeOf(r); const Rhs = @TypeOf(r.*);
if (comptime isTensor(Rhs)) return r; if (comptime isTensor(Rhs)) return r;
const scalar: T = switch (@typeInfo(Rhs)) { const scalar: T = switch (@typeInfo(Rhs)) {
.comptime_int => switch (comptime @typeInfo(T)) { .comptime_int => switch (comptime @typeInfo(T)) {
@ -252,7 +252,7 @@ pub fn Tensor(
inline fn RhsT(comptime Rhs: type) type { inline fn RhsT(comptime Rhs: type) type {
return RhsTensorType(T, Rhs); return RhsTensorType(T, Rhs);
} }
inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { inline fn rhs(r: anytype) *const RhsT(@TypeOf(r.*)) {
return toRhsTensor(T, r); return toRhsTensor(T, r);
} }
@ -265,20 +265,20 @@ pub fn Tensor(
/// Element-wise add. Dimensions must match; scales resolve to finer. /// Element-wise add. Dimensions must match; scales resolve to finer.
/// RHS must have the same shape as self, or total == 1 (broadcast). /// RHS must have the same shape as self, or total == 1 (broadcast).
pub inline fn add(self: Self, r: anytype) Tensor( pub inline fn add(self: *const Self, r: anytype) *const Tensor(
T, T,
dims.argsOpt(), dims.argsOpt(),
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
shape_, shape_,
) { ) {
const rhs_t = rhs(r); const rhs_t = rhs(r);
const RhsType = @TypeOf(r); const RhsType = @TypeOf(rhs_t.*);
if (comptime !dims.eql(RhsType.dims)) if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
return .{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; return &.{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data };
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
@ -287,25 +287,25 @@ pub fn Tensor(
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
break :blk broadcastToVec(RhsNorm, rn); break :blk broadcastToVec(RhsNorm, rn);
}; };
return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; return &.{ .data = if (comptime isInt(T)) l +| rr else l + rr };
} }
/// Element-wise sub. Dimensions must match; scales resolve to finer. /// Element-wise sub. Dimensions must match; scales resolve to finer.
/// RHS must have the same shape as self, or total == 1 (broadcast). /// RHS must have the same shape as self, or total == 1 (broadcast).
pub inline fn sub(self: Self, r: anytype) Tensor( pub inline fn sub(self: *const Self, r: anytype) *const Tensor(
T, T,
dims.argsOpt(), dims.argsOpt(),
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
shape_, shape_,
) { ) {
const rhs_t = rhs(r); const rhs_t = rhs(r);
const RhsType = @TypeOf(rhs_t); const RhsType = @TypeOf(rhs_t.*);
if (comptime !dims.eql(RhsType.dims)) if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS."); @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS.");
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
return .{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; return &.{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data };
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
@ -314,19 +314,19 @@ pub fn Tensor(
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
break :blk broadcastToVec(RhsNorm, rn); break :blk broadcastToVec(RhsNorm, rn);
}; };
return .{ .data = if (comptime isInt(T)) l -| rr else l - rr }; return &.{ .data = if (comptime isInt(T)) l -| rr else l - rr };
} }
/// Element-wise multiply. Dimension exponents summed. /// Element-wise multiply. Dimension exponents summed.
/// Shape {1} RHS is automatically broadcast across all elements. /// Shape {1} RHS is automatically broadcast across all elements.
pub inline fn mul(self: Self, r: anytype) Tensor( pub inline fn mul(self: *const Self, r: anytype) *const Tensor(
T, T,
dims.add(RhsT(@TypeOf(r)).dims).argsOpt(), dims.add(RhsT(@TypeOf(r)).dims).argsOpt(),
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
shape_, shape_,
) { ) {
const rhs_q = rhs(r); const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q); const RhsType = @TypeOf(rhs_q.*);
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS."); @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
@ -334,20 +334,20 @@ pub fn Tensor(
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
const rr: Vec = broadcastToVec(RhsNorm, rr_base); const rr: Vec = broadcastToVec(RhsNorm, rr_base.*);
return .{ .data = if (comptime isInt(T)) l *| rr else l * rr }; return &.{ .data = if (comptime isInt(T)) l *| rr else l * rr };
} }
/// Element-wise divide. Dimension exponents subtracted. /// Element-wise divide. Dimension exponents subtracted.
/// Shape {1} RHS is automatically broadcast across all elements. /// Shape {1} RHS is automatically broadcast across all elements.
pub inline fn div(self: Self, r: anytype) Tensor( pub inline fn div(self: *const Self, r: anytype) *const Tensor(
T, T,
dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(), dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(),
finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
shape_, shape_,
) { ) {
const rhs_q = rhs(r); const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q); const RhsType = @TypeOf(rhs_q.*);
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS."); @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
@ -355,17 +355,17 @@ pub fn Tensor(
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
const rr: Vec = broadcastToVec(RhsNorm, rr_base); const rr: Vec = broadcastToVec(RhsNorm, rr_base.*);
if (comptime isInt(T)) { if (comptime isInt(T)) {
return .{ .data = @divTrunc(l, rr) }; return &.{ .data = @divTrunc(l, rr) };
} else { } else {
return .{ .data = l / rr }; return &.{ .data = l / rr };
} }
} }
/// Absolute value of every element. /// Absolute value of every element.
pub inline fn abs(self: Self) Self { pub inline fn abs(self: *const Self) *const Self {
return .{ .data = @bitCast(@abs(self.data)) }; return &.{ .data = @bitCast(@abs(self.data)) };
} }
/// Raise every element to a comptime integer exponent. /// Raise every element to a comptime integer exponent.
@ -375,7 +375,7 @@ pub fn Tensor(
scales.argsOpt(), scales.argsOpt(),
shape_, shape_,
) { ) {
if (comptime exp == 0) return .{ .data = @splat(1) }; if (comptime exp == 0) return &.{ .data = @splat(1) };
if (comptime exp == 1) return self; if (comptime exp == 1) return self;
var base = self.data; var base = self.data;
@ -395,7 +395,7 @@ pub fn Tensor(
if (comptime !isInt(T) and exp < 0) { if (comptime !isInt(T) and exp < 0) {
result = @as(Vec, @splat(1)) / result; result = @as(Vec, @splat(1)) / result;
} }
return .{ .data = result }; return &.{ .data = result };
} }
/// Square root of every element. All dimension exponents must be even. /// Square root of every element. All dimension exponents must be even.
@ -408,7 +408,7 @@ pub fn Tensor(
if (comptime !dims.isSquare()) if (comptime !dims.isSquare())
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
if (comptime @typeInfo(T) == .float) { if (comptime @typeInfo(T) == .float) {
return .{ .data = @sqrt(self.data) }; // Float is natively vectorized! return &.{ .data = @sqrt(self.data) }; // Float is natively vectorized!
} else { } else {
const arr: [total]T = self.data; // Add this! const arr: [total]T = self.data; // Add this!
var res_arr: [total]T = undefined; var res_arr: [total]T = undefined;
@ -417,13 +417,13 @@ pub fn Tensor(
const v = arr[i]; const v = arr[i];
res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
} }
return .{ .data = res_arr }; return &.{ .data = res_arr };
} }
} }
/// Negate every element. /// Negate every element.
pub inline fn negate(self: Self) Self { pub inline fn negate(self: Self) Self {
return .{ .data = -self.data }; return &.{ .data = -self.data };
} }
/// Convert to a compatible Tensor type. /// Convert to a compatible Tensor type.
@ -431,9 +431,9 @@ pub fn Tensor(
/// Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern). /// Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern).
/// Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. /// Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
pub inline fn to( pub inline fn to(
self: Self, self: *const Self,
comptime Dest: type, comptime Dest: type,
) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { ) *const Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
if (comptime Self == ActualDest) return self; if (comptime Self == ActualDest) return self;
@ -449,14 +449,14 @@ pub fn Tensor(
const DestVec = @Vector(total, DestT); const DestVec = @Vector(total, DestT);
if (comptime ratio == 1.0 and T == DestT) if (comptime ratio == 1.0 and T == DestT)
return .{ .data = self.data }; return &.{ .data = self.data };
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types // If ratio is 1, handle type conversion correctly based on BOTH source and dest types
if (comptime ratio == 1.0) { if (comptime ratio == 1.0) {
const T_info = @typeInfo(T); const T_info = @typeInfo(T);
const Dest_info = @typeInfo(DestT); const Dest_info = @typeInfo(DestT);
return .{ return &.{
.data = if (comptime T_info == .int and Dest_info == .int) .data = if (comptime T_info == .int and Dest_info == .int)
@as(DestVec, @intCast(self.data)) @as(DestVec, @intCast(self.data))
else if (comptime T_info == .float and Dest_info == .float) else if (comptime T_info == .float and Dest_info == .float)
@ -472,22 +472,22 @@ pub fn Tensor(
if (comptime T == DestT) { if (comptime T == DestT) {
if (comptime @typeInfo(T) == .float) if (comptime @typeInfo(T) == .float)
return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; return &.{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
if (comptime ratio >= 1.0) { if (comptime ratio >= 1.0) {
const mult: T = comptime @intFromFloat(@round(ratio)); const mult: T = comptime @intFromFloat(@round(ratio));
return .{ .data = self.data *| @as(Vec, @splat(mult)) }; return &.{ .data = self.data *| @as(Vec, @splat(mult)) };
} else { } else {
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
const half: T = comptime @divTrunc(div_val, 2); const half: T = comptime @divTrunc(div_val, 2);
if (comptime @typeInfo(T).int.signedness == .unsigned) { if (comptime @typeInfo(T).int.signedness == .unsigned) {
return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) }; return &.{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) };
} else { } else {
// Vectorized branchless negative handling // Vectorized branchless negative handling
const is_pos = self.data >= @as(Vec, @splat(0)); const is_pos = self.data >= @as(Vec, @splat(0));
const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half))); const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half)));
return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) }; return &.{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) };
} }
} }
} }
@ -503,8 +503,8 @@ pub fn Tensor(
const scaled = float_vec * @as(FVec, @splat(ratio)); const scaled = float_vec * @as(FVec, @splat(ratio));
return switch (comptime @typeInfo(DestT)) { return switch (comptime @typeInfo(DestT)) {
.float => .{ .data = @floatCast(scaled) }, .float => &.{ .data = @floatCast(scaled) },
.int => .{ .data = @intFromFloat(@round(scaled)) }, .int => &.{ .data = @intFromFloat(@round(scaled)) },
else => unreachable, else => unreachable,
}; };
} }
@ -516,8 +516,8 @@ pub fn Tensor(
} }
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { inline fn resolveScalePair(self: *const Self, rhs_q: anytype) struct { l: Vec, r: Vec } {
const RhsType = @TypeOf(rhs_q); const RhsType = @TypeOf(rhs_q.*);
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
@ -526,74 +526,74 @@ pub fn Tensor(
const rr: Vec = blk: { const rr: Vec = blk: {
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
break :blk broadcastToVec(RhsNorm, rn); break :blk broadcastToVec(RhsNorm, rn.*);
}; };
return .{ .l = l, .r = rr }; return .{ .l = l, .r = rr };
} }
pub inline fn eq(self: Self, r: anytype) CmpResult { pub inline fn eq(self: *const Self, r: anytype) CmpResult {
const rhs_q = rhs(r); const rhs_q = rhs(r);
if (comptime !dims.eql(@TypeOf(rhs_q).dims)) if (comptime !dims.eql(@TypeOf(rhs_q.*).dims))
@compileError("Dimension mismatch in ne."); @compileError("Dimension mismatch in ne.");
const p = resolveScalePair(self, rhs_q); const p = resolveScalePair(self, rhs_q);
return cmpResult(p.l == p.r); return cmpResult(p.l == p.r);
} }
pub inline fn ne(self: Self, r: anytype) CmpResult { pub inline fn ne(self: *const Self, r: anytype) CmpResult {
const rhs_q = rhs(r); const rhs_q = rhs(r);
if (comptime !dims.eql(@TypeOf(rhs_q).dims)) if (comptime !dims.eql(@TypeOf(rhs_q.*).dims))
@compileError("Dimension mismatch in ne."); @compileError("Dimension mismatch in ne.");
const p = resolveScalePair(self, rhs_q); const p = resolveScalePair(self, rhs_q);
return cmpResult(p.l != p.r); return cmpResult(p.l != p.r);
} }
pub inline fn gt(self: Self, r: anytype) CmpResult { pub inline fn gt(self: *const Self, r: anytype) CmpResult {
const rhs_q = rhs(r); const rhs_q = rhs(r);
if (comptime !dims.eql(@TypeOf(rhs_q).dims)) if (comptime !dims.eql(@TypeOf(rhs_q.*).dims))
@compileError("Dimension mismatch in gt."); @compileError("Dimension mismatch in gt.");
const p = resolveScalePair(self, rhs_q); const p = resolveScalePair(self, rhs_q);
return cmpResult(p.l > p.r); return cmpResult(p.l > p.r);
} }
pub inline fn gte(self: Self, r: anytype) CmpResult { pub inline fn gte(self: *const Self, r: anytype) CmpResult {
const rhs_q = rhs(r); const rhs_q = rhs(r);
if (comptime !dims.eql(@TypeOf(rhs_q).dims)) if (comptime !dims.eql(@TypeOf(rhs_q.*).dims))
@compileError("Dimension mismatch in gte."); @compileError("Dimension mismatch in gte.");
const p = resolveScalePair(self, rhs_q); const p = resolveScalePair(self, rhs_q);
return cmpResult(p.l >= p.r); return cmpResult(p.l >= p.r);
} }
pub inline fn lt(self: Self, r: anytype) CmpResult { pub inline fn lt(self: *const Self, r: anytype) CmpResult {
const rhs_q = rhs(r); const rhs_q = rhs(r);
if (comptime !dims.eql(@TypeOf(rhs_q).dims)) if (comptime !dims.eql(@TypeOf(rhs_q.*).dims))
@compileError("Dimension mismatch in lt."); @compileError("Dimension mismatch in lt.");
const p = resolveScalePair(self, rhs_q); const p = resolveScalePair(self, rhs_q);
return cmpResult(p.l < p.r); return cmpResult(p.l < p.r);
} }
pub inline fn lte(self: Self, r: anytype) CmpResult { pub inline fn lte(self: *const Self, r: anytype) CmpResult {
const rhs_q = rhs(r); const rhs_q = rhs(r);
if (comptime !dims.eql(@TypeOf(rhs_q).dims)) if (comptime !dims.eql(@TypeOf(rhs_q.*).dims))
@compileError("Dimension mismatch in lte."); @compileError("Dimension mismatch in lte.");
const p = resolveScalePair(self, rhs_q); const p = resolveScalePair(self, rhs_q);
return cmpResult(p.l <= p.r); return cmpResult(p.l <= p.r);
} }
/// True iff every element is equal after scale resolution. /// True iff every element is equal after scale resolution.
pub inline fn eqAll(self: Self, other: anytype) bool { pub inline fn eqAll(self: *const Self, other: anytype) bool {
if (comptime !dims.eql(@TypeOf(other).dims)) if (comptime !dims.eql(@TypeOf(other.*).dims))
@compileError("Dimension mismatch in eqAll."); @compileError("Dimension mismatch in eqAll.");
const p = resolveScalePair(self, other); const p = resolveScalePair(self, other);
return @reduce(.And, p.l == p.r); return @reduce(.And, p.l == p.r);
} }
/// True iff any element differs after scale resolution. /// True iff any element differs after scale resolution.
pub inline fn neAll(self: Self, other: anytype) bool { pub inline fn neAll(self: *const Self, other: anytype) bool {
return !self.eqAll(other); return !self.eqAll(other);
} }
pub inline fn contract( pub inline fn contract(
self: Self, self: *const Self,
other: anytype, other: anytype,
comptime axis_a: usize, comptime axis_a: usize,
comptime axis_b: usize, comptime axis_b: usize,
@ -607,7 +607,7 @@ pub fn Tensor(
const sb = shapeRemoveAxis(OT.shape, axis_b); const sb = shapeRemoveAxis(OT.shape, axis_b);
const rs_raw = shapeCat(&sa, &sb); const rs_raw = shapeCat(&sa, &sb);
const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw; const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
break :blk Tensor( break :blk *const Tensor(
T, T,
dims.add(OT.dims).argsOpt(), dims.add(OT.dims).argsOpt(),
finerScales(Self, OT).argsOpt(), finerScales(Self, OT).argsOpt(),
@ -645,7 +645,7 @@ pub fn Tensor(
const mul_arr: [total]T = a_data *| b_data; const mul_arr: [total]T = a_data *| b_data;
var acc: T = 0; var acc: T = 0;
for (mul_arr) |val| acc +|= val; for (mul_arr) |val| acc +|= val;
return .{ .data = @splat(acc) }; return &.{ .data = @splat(acc) };
} }
} }
@ -677,7 +677,7 @@ pub fn Tensor(
} }
} }
// Return the initialized Tensor struct // Return the initialized Tensor struct
return .{ .data = res_arr }; return &.{ .data = res_arr };
} }
// FALLBACK PATH // FALLBACK PATH
@ -709,12 +709,12 @@ pub fn Tensor(
} }
// Return the initialized Tensor struct // Return the initialized Tensor struct
return .{ .data = result_arr }; return &.{ .data = result_arr };
} }
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3. /// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
/// Result dimensions are the sum of input dimensions. /// Result dimensions are the sum of input dimensions.
pub inline fn cross(self: Self, other: anytype) Tensor( pub inline fn cross(self: *const Self, other: anytype) *const Tensor(
T, T,
dims.add(RhsT(@TypeOf(other)).dims).argsOpt(), dims.add(RhsT(@TypeOf(other)).dims).argsOpt(),
finerScales(Self, RhsT(@TypeOf(other))).argsOpt(), finerScales(Self, RhsT(@TypeOf(other))).argsOpt(),
@ -743,16 +743,16 @@ pub fn Tensor(
res[2] = (l[0] * r[1]) - (l[1] * r[0]); res[2] = (l[0] * r[1]) - (l[1] * r[0]);
} }
return .{ .data = res }; return &.{ .data = res };
} }
/// Sum of squared elements. Cheaper than length(); use for ordering. /// Sum of squared elements. Cheaper than length(); use for ordering.
pub inline fn lengthSqr(self: Self) T { pub inline fn lengthSqr(self: *const Self) T {
return @reduce(.Add, self.data * self.data); return @reduce(.Add, self.data * self.data);
} }
/// Euclidean length (L2 norm). /// Euclidean length (L2 norm).
pub inline fn length(self: Self) T { pub inline fn length(self: *const Self) T {
const sq = self.lengthSqr(); const sq = self.lengthSqr();
if (comptime @typeInfo(T) == .int) { if (comptime @typeInfo(T) == .int) {
const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits);
@ -762,17 +762,17 @@ pub fn Tensor(
} }
/// Product of all elements. Result has shape {1}; dimension exponent * total. /// Product of all elements. Result has shape {1}; dimension exponent * total.
pub inline fn product(self: Self) Tensor( pub inline fn product(self: *const Self) *const Tensor(
T, T,
dims.scale(@as(comptime_int, total)).argsOpt(), dims.scale(@as(comptime_int, total)).argsOpt(),
scales.argsOpt(), scales.argsOpt(),
&.{1}, &.{1},
) { ) {
return .{ .data = .{@reduce(.Mul, self.data)} }; return &.{ .data = .{@reduce(.Mul, self.data)} };
} }
pub fn formatNumber( pub fn formatNumber(
self: Self, self: *const Self,
writer: *std.Io.Writer, writer: *std.Io.Writer,
options: std.fmt.Number, options: std.fmt.Number,
) !void { ) !void {
@ -1393,3 +1393,12 @@ test "Tensor strides_arr correctness" {
try std.testing.expectEqual(4, T3.strides_arr[1]); try std.testing.expectEqual(4, T3.strides_arr[1]);
try std.testing.expectEqual(1, T3.strides_arr[2]); try std.testing.expectEqual(1, T3.strides_arr[2]);
} }
test "Big Tensor" {
const Tens = Tensor(f32, .{}, .{}, &.{1_000_000});
const t1 = Tens.splat(2);
const t2 = Tens.splat(3);
const t3 = t1.add(t2);
try std.testing.expectApproxEqAbs(5, t3.data[0], 0.0001);
}

View File

@ -26,11 +26,11 @@ pub fn main(init: std.process.Init) !void {
try bench_vsNative(&stdout_writer.interface); try bench_vsNative(&stdout_writer.interface);
try stdout_writer.flush(); try stdout_writer.flush();
// try bench_crossTypeVsNative(&stdout_writer.interface); // try bench_crossTypeVsNative(&stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try bench_Vector(&stdout_writer.interface); // try bench_Vector(&stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try bench_HighDimTensor(&stdout_writer.interface); // try bench_HighDimTensor(&stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
} }
fn getTime() Io.Timestamp { fn getTime() Io.Timestamp {
@ -180,8 +180,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
} }
}.f; }.f;
const Types = .{ f64, i64, i128, f32, f64 }; const Types = .{ i32, i64, i128, f32, f64 };
const TNames = .{ "f64", "i64", "i128", "f32", "f64" }; const TNames = .{ "i32", "i64", "i128", "f32", "f64" };
// Expanded Ops to match bench_Scalar // Expanded Ops to match bench_Scalar
const Ops = .{ "add", "sub", "mul", "div", "abs", "eq", "gt" }; const Ops = .{ "add", "sub", "mul", "div", "abs", "eq", "gt" };
@ -203,86 +203,84 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
const M = Tensor(T, .{}, .{}, &.{1}); const M = Tensor(T, .{}, .{}, &.{1});
std.mem.doNotOptimizeAway({ for (0..SAMPLES) |_| {
for (0..SAMPLES) |_| { // --- 1. Benchmark Native ---
// --- 1. Benchmark Native --- const n_start = getTime();
const n_start = getTime(); const a = getValT(T, 10);
const a = getValT(T, 10); const b = getValT(T, 2);
const b = getValT(T, 2); for (0..ITERS) |_| {
for (0..ITERS) |_| { // Native logic branch
// Native logic branch _ = if (comptime std.mem.eql(u8, op_name, "add"))
_ = if (comptime std.mem.eql(u8, op_name, "add")) if (comptime @typeInfo(T) == .int) a +| b else a + b
if (comptime @typeInfo(T) == .int) a +| b else a + b else if (comptime std.mem.eql(u8, op_name, "sub"))
else if (comptime std.mem.eql(u8, op_name, "sub")) if (comptime @typeInfo(T) == .int) a -| b else a - b
if (comptime @typeInfo(T) == .int) a -| b else a - b else if (comptime std.mem.eql(u8, op_name, "mul"))
else if (comptime std.mem.eql(u8, op_name, "mul")) if (comptime @typeInfo(T) == .int) a *| b else a * b
if (comptime @typeInfo(T) == .int) a *| b else a * b else if (comptime std.mem.eql(u8, op_name, "div"))
else if (comptime std.mem.eql(u8, op_name, "div")) if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b
if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b else if (comptime std.mem.eql(u8, op_name, "abs"))
else if (comptime std.mem.eql(u8, op_name, "abs")) if (comptime @typeInfo(T) == .int) @abs(a) else @as(T, @abs(a))
if (comptime @typeInfo(T) == .int) @abs(a) else @as(T, @abs(a)) else if (comptime std.mem.eql(u8, op_name, "eq"))
else if (comptime std.mem.eql(u8, op_name, "eq")) a == b
a == b else if (comptime std.mem.eql(u8, op_name, "gt"))
else if (comptime std.mem.eql(u8, op_name, "gt")) a > b
a > b else
else unreachable;
unreachable;
}
const n_end = getTime();
native_total_ns += @as(f64, @floatFromInt(n_start.durationTo(n_end).toNanoseconds()));
const v_start = getTime();
const va = getValT(T, 10);
const vb = getValT(T, 2);
for (0..ITERS) |_| {
// Native logic branch
_ = if (comptime std.mem.eql(u8, op_name, "add"))
if (comptime @typeInfo(T) == .int) va +| vb else va + vb
else if (comptime std.mem.eql(u8, op_name, "sub"))
if (comptime @typeInfo(T) == .int) va -| vb else va - vb
else if (comptime std.mem.eql(u8, op_name, "mul"))
if (comptime @typeInfo(T) == .int) va *| vb else va * vb
else if (comptime std.mem.eql(u8, op_name, "div"))
if (comptime @typeInfo(T) == .int) @divTrunc(va, vb) else va / vb
else if (comptime std.mem.eql(u8, op_name, "abs"))
if (comptime @typeInfo(T) == .int) @abs(va) else @as(T, @abs(va))
else if (comptime std.mem.eql(u8, op_name, "eq"))
va == vb
else if (comptime std.mem.eql(u8, op_name, "gt"))
va > vb
else
unreachable;
}
const v_end = getTime();
vector_total_ns += @as(f64, @floatFromInt(v_start.durationTo(v_end).toNanoseconds()));
// --- 2. Benchmark Scalar ---
const q_start = getTime();
const qa = M.splat(getValT(T, 10));
const qb = M.splat(getValT(T, 2));
for (0..ITERS) |_| {
// Scalar logic branch
_ = if (comptime std.mem.eql(u8, op_name, "add"))
qa.add(qb)
else if (comptime std.mem.eql(u8, op_name, "sub"))
qa.sub(qb)
else if (comptime std.mem.eql(u8, op_name, "mul"))
qa.mul(qb)
else if (comptime std.mem.eql(u8, op_name, "div"))
qa.div(qb)
else if (comptime std.mem.eql(u8, op_name, "abs"))
qa.abs()
else if (comptime std.mem.eql(u8, op_name, "eq"))
qa.eq(qb)
else if (comptime std.mem.eql(u8, op_name, "gt"))
qa.gt(qb)
else
unreachable;
}
const q_end = getTime();
tensor_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds()));
} }
}); const n_end = getTime();
native_total_ns += @as(f64, @floatFromInt(n_start.durationTo(n_end).toNanoseconds()));
const v_start = getTime();
const va = @Vector(1, T){getValT(T, 10)};
const vb = @Vector(1, T){getValT(T, 2)};
for (0..ITERS) |_| {
// Native logic branch
_ = if (comptime std.mem.eql(u8, op_name, "add"))
if (comptime @typeInfo(T) == .int) va +| vb else va + vb
else if (comptime std.mem.eql(u8, op_name, "sub"))
if (comptime @typeInfo(T) == .int) va -| vb else va - vb
else if (comptime std.mem.eql(u8, op_name, "mul"))
if (comptime @typeInfo(T) == .int) va *| vb else va * vb
else if (comptime std.mem.eql(u8, op_name, "div"))
if (comptime @typeInfo(T) == .int) @divTrunc(va, vb) else va / vb
else if (comptime std.mem.eql(u8, op_name, "abs"))
if (comptime @typeInfo(T) == .int) @as(T, @intCast(@abs(va[0]))) else @abs(va)
else if (comptime std.mem.eql(u8, op_name, "eq"))
va == vb
else if (comptime std.mem.eql(u8, op_name, "gt"))
va > vb
else
unreachable;
}
const v_end = getTime();
vector_total_ns += @as(f64, @floatFromInt(v_start.durationTo(v_end).toNanoseconds()));
// --- 2. Benchmark Scalar ---
const q_start = getTime();
const qa = M.splat(getValT(T, 10));
const qb = M.splat(getValT(T, 2));
for (0..ITERS) |_| {
// Scalar logic branch
_ = if (comptime std.mem.eql(u8, op_name, "add"))
&qa.add(&qb)
else if (comptime std.mem.eql(u8, op_name, "sub"))
&qa.sub(&qb)
else if (comptime std.mem.eql(u8, op_name, "mul"))
&qa.mul(&qb)
else if (comptime std.mem.eql(u8, op_name, "div"))
&qa.div(&qb)
else if (comptime std.mem.eql(u8, op_name, "abs"))
&qa.abs()
else if (comptime std.mem.eql(u8, op_name, "eq"))
&qa.eq(&qb)
else if (comptime std.mem.eql(u8, op_name, "gt"))
&qa.gt(&qb)
else
unreachable;
}
const q_end = getTime();
tensor_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds()));
}
const avg_n = (native_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); const avg_n = (native_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS));
const avg_v = (vector_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); const avg_v = (vector_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS));