Removed inline for TensorAlloc fn
This commit is contained in:
parent
0ef19e18de
commit
8028cf41a5
@ -67,7 +67,7 @@ pub fn Tensor(
|
|||||||
|
|
||||||
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
||||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||||
pub inline fn add(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
|
pub fn add(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
|
||||||
T,
|
T,
|
||||||
dims.argsOpt(),
|
dims.argsOpt(),
|
||||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||||
@ -100,7 +100,7 @@ pub fn Tensor(
|
|||||||
|
|
||||||
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
||||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||||
pub inline fn sub(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
|
pub fn sub(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
|
||||||
T,
|
T,
|
||||||
dims.argsOpt(),
|
dims.argsOpt(),
|
||||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||||
@ -133,7 +133,7 @@ pub fn Tensor(
|
|||||||
|
|
||||||
/// Element-wise multiply. Dimension exponents summed.
|
/// Element-wise multiply. Dimension exponents summed.
|
||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn mul(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
|
pub fn mul(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
|
||||||
T,
|
T,
|
||||||
dims.add(@TypeOf(rhs).dims).argsOpt(),
|
dims.add(@TypeOf(rhs).dims).argsOpt(),
|
||||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||||
@ -167,7 +167,7 @@ pub fn Tensor(
|
|||||||
|
|
||||||
/// Element-wise divide. Dimension exponents subtracted.
|
/// Element-wise divide. Dimension exponents subtracted.
|
||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn div(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
|
pub fn div(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
|
||||||
T,
|
T,
|
||||||
dims.sub(@TypeOf(rhs).dims).argsOpt(),
|
dims.sub(@TypeOf(rhs).dims).argsOpt(),
|
||||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||||
@ -200,7 +200,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Absolute value of every element.
|
/// Absolute value of every element.
|
||||||
pub inline fn abs(self: *const Self, alloc: Allocator) !Self {
|
pub fn abs(self: *const Self, alloc: Allocator) !Self {
|
||||||
const result_vec = @as(Vec, @bitCast(@abs(self.data.*)));
|
const result_vec = @as(Vec, @bitCast(@abs(self.data.*)));
|
||||||
|
|
||||||
const vec_ptr = try alloc.create(@TypeOf(result_vec));
|
const vec_ptr = try alloc.create(@TypeOf(result_vec));
|
||||||
@ -210,7 +210,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Raise every element to a comptime integer exponent.
|
/// Raise every element to a comptime integer exponent.
|
||||||
pub inline fn pow(self: *const Self, alloc: Allocator, comptime exp: comptime_int) !Tensor(
|
pub fn pow(self: *const Self, alloc: Allocator, comptime exp: comptime_int) !Tensor(
|
||||||
T,
|
T,
|
||||||
dims.scale(exp).argsOpt(),
|
dims.scale(exp).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
@ -245,7 +245,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Square root of every element. All dimension exponents must be even.
|
/// Square root of every element. All dimension exponents must be even.
|
||||||
pub inline fn sqrt(self: *const Self, alloc: Allocator) !Tensor(
|
pub fn sqrt(self: *const Self, alloc: Allocator) !Tensor(
|
||||||
T,
|
T,
|
||||||
dims.div(2).argsOpt(),
|
dims.div(2).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
@ -279,7 +279,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Negate every element.
|
/// Negate every element.
|
||||||
pub inline fn negate(self: *const Self, alloc: Allocator) !Self {
|
pub fn negate(self: *const Self, alloc: Allocator) !Self {
|
||||||
const result_vec = -self.data.*;
|
const result_vec = -self.data.*;
|
||||||
|
|
||||||
const vec_ptr = try alloc.create(@TypeOf(result_vec));
|
const vec_ptr = try alloc.create(@TypeOf(result_vec));
|
||||||
@ -292,7 +292,7 @@ pub fn Tensor(
|
|||||||
/// • Dimension mismatch → compile error.
|
/// • Dimension mismatch → compile error.
|
||||||
/// • Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern).
|
/// • Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern).
|
||||||
/// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
|
/// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
|
||||||
pub inline fn to(
|
pub fn to(
|
||||||
self: *const Self,
|
self: *const Self,
|
||||||
alloc: Allocator,
|
alloc: Allocator,
|
||||||
comptime Dest: type,
|
comptime Dest: type,
|
||||||
@ -384,12 +384,12 @@ pub fn Tensor(
|
|||||||
|
|
||||||
const CmpResult = if (total == 1) bool else [total]bool;
|
const CmpResult = if (total == 1) bool else [total]bool;
|
||||||
|
|
||||||
inline fn cmpResult(v: @Vector(total, bool)) CmpResult {
|
fn cmpResult(v: @Vector(total, bool)) CmpResult {
|
||||||
return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v);
|
return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||||
inline fn resolveScalePair(self: *const Self, alloc: Allocator, rhs: anytype) !struct {
|
fn resolveScalePair(self: *const Self, alloc: Allocator, rhs: anytype) !struct {
|
||||||
l: Tensor(T, dims.argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), shape),
|
l: Tensor(T, dims.argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), shape),
|
||||||
r: Tensor(T, dims.argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), shape),
|
r: Tensor(T, dims.argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), shape),
|
||||||
|
|
||||||
@ -408,7 +408,7 @@ pub fn Tensor(
|
|||||||
return .{ .l = try self.to(alloc, TargetType), .r = try rhs.to(alloc, TargetType) };
|
return .{ .l = try self.to(alloc, TargetType), .r = try rhs.to(alloc, TargetType) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn eq(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
pub fn eq(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in eq.");
|
@compileError("Dimension mismatch in eq.");
|
||||||
const p = try resolveScalePair(self, alloc, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
@ -416,7 +416,7 @@ pub fn Tensor(
|
|||||||
return cmpResult(p.l.data.* == p.r.data.*);
|
return cmpResult(p.l.data.* == p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn ne(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
pub fn ne(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in ne.");
|
@compileError("Dimension mismatch in ne.");
|
||||||
const p = try resolveScalePair(self, alloc, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
@ -424,7 +424,7 @@ pub fn Tensor(
|
|||||||
return cmpResult(p.l.data.* != p.r.data.*);
|
return cmpResult(p.l.data.* != p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn gt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
pub fn gt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in gt.");
|
@compileError("Dimension mismatch in gt.");
|
||||||
const p = try resolveScalePair(self, alloc, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
@ -432,7 +432,7 @@ pub fn Tensor(
|
|||||||
return cmpResult(p.l.data.* > p.r.data.*);
|
return cmpResult(p.l.data.* > p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn gte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
pub fn gte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in gte.");
|
@compileError("Dimension mismatch in gte.");
|
||||||
const p = try resolveScalePair(self, alloc, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
@ -440,7 +440,7 @@ pub fn Tensor(
|
|||||||
return cmpResult(p.l.data.* >= p.r.data.*);
|
return cmpResult(p.l.data.* >= p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn lt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
pub fn lt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in lt.");
|
@compileError("Dimension mismatch in lt.");
|
||||||
const p = try resolveScalePair(self, alloc, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
@ -448,7 +448,7 @@ pub fn Tensor(
|
|||||||
return cmpResult(p.l.data.* < p.r.data.*);
|
return cmpResult(p.l.data.* < p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn lte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
pub fn lte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in lte.");
|
@compileError("Dimension mismatch in lte.");
|
||||||
const p = try resolveScalePair(self, alloc, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
@ -457,7 +457,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// True iff every element is equal after scale resolution.
|
/// True iff every element is equal after scale resolution.
|
||||||
pub inline fn eqAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
|
pub fn eqAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
|
||||||
if (comptime !dims.eql(@TypeOf(other).dims))
|
if (comptime !dims.eql(@TypeOf(other).dims))
|
||||||
@compileError("Dimension mismatch in eqAll.");
|
@compileError("Dimension mismatch in eqAll.");
|
||||||
const p = try resolveScalePair(self, alloc, other);
|
const p = try resolveScalePair(self, alloc, other);
|
||||||
@ -466,7 +466,7 @@ pub fn Tensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// True iff any element differs after scale resolution.
|
/// True iff any element differs after scale resolution.
|
||||||
pub inline fn neAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
|
pub fn neAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
|
||||||
return !(try self.eqAll(alloc, other));
|
return !(try self.eqAll(alloc, other));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user