Removed inline for TensorAlloc fn

This commit is contained in:
adrien 2026-05-25 01:54:47 +02:00
parent 0ef19e18de
commit 8028cf41a5

View File

@ -67,7 +67,7 @@ pub fn Tensor(
/// Element-wise add. Dimensions must match; scales resolve to finer. /// Element-wise add. Dimensions must match; scales resolve to finer.
/// RHS must have the same shape as self, or total == 1 (broadcast). /// RHS must have the same shape as self, or total == 1 (broadcast).
pub inline fn add(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor( pub fn add(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
T, T,
dims.argsOpt(), dims.argsOpt(),
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
@ -100,7 +100,7 @@ pub fn Tensor(
/// Element-wise sub. Dimensions must match; scales resolve to finer. /// Element-wise sub. Dimensions must match; scales resolve to finer.
/// RHS must have the same shape as self, or total == 1 (broadcast). /// RHS must have the same shape as self, or total == 1 (broadcast).
pub inline fn sub(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor( pub fn sub(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
T, T,
dims.argsOpt(), dims.argsOpt(),
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
@ -133,7 +133,7 @@ pub fn Tensor(
/// Element-wise multiply. Dimension exponents summed. /// Element-wise multiply. Dimension exponents summed.
/// Shape {1} RHS is automatically broadcast across all elements. /// Shape {1} RHS is automatically broadcast across all elements.
pub inline fn mul(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor( pub fn mul(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
T, T,
dims.add(@TypeOf(rhs).dims).argsOpt(), dims.add(@TypeOf(rhs).dims).argsOpt(),
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
@ -167,7 +167,7 @@ pub fn Tensor(
/// Element-wise divide. Dimension exponents subtracted. /// Element-wise divide. Dimension exponents subtracted.
/// Shape {1} RHS is automatically broadcast across all elements. /// Shape {1} RHS is automatically broadcast across all elements.
pub inline fn div(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor( pub fn div(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
T, T,
dims.sub(@TypeOf(rhs).dims).argsOpt(), dims.sub(@TypeOf(rhs).dims).argsOpt(),
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
@ -200,7 +200,7 @@ pub fn Tensor(
} }
/// Absolute value of every element. /// Absolute value of every element.
pub inline fn abs(self: *const Self, alloc: Allocator) !Self { pub fn abs(self: *const Self, alloc: Allocator) !Self {
const result_vec = @as(Vec, @bitCast(@abs(self.data.*))); const result_vec = @as(Vec, @bitCast(@abs(self.data.*)));
const vec_ptr = try alloc.create(@TypeOf(result_vec)); const vec_ptr = try alloc.create(@TypeOf(result_vec));
@ -210,7 +210,7 @@ pub fn Tensor(
} }
/// Raise every element to a comptime integer exponent. /// Raise every element to a comptime integer exponent.
pub inline fn pow(self: *const Self, alloc: Allocator, comptime exp: comptime_int) !Tensor( pub fn pow(self: *const Self, alloc: Allocator, comptime exp: comptime_int) !Tensor(
T, T,
dims.scale(exp).argsOpt(), dims.scale(exp).argsOpt(),
scales.argsOpt(), scales.argsOpt(),
@ -245,7 +245,7 @@ pub fn Tensor(
} }
/// Square root of every element. All dimension exponents must be even. /// Square root of every element. All dimension exponents must be even.
pub inline fn sqrt(self: *const Self, alloc: Allocator) !Tensor( pub fn sqrt(self: *const Self, alloc: Allocator) !Tensor(
T, T,
dims.div(2).argsOpt(), dims.div(2).argsOpt(),
scales.argsOpt(), scales.argsOpt(),
@ -279,7 +279,7 @@ pub fn Tensor(
} }
/// Negate every element. /// Negate every element.
pub inline fn negate(self: *const Self, alloc: Allocator) !Self { pub fn negate(self: *const Self, alloc: Allocator) !Self {
const result_vec = -self.data.*; const result_vec = -self.data.*;
const vec_ptr = try alloc.create(@TypeOf(result_vec)); const vec_ptr = try alloc.create(@TypeOf(result_vec));
@ -292,7 +292,7 @@ pub fn Tensor(
/// Dimension mismatch compile error. /// Dimension mismatch compile error.
/// Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern). /// Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern).
/// Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. /// Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
pub inline fn to( pub fn to(
self: *const Self, self: *const Self,
alloc: Allocator, alloc: Allocator,
comptime Dest: type, comptime Dest: type,
@ -384,12 +384,12 @@ pub fn Tensor(
const CmpResult = if (total == 1) bool else [total]bool; const CmpResult = if (total == 1) bool else [total]bool;
inline fn cmpResult(v: @Vector(total, bool)) CmpResult { fn cmpResult(v: @Vector(total, bool)) CmpResult {
return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v); return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v);
} }
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
inline fn resolveScalePair(self: *const Self, alloc: Allocator, rhs: anytype) !struct { fn resolveScalePair(self: *const Self, alloc: Allocator, rhs: anytype) !struct {
l: Tensor(T, dims.argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), shape), l: Tensor(T, dims.argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), shape),
r: Tensor(T, dims.argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), shape), r: Tensor(T, dims.argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), shape),
@ -408,7 +408,7 @@ pub fn Tensor(
return .{ .l = try self.to(alloc, TargetType), .r = try rhs.to(alloc, TargetType) }; return .{ .l = try self.to(alloc, TargetType), .r = try rhs.to(alloc, TargetType) };
} }
pub inline fn eq(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { pub fn eq(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in eq."); @compileError("Dimension mismatch in eq.");
const p = try resolveScalePair(self, alloc, rhs); const p = try resolveScalePair(self, alloc, rhs);
@ -416,7 +416,7 @@ pub fn Tensor(
return cmpResult(p.l.data.* == p.r.data.*); return cmpResult(p.l.data.* == p.r.data.*);
} }
pub inline fn ne(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { pub fn ne(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in ne."); @compileError("Dimension mismatch in ne.");
const p = try resolveScalePair(self, alloc, rhs); const p = try resolveScalePair(self, alloc, rhs);
@ -424,7 +424,7 @@ pub fn Tensor(
return cmpResult(p.l.data.* != p.r.data.*); return cmpResult(p.l.data.* != p.r.data.*);
} }
pub inline fn gt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { pub fn gt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in gt."); @compileError("Dimension mismatch in gt.");
const p = try resolveScalePair(self, alloc, rhs); const p = try resolveScalePair(self, alloc, rhs);
@ -432,7 +432,7 @@ pub fn Tensor(
return cmpResult(p.l.data.* > p.r.data.*); return cmpResult(p.l.data.* > p.r.data.*);
} }
pub inline fn gte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { pub fn gte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in gte."); @compileError("Dimension mismatch in gte.");
const p = try resolveScalePair(self, alloc, rhs); const p = try resolveScalePair(self, alloc, rhs);
@ -440,7 +440,7 @@ pub fn Tensor(
return cmpResult(p.l.data.* >= p.r.data.*); return cmpResult(p.l.data.* >= p.r.data.*);
} }
pub inline fn lt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { pub fn lt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in lt."); @compileError("Dimension mismatch in lt.");
const p = try resolveScalePair(self, alloc, rhs); const p = try resolveScalePair(self, alloc, rhs);
@ -448,7 +448,7 @@ pub fn Tensor(
return cmpResult(p.l.data.* < p.r.data.*); return cmpResult(p.l.data.* < p.r.data.*);
} }
pub inline fn lte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { pub fn lte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in lte."); @compileError("Dimension mismatch in lte.");
const p = try resolveScalePair(self, alloc, rhs); const p = try resolveScalePair(self, alloc, rhs);
@ -457,7 +457,7 @@ pub fn Tensor(
} }
/// True iff every element is equal after scale resolution. /// True iff every element is equal after scale resolution.
pub inline fn eqAll(self: *const Self, alloc: Allocator, other: anytype) !bool { pub fn eqAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
if (comptime !dims.eql(@TypeOf(other).dims)) if (comptime !dims.eql(@TypeOf(other).dims))
@compileError("Dimension mismatch in eqAll."); @compileError("Dimension mismatch in eqAll.");
const p = try resolveScalePair(self, alloc, other); const p = try resolveScalePair(self, alloc, other);
@ -466,7 +466,7 @@ pub fn Tensor(
} }
/// True iff any element differs after scale resolution. /// True iff any element differs after scale resolution.
pub inline fn neAll(self: *const Self, alloc: Allocator, other: anytype) !bool { pub fn neAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
return !(try self.eqAll(alloc, other)); return !(try self.eqAll(alloc, other));
} }
}; };