tmp
This commit is contained in:
parent
8816a65518
commit
f0029449f0
@ -139,8 +139,9 @@ inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
|||||||
|
|
||||||
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
|
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
|
||||||
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
|
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
|
||||||
inline fn toRhsTensor(comptime T: type, r: anytype) *const RhsTensorType(T, if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) {
|
inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) {
|
||||||
const Rhs = if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r);
|
const is_ptr = @typeInfo(@TypeOf(r)) == .pointer;
|
||||||
|
const Rhs = if (is_ptr) @TypeOf(r.*) else @TypeOf(r);
|
||||||
if (comptime isTensor(Rhs)) return r;
|
if (comptime isTensor(Rhs)) return r;
|
||||||
const scalar: T = switch (@typeInfo(Rhs)) {
|
const scalar: T = switch (@typeInfo(Rhs)) {
|
||||||
.comptime_int => switch (comptime @typeInfo(T)) {
|
.comptime_int => switch (comptime @typeInfo(T)) {
|
||||||
@ -161,7 +162,7 @@ inline fn toRhsTensor(comptime T: type, r: anytype) *const RhsTensorType(T, if (
|
|||||||
},
|
},
|
||||||
else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)),
|
else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)),
|
||||||
};
|
};
|
||||||
return &Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} };
|
return Tensor(T, .{}, .{}, &.{1}){ .data = .{scalar} };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
|
pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
|
||||||
@ -253,7 +254,7 @@ pub fn Tensor(
|
|||||||
return RhsTensorType(T, Rhs);
|
return RhsTensorType(T, Rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fn rhs(r: anytype) *const RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) {
|
inline fn rhs(r: anytype) RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)) {
|
||||||
return toRhsTensor(T, r);
|
return toRhsTensor(T, r);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,7 +267,7 @@ pub fn Tensor(
|
|||||||
|
|
||||||
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
||||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||||
pub inline fn add(self: *const Self, r: anytype) *const Tensor(
|
pub inline fn add(self: *const Self, r: anytype) Tensor(
|
||||||
T,
|
T,
|
||||||
dims.argsOpt(),
|
dims.argsOpt(),
|
||||||
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
||||||
@ -279,7 +280,7 @@ pub fn Tensor(
|
|||||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
||||||
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
||||||
return &.{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data };
|
return .{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data };
|
||||||
|
|
||||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||||
@ -288,12 +289,12 @@ pub fn Tensor(
|
|||||||
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
||||||
break :blk broadcastToVec(RhsNorm, rn);
|
break :blk broadcastToVec(RhsNorm, rn);
|
||||||
};
|
};
|
||||||
return &.{ .data = if (comptime isInt(T)) l +| rr else l + rr };
|
return .{ .data = if (comptime isInt(T)) l +| rr else l + rr };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
||||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||||
pub inline fn sub(self: *const Self, r: anytype) *const Tensor(
|
pub inline fn sub(self: *const Self, r: anytype) Tensor(
|
||||||
T,
|
T,
|
||||||
dims.argsOpt(),
|
dims.argsOpt(),
|
||||||
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
||||||
@ -306,7 +307,7 @@ pub fn Tensor(
|
|||||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
||||||
@compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
||||||
return &.{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data };
|
return .{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data };
|
||||||
|
|
||||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||||
@ -315,12 +316,12 @@ pub fn Tensor(
|
|||||||
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
||||||
break :blk broadcastToVec(RhsNorm, rn);
|
break :blk broadcastToVec(RhsNorm, rn);
|
||||||
};
|
};
|
||||||
return &.{ .data = if (comptime isInt(T)) l -| rr else l - rr };
|
return .{ .data = if (comptime isInt(T)) l -| rr else l - rr };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise multiply. Dimension exponents summed.
|
/// Element-wise multiply. Dimension exponents summed.
|
||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn mul(self: *const Self, r: anytype) *const Tensor(
|
pub inline fn mul(self: *const Self, r: anytype) Tensor(
|
||||||
T,
|
T,
|
||||||
dims.add(RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)).dims).argsOpt(),
|
dims.add(RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)).dims).argsOpt(),
|
||||||
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
||||||
@ -336,12 +337,12 @@ pub fn Tensor(
|
|||||||
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||||
return &.{ .data = if (comptime isInt(T)) l *| rr else l * rr };
|
return .{ .data = if (comptime isInt(T)) l *| rr else l * rr };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise divide. Dimension exponents subtracted.
|
/// Element-wise divide. Dimension exponents subtracted.
|
||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn div(self: *const Self, r: anytype) *const Tensor(
|
pub inline fn div(self: *const Self, r: anytype) Tensor(
|
||||||
T,
|
T,
|
||||||
dims.sub(RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)).dims).argsOpt(),
|
dims.sub(RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r)).dims).argsOpt(),
|
||||||
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
finerScales(Self, RhsT(if (@typeInfo(@TypeOf(r)) == .pointer) @TypeOf(r.*) else @TypeOf(r))).argsOpt(),
|
||||||
@ -358,19 +359,19 @@ pub fn Tensor(
|
|||||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||||
if (comptime isInt(T)) {
|
if (comptime isInt(T)) {
|
||||||
return &.{ .data = @divTrunc(l, rr) };
|
return .{ .data = @divTrunc(l, rr) };
|
||||||
} else {
|
} else {
|
||||||
return &.{ .data = l / rr };
|
return .{ .data = l / rr };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Absolute value of every element.
|
/// Absolute value of every element.
|
||||||
pub inline fn abs(self: *const Self) *const Self {
|
pub inline fn abs(self: *const Self) Self {
|
||||||
return &.{ .data = @bitCast(@abs(self.data)) };
|
return .{ .data = @bitCast(@abs(self.data)) };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Raise every element to a comptime integer exponent.
|
/// Raise every element to a comptime integer exponent.
|
||||||
pub inline fn pow(self: Self, comptime exp: comptime_int) *const Tensor(
|
pub inline fn pow(self: Self, comptime exp: comptime_int) Tensor(
|
||||||
T,
|
T,
|
||||||
dims.scale(exp).argsOpt(),
|
dims.scale(exp).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
@ -396,11 +397,11 @@ pub fn Tensor(
|
|||||||
if (comptime !isInt(T) and exp < 0) {
|
if (comptime !isInt(T) and exp < 0) {
|
||||||
result = @as(Vec, @splat(1)) / result;
|
result = @as(Vec, @splat(1)) / result;
|
||||||
}
|
}
|
||||||
return &.{ .data = result };
|
return .{ .data = result };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Square root of every element. All dimension exponents must be even.
|
/// Square root of every element. All dimension exponents must be even.
|
||||||
pub inline fn sqrt(self: Self) *const Tensor(
|
pub inline fn sqrt(self: Self) Tensor(
|
||||||
T,
|
T,
|
||||||
dims.div(2).argsOpt(),
|
dims.div(2).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
@ -409,7 +410,7 @@ pub fn Tensor(
|
|||||||
if (comptime !dims.isSquare())
|
if (comptime !dims.isSquare())
|
||||||
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
||||||
if (comptime @typeInfo(T) == .float) {
|
if (comptime @typeInfo(T) == .float) {
|
||||||
return &.{ .data = @sqrt(self.data) }; // Float is natively vectorized!
|
return .{ .data = @sqrt(self.data) }; // Float is natively vectorized!
|
||||||
} else {
|
} else {
|
||||||
const arr: [total]T = self.data; // Add this!
|
const arr: [total]T = self.data; // Add this!
|
||||||
var res_arr: [total]T = undefined;
|
var res_arr: [total]T = undefined;
|
||||||
@ -418,13 +419,13 @@ pub fn Tensor(
|
|||||||
const v = arr[i];
|
const v = arr[i];
|
||||||
res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
|
res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v)))));
|
||||||
}
|
}
|
||||||
return &.{ .data = res_arr };
|
return .{ .data = res_arr };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Negate every element.
|
/// Negate every element.
|
||||||
pub inline fn negate(self: Self) *const Self {
|
pub inline fn negate(self: Self) Self {
|
||||||
return &.{ .data = -self.data };
|
return .{ .data = -self.data };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert to a compatible Tensor type.
|
/// Convert to a compatible Tensor type.
|
||||||
@ -434,7 +435,7 @@ pub fn Tensor(
|
|||||||
pub inline fn to(
|
pub inline fn to(
|
||||||
self: *const Self,
|
self: *const Self,
|
||||||
comptime Dest: type,
|
comptime Dest: type,
|
||||||
) *const Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
|
) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
|
||||||
const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
|
const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
|
||||||
|
|
||||||
if (comptime Self == ActualDest) return self;
|
if (comptime Self == ActualDest) return self;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user